Source code for hamgnn.utils.cutoff_functions

import torch
import torch.nn as nn
from math import pi
from e3nn.math import soft_unit_step
from e3nn.util.jit import compile_mode

"""
cutoff function that smoothly goes from y = 1..0 in the interval x = 0..cutoff
(this cutoff function has infinitely many smooth derivatives)
"""
[docs] def cutoff_function(x, cutoff): zeros = torch.zeros_like(x) x_ = torch.where(x < cutoff, x, zeros) return torch.where(x < cutoff, torch.exp(-x_**2/((cutoff-x_)*(cutoff+x_))), zeros)
[docs] class cuttoff_envelope(nn.Module): def __init__(self, cutoff, exponent=6): super(cuttoff_envelope, self).__init__() self.p = exponent self.a = -(self.p + 1) * (self.p + 2) / 2 self.b = self.p * (self.p + 2) self.c = -self.p * (self.p + 1) / 2 self.cutoff = cutoff
[docs] def forward(self, x): p, a, b, c = self.p, self.a, self.b, self.c x = x/self.cutoff x_pow_p0 = x.pow(p) x_pow_p1 = x_pow_p0 * x x_pow_p2 = x_pow_p1 * x return (1. + a * x_pow_p0 + b * x_pow_p1 + c * x_pow_p2) * (x < self.cutoff).float()
[docs] class CosineCutoff(nn.Module): r"""Class of Behler cosine cutoff. From schnetpack .. math:: f(r) = \begin{cases} 0.5 \times \left[1 + \cos\left(\frac{\pi r}{r_\text{cutoff}}\right)\right] & r < r_\text{cutoff} \\ 0 & r \geqslant r_\text{cutoff} \\ \end{cases} Args: cutoff (float, optional): cutoff radius. """ def __init__(self, cutoff=5.0): super(CosineCutoff, self).__init__() self.register_buffer("cutoff", torch.FloatTensor([cutoff]))
[docs] def forward(self, distances): """Compute cutoff. Args: distances (torch.Tensor): values of interatomic distances. Returns: torch.Tensor: values of cutoff function. """ # Compute values of cutoff function cutoffs = 0.5 * (torch.cos(distances * pi / self.cutoff) + 1.0) # Remove contributions beyond the cutoff radius cutoffs *= (distances < self.cutoff).float() return cutoffs
[docs] @compile_mode("script") class SoftUnitStepCutoff(nn.Module): """ A PyTorch module that applies a soft unit step function with a cutoff. Attributes: cutoff (float): The distance at which the cutoff is applied. cut_param (nn.Parameter): A learnable parameter influencing the softness of the step. """ def __init__(self, cutoff): """ Initializes the SoftUnitStepCutoff module. Args: cutoff (float): The cutoff distance for the step function. """ super(SoftUnitStepCutoff, self).__init__() self.cutoff = cutoff self.cut_param = nn.Parameter(torch.tensor(10.0, dtype=torch.get_default_dtype()))
[docs] def forward(self, edge_distance): """ Forward pass for the module. Applies the soft unit step function to the input edge distances. Args: edge_distance (Tensor): A tensor containing edge distances. Returns: Tensor: A tensor with the calculated edge weights after applying the cutoff. """ # Calculate the scaled difference and apply the soft unit step scaled_diff = self.cut_param * (1.0 - edge_distance / self.cutoff) edge_weight_cutoff = soft_unit_step(scaled_diff) return edge_weight_cutoff