import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy.special import binom
import math
import sympy as sym
from torch_geometric.nn.models.dimenet_utils import real_sph_harm
from .cutoff_functions import cutoff_function
from .activation import softplus_inverse
"""
computes radial basis functions with Bernstein polynomials
"""
[docs]
class BernsteinRadialBasisFunctions(nn.Module):
def __init__(self, num_basis_functions, cutoff):
super(BernsteinRadialBasisFunctions, self).__init__()
self.num_basis_functions = num_basis_functions
#compute values to initialize buffers
logfactorial = np.zeros((num_basis_functions))
for i in range(2,num_basis_functions):
logfactorial[i] = logfactorial[i-1] + np.log(i)
v = np.arange(0,num_basis_functions)
n = (num_basis_functions-1)-v
logbinomial = logfactorial[-1]-logfactorial[v]-logfactorial[n]
#register buffers and parameters
self.register_buffer('cutoff', torch.tensor(cutoff, dtype=torch.float64))
self.register_buffer('logc', torch.tensor(logbinomial, dtype=torch.float64))
self.register_buffer('n', torch.tensor(n, dtype=torch.float64))
self.register_buffer('v', torch.tensor(v, dtype=torch.float64))
self.reset_parameters()
[docs]
def reset_parameters(self):
pass
[docs]
def forward(self, r):
x = torch.log(r/self.cutoff)
x = self.logc + self.n*x + self.v*torch.log(-torch.expm1(x))
rbf = cutoff_function(r, self.cutoff) * torch.exp(x)
return rbf
"""
computes radial basis functions with exponential Bernstein polynomials
"""
[docs]
class ExponentialBernsteinRadialBasisFunctions(nn.Module):
def __init__(self, num_basis_functions, cutoff, ini_alpha=0.5):
super(ExponentialBernsteinRadialBasisFunctions, self).__init__()
self.num_basis_functions = num_basis_functions
self.ini_alpha = ini_alpha
#compute values to initialize buffers
logfactorial = np.zeros((num_basis_functions))
for i in range(2,num_basis_functions):
logfactorial[i] = logfactorial[i-1] + np.log(i)
v = np.arange(0,num_basis_functions)
n = (num_basis_functions-1)-v
logbinomial = logfactorial[-1]-logfactorial[v]-logfactorial[n]
#register buffers and parameters
self.register_buffer('cutoff', torch.tensor(cutoff, dtype=torch.float64))
self.register_buffer('logc', torch.tensor(logbinomial, dtype=torch.float64))
self.register_buffer('n', torch.tensor(n, dtype=torch.float64))
self.register_buffer('v', torch.tensor(v, dtype=torch.float64))
self.register_parameter('_alpha', nn.Parameter(torch.tensor(1.0, dtype=torch.float64)))
self.reset_parameters()
[docs]
def reset_parameters(self):
nn.init.constant_(self._alpha, softplus_inverse(self.ini_alpha))
[docs]
def forward(self, r):
alpha = F.softplus(self._alpha)
x = -alpha*r
x = self.logc + self.n*x + self.v*torch.log(-torch.expm1(x))
rbf = cutoff_function(r, self.cutoff) * torch.exp(x)
return rbf
"""
computes radial basis functions with exponential Gaussians
"""
[docs]
class ExponentialGaussianRadialBasisFunctions(nn.Module):
def __init__(self, num_basis_functions, cutoff, ini_alpha=0.5):
super(ExponentialGaussianRadialBasisFunctions, self).__init__()
self.num_basis_functions = num_basis_functions
self.ini_alpha = ini_alpha
self.register_buffer('cutoff', torch.tensor(cutoff, dtype=torch.float64))
self.register_buffer('center', torch.linspace(1, 0, self.num_basis_functions, dtype=torch.float64))
self.register_buffer('width', torch.tensor(1.0*self.num_basis_functions, dtype=torch.float64))
self.register_parameter('_alpha', nn.Parameter(torch.tensor(1.0, dtype=torch.float64)))
self.reset_parameters()
[docs]
def reset_parameters(self):
nn.init.constant_(self._alpha, softplus_inverse(self.ini_alpha))
[docs]
def forward(self, r):
alpha = F.softplus(self._alpha)
rbf = cutoff_function(r, self.cutoff) * torch.exp(-self.width*(torch.exp(-alpha*r)-self.center)**2)
return rbf
"""
computes radial basis functions with exponential Gaussians
"""
[docs]
class GaussianRadialBasisFunctions(nn.Module):
def __init__(self, num_basis_functions, cutoff):
super(GaussianRadialBasisFunctions, self).__init__()
self.num_basis_functions = num_basis_functions
self.register_buffer('cutoff', torch.tensor(cutoff, dtype=torch.float64))
self.register_buffer('center', torch.linspace(0, cutoff, self.num_basis_functions, dtype=torch.float64))
self.register_buffer('width', torch.tensor(self.num_basis_functions/cutoff, dtype=torch.float64))
#for compatibility with other basis functions on tensorboard, doesn't do anything
self.register_parameter('_alpha', nn.Parameter(torch.tensor(1.0, dtype=torch.float64)))
self.reset_parameters()
[docs]
def reset_parameters(self):
pass
[docs]
def forward(self, r):
rbf = cutoff_function(r, self.cutoff) * torch.exp(-self.width*(r-self.center)**2)
return rbf
"""
computes radial basis functions with overlap Bernstein polynomials
"""
[docs]
class OverlapBernsteinRadialBasisFunctions(nn.Module):
def __init__(self, num_basis_functions, cutoff, ini_alpha=0.5):
super(OverlapBernsteinRadialBasisFunctions, self).__init__()
self.num_basis_functions = num_basis_functions
self.ini_alpha = ini_alpha
#compute values to initialize buffers
logfactorial = np.zeros((num_basis_functions))
for i in range(2,num_basis_functions):
logfactorial[i] = logfactorial[i-1] + np.log(i)
v = np.arange(0,num_basis_functions)
n = (num_basis_functions-1)-v
logbinomial = logfactorial[-1]-logfactorial[v]-logfactorial[n]
#register buffers and parameters
self.register_buffer('cutoff', torch.tensor(cutoff, dtype=torch.float64))
self.register_buffer('logc', torch.tensor(logbinomial, dtype=torch.float64))
self.register_buffer('n', torch.tensor(n, dtype=torch.float64))
self.register_buffer('v', torch.tensor(v, dtype=torch.float64))
self.register_parameter('_alpha', nn.Parameter(torch.tensor(1.0, dtype=torch.float64)))
self.reset_parameters()
[docs]
def reset_parameters(self):
nn.init.constant_(self._alpha, softplus_inverse(self.ini_alpha))
[docs]
def forward(self, r):
alpha_r = F.softplus(self._alpha)*r
x = torch.log1p(alpha_r)-alpha_r
x = self.logc + self.n*x + self.v*torch.log(-torch.expm1(x))
rbf = cutoff_function(r, self.cutoff) * torch.exp(x)
return rbf
[docs]
class sph_harm_layer(nn.Module):
def __init__(self, num_spherical):
super(sph_harm_layer, self).__init__()
self.num_spherical = num_spherical
sph_harm_forms = real_sph_harm(num_spherical)
self.sph_funcs = []
theta = sym.symbols('theta')
modules = {'sin': torch.sin, 'cos': torch.cos}
for i in range(num_spherical):
if i == 0:
sph1 = sym.lambdify([theta], sph_harm_forms[i][0], modules)(0)
self.sph_funcs.append(lambda x: torch.zeros_like(x) + sph1)
else:
sph = sym.lambdify([theta], sph_harm_forms[i][0], modules)
self.sph_funcs.append(sph)
[docs]
def forward(self, angle):
out = torch.cat([f(angle.unsqueeze(-1)) for f in self.sph_funcs], dim=-1)
return out
[docs]
class BesselBasis(nn.Module):
"""
Sine for radial basis expansion with coulomb decay. (0th order Bessel from DimeNet)
"""
def __init__(self, cutoff=5.0, n_rbf:int=None, cutoff_func=None):
"""
Args:
cutoff: radial cutoff
n_rbf: number of basis functions.
"""
super(BesselBasis, self).__init__()
# compute offset and width of Gaussian functions
freqs = torch.arange(1, n_rbf + 1) * math.pi / cutoff
self.register_buffer("freqs", freqs)
self.cutoff_func = cutoff_func
[docs]
def forward(self, dist):
r"""Computes the 0th order Bessel expansion of inter-atomic distances.
Args:
dist (torch.Tensor):
inter-atomic distances with (N_edge,) shape
Returns:
rbf (torch.Tensor):
the 0th order Bessel expansion of inter-atomic distances
with (N_edge, n_rbf) shape.
"""
a = self.freqs[None, :]
ax = dist.unsqueeze(-1) * a
rbf = torch.sin(ax) / dist.unsqueeze(-1)
if self.cutoff_func is not None:
rbf = rbf * self.cutoff_func(dist.unsqueeze(-1))
return rbf
[docs]
class GaussianSmearing(nn.Module):
def __init__(self, start=0.0, stop=5.0, num_gaussians=50, cutoff_func=None):
super(GaussianSmearing, self).__init__()
offset = torch.linspace(start, stop, num_gaussians)
self.coeff = -0.5 / (offset[1] - offset[0]).item()**2
self.register_buffer('offset', offset)
self.cutoff_func = cutoff_func
[docs]
def forward(self, dist):
dist = dist.view(-1, 1) - self.offset.view(1, -1)
expansion = torch.exp(self.coeff * torch.pow(dist, 2))
if self.cutoff_func is not None:
expansion = expansion*self.cutoff_func(dist)
return expansion