import torch
import torch.nn as nn
from torch.nn import Linear
from torch.nn.init import xavier_uniform_, constant_
from functools import partial
from torch.nn import Linear, BatchNorm1d, ELU
zeros_initializer = partial(constant_, val=0.0)
[docs]
def linear_bn_act(in_features: int, out_features: int, lbias: bool = True,
activation = ELU(), use_batch_norm: bool = True):
"""
Create a sequential module that includes a linear layer, optional batch normalization, and activation functions
Args:
in_features (int): Number of input features
out_features (int): Number of output features
lbias (bool): Whether a bias is included in the linear layer
activation (callable): The activation function to be used
use_batch_norm (bool): Whether it includes batch normalization or not
Returns:
torch.nn.Sequential: A sequential module containing a linear layer, an optional batch normalization, and an activation function
"""
layers = []
layers.append(Linear(in_features, out_features, bias=lbias))
if use_batch_norm:
layers.append(BatchNorm1d(out_features))
if activation is not None:
layers.append(activation)
return nn.Sequential(*layers)
[docs]
class denseLayer(nn.Module):
def __init__(self, in_features: int=None, out_features: int=None, bias:bool=True,
use_batch_norm:bool=True, activation=nn.ELU()):
super().__init__()
self.lba = linear_bn_act(in_features=in_features, out_features=out_features, lbias=bias,
activation=activation, use_batch_norm=use_batch_norm)
self.linear = Linear(out_features, out_features, bias=bias)
[docs]
def forward(self, x):
out = self.linear(self.lba(x))
return out
[docs]
class Dense(nn.Linear):
r"""From schnetpack
Fully connected linear layer with activation function.
.. math::
y = activation(xW^T + b)
Args:
in_features (int): number of input feature :math:`x`.
out_features (int): number of output features :math:`y`.
bias (bool, optional): if False, the layer will not adapt bias :math:`b`.
activation (callable, optional): if None, no activation function is used.
weight_init (callable, optional): weight initializer from current weight.
bias_init (callable, optional): bias initializer from current bias.
"""
def __init__(
self,
in_features,
out_features,
bias=True,
activation=None,
weight_init=xavier_uniform_,
bias_init=zeros_initializer,
):
self.weight_init = weight_init
self.bias_init = bias_init
super(Dense, self).__init__(in_features, out_features, bias)
self.activation = activation
# initialize linear layer y = xW^T + b
[docs]
def reset_parameters(self):
"""Reinitialize model weight and bias values."""
self.weight_init(self.weight)
if self.bias is not None:
self.bias_init(self.bias)
[docs]
def forward(self, inputs):
"""Compute layer output.
Args:
inputs (dict of torch.Tensor): batch of input values.
Returns:
torch.Tensor: layer output.
"""
# compute linear layer y = xW^T + b
y = super(Dense, self).forward(inputs)
# add activation function
if self.activation:
y = self.activation(y)
return y