Model Components
This section contains various general components that make up the HamGNN v2.1 models.
Embeddings
Radial and pair embeddings for edges and nodes in equivariant HamGNN stacks.
Includes radial-basis edge encodings, one-hot and spherical-harmonic style embedders,
and pair-interaction feature blocks used with NequIP-style GraphModuleMixin.
- class hamgnn.nn.embeddings.RadialBasisEdgeEncoding(basis=None, cutoff=None, out_field='edge_embedding', irreps_in=None)[source]
Bases:
GraphModuleMixin,ModuleEncodes edge lengths using a specified radial basis.
- class hamgnn.nn.embeddings.EdgeScalarEmbedding(irreps_node_attrs, irreps_edge_embed, irreps_edge_scalars)[source]
Bases:
ModuleA layer to compute edge scalars from node attributes and edge embeddings.
- Parameters:
irreps_node_attrs (Irreps) – Irreps for node attributes.
irreps_edge_embed (Irreps) – Irreps for edge embeddings.
irreps_edge_scalars (Irreps) – Irreps for edge scalars.
- forward(node_attr_src, node_attr_dst, edge_embed)[source]
Forward pass to compute edge scalars.
- Parameters:
node_attr_src (Tensor) – Source node attributes.
node_attr_dst (Tensor) – Destination node attributes.
edge_embed (Tensor) – Edge embeddings.
- Returns:
Computed edge scalars.
- Return type:
Tensor
- class hamgnn.nn.embeddings.LocalEnvironmentEmbedding(irreps_edge_attrs, irreps_edge_embed, irreps_node_attrs, irreps_edge_scalars, irreps_env_sh, radial_MLP=[64, 64], use_kan=False)[source]
Bases:
ModuleEmbeds local environments using node and edge attributes, edge embeddings, and spherical harmonics.
- Parameters:
irreps_edge_attrs (Irreps) – Irreps for edge attributes.
irreps_edge_embed (Irreps) – Irreps for edge embeddings.
irreps_node_attrs (Irreps) – Irreps for node attributes.
irreps_edge_scalars (Irreps) – Irreps for edge scalars.
irreps_env_sh (Irreps) – Irreps for environment spherical harmonics.
radial_mlp_dims (list[int]) – Dimensions for the radial MLP.
use_kan (bool) – Whether to use the KAN model.
- forward(edge_index, node_attr, edge_attr, edge_embed)[source]
Forward pass to compute local environment embeddings.
- Parameters:
edge_index (Tensor) – Indices of the edges.
node_attr (Tensor) – Node attributes.
edge_attr (Tensor) – Edge attributes.
edge_embed (Tensor) – Edge embeddings.
- Returns:
Local environment embeddings.
- Return type:
Tensor
- class hamgnn.nn.embeddings.PairInteractionEmbeddingBlock(irreps_node_feats, irreps_edge_attrs, irreps_node_attrs, irreps_edge_embed, irreps_edge_feats, use_kan=False, radial_MLP=None, nonlinearity_type='gate', nonlinearity_scalars={'e': 'ssp', 'o': 'tanh'}, nonlinearity_gates={'e': 'ssp', 'o': 'abs'}, lite_mode=False)[source]
Bases:
ModuleA pair interaction block for updating edge features based on node features and edge attributes.
Parameters: - irreps_node_feats (o3.Irreps): Irreducible representations for node features. - irreps_edge_attrs (o3.Irreps): Irreducible representations for edge attributes. - irreps_edge_embed (o3.Irreps): Irreducible representations for edge embeddings. - irreps_edge_feats (o3.Irreps): Irreducible representations for edge features. - use_skip_connections (bool): Whether to use skip connections. - use_kan (bool): Whether to use KAN for radial MLP. - radial_MLP (Optional[List[int]]): Architecture of the radial MLP. - nonlinearity_type (str): Type of nonlinearity to use (“gate” or “norm”). - nonlinearity_scalars (Dict[int, Callable]): Nonlinearity for scalar channels. - nonlinearity_gates (Dict[int, Callable]): Nonlinearity for gate channels.
- class hamgnn.nn.embeddings.Embedding(num_features, Zmax=87)[source]
Bases:
Module- forward(Z)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
Electron Configurations
Nuclear charges and electron configurations (scaled between 0 and 1) for all elements up to Z=86 This encourages the NN to learn atom embeddings that generalize across the periodic table
MLP Utilities
- hamgnn.utils.mlp.linear_bn_act(in_features, out_features, lbias=True, activation=ELU(alpha=1.0), use_batch_norm=True)[source]
Create a sequential module that includes a linear layer, optional batch normalization, and activation functions
- Parameters:
- Returns:
A sequential module containing a linear layer, an optional batch normalization, and an activation function
- Return type:
- class hamgnn.utils.mlp.denseLayer(in_features=None, out_features=None, bias=True, use_batch_norm=True, activation=ELU(alpha=1.0))[source]
Bases:
Module- forward(x)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class hamgnn.utils.mlp.Dense(in_features, out_features, bias=True, activation=None, weight_init=<function xavier_uniform_>, bias_init=functools.partial(<function constant_>, val=0.0))[source]
Bases:
LinearFrom schnetpack Fully connected linear layer with activation function. .. math:
y = activation(xW^T + b)
- Parameters:
in_features (int) – number of input feature \(x\).
out_features (int) – number of output features \(y\).
bias (bool, optional) – if False, the layer will not adapt bias \(b\).
activation (callable, optional) – if None, no activation function is used.
weight_init (callable, optional) – weight initializer from current weight.
bias_init (callable, optional) – bias initializer from current bias.
Activation Functions
- class hamgnn.utils.activation.SoftUnitStepCutoff(cutoff)[source]
Bases:
ModuleA PyTorch module that applies a soft unit step function with a cutoff.
- cut_param
A learnable parameter influencing the softness of the step.
- Type:
nn.Parameter
- class hamgnn.utils.activation.SSP(beta=1, threshold=20)[source]
Bases:
ModuleApplies element-wise \(\text{SSP}(x)=\text{Softplus}(x)-\text{Softplus}(0)\)
Shifted SoftPlus (SSP)
- Parameters:
beta – the \(\beta\) value for the Softplus formulation. Default: 1
threshold – values above this revert to a linear function. Default: 20
- Shape:
Input: \((N, *)\) where * means, any number of additional dimensions
Output: \((N, *)\), same shape as the input
- forward(input)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class hamgnn.utils.activation.SWISH[source]
Bases:
Module- forward(input)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
Basis Functions
- class hamgnn.utils.basis_functions.BernsteinRadialBasisFunctions(num_basis_functions, cutoff)[source]
Bases:
Module- forward(r)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class hamgnn.utils.basis_functions.ExponentialBernsteinRadialBasisFunctions(num_basis_functions, cutoff, ini_alpha=0.5)[source]
Bases:
Module- forward(r)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class hamgnn.utils.basis_functions.ExponentialGaussianRadialBasisFunctions(num_basis_functions, cutoff, ini_alpha=0.5)[source]
Bases:
Module- forward(r)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class hamgnn.utils.basis_functions.GaussianRadialBasisFunctions(num_basis_functions, cutoff)[source]
Bases:
Module- forward(r)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class hamgnn.utils.basis_functions.OverlapBernsteinRadialBasisFunctions(num_basis_functions, cutoff, ini_alpha=0.5)[source]
Bases:
Module- forward(r)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class hamgnn.utils.basis_functions.sph_harm_layer(num_spherical)[source]
Bases:
Module- forward(angle)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class hamgnn.utils.basis_functions.BesselBasis(cutoff=5.0, n_rbf=None, cutoff_func=None)[source]
Bases:
ModuleSine for radial basis expansion with coulomb decay. (0th order Bessel from DimeNet)
- forward(dist)[source]
Computes the 0th order Bessel expansion of inter-atomic distances. :type dist: :param dist: inter-atomic distances with (N_edge,) shape :type dist: torch.Tensor
- Returns:
the 0th order Bessel expansion of inter-atomic distances with (N_edge, n_rbf) shape.
- Return type:
rbf (torch.Tensor)
- class hamgnn.utils.basis_functions.GaussianSmearing(start=0.0, stop=5.0, num_gaussians=50, cutoff_func=None)[source]
Bases:
Module- forward(dist)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
Cutoff Functions
- class hamgnn.utils.cutoff_functions.cuttoff_envelope(cutoff, exponent=6)[source]
Bases:
Module- forward(x)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class hamgnn.utils.cutoff_functions.CosineCutoff(cutoff=5.0)[source]
Bases:
ModuleClass of Behler cosine cutoff. From schnetpack .. math:
f(r) = \begin{cases} 0.5 \times \left[1 + \cos\left(\frac{\pi r}{r_\text{cutoff}}\right)\right] & r < r_\text{cutoff} \\ 0 & r \geqslant r_\text{cutoff} \\ \end{cases}
- Parameters:
cutoff (float, optional) – cutoff radius.
Regression Layers
- class hamgnn.utils.regression_layers.denseRegression(in_features=None, out_features=None, bias=True, use_batch_norm=True, activation=Softplus(beta=1, threshold=20), n_h=3)[source]
Bases:
Module- forward(x)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class hamgnn.utils.regression_layers.MLPRegression(num_in_features, num_out_features, num_mlp=3, lbias=False, activation=ELU(alpha=1.0), use_batch_norm=False)[source]
Bases:
Module- forward(x)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.