Model Components

This section contains various general components that make up the HamGNN v2.1 models.

Embeddings

Radial and pair embeddings for edges and nodes in equivariant HamGNN stacks.

Includes radial-basis edge encodings, one-hot and spherical-harmonic style embedders, and pair-interaction feature blocks used with NequIP-style GraphModuleMixin.

class hamgnn.nn.embeddings.RadialBasisEdgeEncoding(basis=None, cutoff=None, out_field='edge_embedding', irreps_in=None)[source]

Bases: GraphModuleMixin, Module

Encodes edge lengths using a specified radial basis.

out_field

The key for storing the encoded edge features.

Type:

str

forward(data)[source]

Computes the edge encoding and updates the data dictionary.

Parameters:

data (Dict[str, Tensor]) – A dictionary containing graph data.

Return type:

Dict[str, Tensor]

Returns:

Updated graph data with encoded edge features.

class hamgnn.nn.embeddings.EdgeScalarEmbedding(irreps_node_attrs, irreps_edge_embed, irreps_edge_scalars)[source]

Bases: Module

A layer to compute edge scalars from node attributes and edge embeddings.

Parameters:
  • irreps_node_attrs (Irreps) – Irreps for node attributes.

  • irreps_edge_embed (Irreps) – Irreps for edge embeddings.

  • irreps_edge_scalars (Irreps) – Irreps for edge scalars.

forward(node_attr_src, node_attr_dst, edge_embed)[source]

Forward pass to compute edge scalars.

Parameters:
  • node_attr_src (Tensor) – Source node attributes.

  • node_attr_dst (Tensor) – Destination node attributes.

  • edge_embed (Tensor) – Edge embeddings.

Returns:

Computed edge scalars.

Return type:

Tensor

class hamgnn.nn.embeddings.LocalEnvironmentEmbedding(irreps_edge_attrs, irreps_edge_embed, irreps_node_attrs, irreps_edge_scalars, irreps_env_sh, radial_MLP=[64, 64], use_kan=False)[source]

Bases: Module

Embeds local environments using node and edge attributes, edge embeddings, and spherical harmonics.

Parameters:
  • irreps_edge_attrs (Irreps) – Irreps for edge attributes.

  • irreps_edge_embed (Irreps) – Irreps for edge embeddings.

  • irreps_node_attrs (Irreps) – Irreps for node attributes.

  • irreps_edge_scalars (Irreps) – Irreps for edge scalars.

  • irreps_env_sh (Irreps) – Irreps for environment spherical harmonics.

  • radial_mlp_dims (list[int]) – Dimensions for the radial MLP.

  • use_kan (bool) – Whether to use the KAN model.

forward(edge_index, node_attr, edge_attr, edge_embed)[source]

Forward pass to compute local environment embeddings.

Parameters:
  • edge_index (Tensor) – Indices of the edges.

  • node_attr (Tensor) – Node attributes.

  • edge_attr (Tensor) – Edge attributes.

  • edge_embed (Tensor) – Edge embeddings.

Returns:

Local environment embeddings.

Return type:

Tensor

class hamgnn.nn.embeddings.PairInteractionEmbeddingBlock(irreps_node_feats, irreps_edge_attrs, irreps_node_attrs, irreps_edge_embed, irreps_edge_feats, use_kan=False, radial_MLP=None, nonlinearity_type='gate', nonlinearity_scalars={'e': 'ssp', 'o': 'tanh'}, nonlinearity_gates={'e': 'ssp', 'o': 'abs'}, lite_mode=False)[source]

Bases: Module

A pair interaction block for updating edge features based on node features and edge attributes.

Parameters: - irreps_node_feats (o3.Irreps): Irreducible representations for node features. - irreps_edge_attrs (o3.Irreps): Irreducible representations for edge attributes. - irreps_edge_embed (o3.Irreps): Irreducible representations for edge embeddings. - irreps_edge_feats (o3.Irreps): Irreducible representations for edge features. - use_skip_connections (bool): Whether to use skip connections. - use_kan (bool): Whether to use KAN for radial MLP. - radial_MLP (Optional[List[int]]): Architecture of the radial MLP. - nonlinearity_type (str): Type of nonlinearity to use (“gate” or “norm”). - nonlinearity_scalars (Dict[int, Callable]): Nonlinearity for scalar channels. - nonlinearity_gates (Dict[int, Callable]): Nonlinearity for gate channels.

create_linear(irreps_in, irreps_out=None)[source]

Create a linear layer.

create_tensor_product(irreps_mid, instructions)[source]

Create a TensorProduct layer.

init_weight_generator(input_dim, weight_numel)[source]

Initialize weight generator.

forward(data)[source]

Forward pass of the pair interaction block.

Parameters: - data (Dict[str, torch.Tensor]): A dictionary containing the graph data.

Returns: - torch.Tensor: Updated edge features.

Return type:

Tensor

class hamgnn.nn.embeddings.Embedding(num_features, Zmax=87)[source]

Bases: Module

reset_parameters()[source]
forward(Z)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Electron Configurations

Nuclear charges and electron configurations (scaled between 0 and 1) for all elements up to Z=86 This encourages the NN to learn atom embeddings that generalize across the periodic table

MLP Utilities

hamgnn.utils.mlp.linear_bn_act(in_features, out_features, lbias=True, activation=ELU(alpha=1.0), use_batch_norm=True)[source]

Create a sequential module that includes a linear layer, optional batch normalization, and activation functions

Parameters:
  • in_features (int) – Number of input features

  • out_features (int) – Number of output features

  • lbias (bool) – Whether a bias is included in the linear layer

  • activation (callable) – The activation function to be used

  • use_batch_norm (bool) – Whether it includes batch normalization or not

Returns:

A sequential module containing a linear layer, an optional batch normalization, and an activation function

Return type:

torch.nn.Sequential

class hamgnn.utils.mlp.denseLayer(in_features=None, out_features=None, bias=True, use_batch_norm=True, activation=ELU(alpha=1.0))[source]

Bases: Module

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class hamgnn.utils.mlp.Dense(in_features, out_features, bias=True, activation=None, weight_init=<function xavier_uniform_>, bias_init=functools.partial(<function constant_>, val=0.0))[source]

Bases: Linear

From schnetpack Fully connected linear layer with activation function. .. math:

y = activation(xW^T + b)
Parameters:
  • in_features (int) – number of input feature \(x\).

  • out_features (int) – number of output features \(y\).

  • bias (bool, optional) – if False, the layer will not adapt bias \(b\).

  • activation (callable, optional) – if None, no activation function is used.

  • weight_init (callable, optional) – weight initializer from current weight.

  • bias_init (callable, optional) – bias initializer from current bias.

reset_parameters()[source]

Reinitialize model weight and bias values.

forward(inputs)[source]

Compute layer output. :type inputs: :param inputs: batch of input values. :type inputs: dict of torch.Tensor

Returns:

layer output.

Return type:

torch.Tensor

Activation Functions

hamgnn.utils.activation.shifted_softplus(x)[source]
hamgnn.utils.activation.switch_function(x, cuton, cutoff)[source]
hamgnn.utils.activation.softplus_inverse(x)[source]
class hamgnn.utils.activation.SoftUnitStepCutoff(cutoff)[source]

Bases: Module

A PyTorch module that applies a soft unit step function with a cutoff.

cutoff

The distance at which the cutoff is applied.

Type:

float

cut_param

A learnable parameter influencing the softness of the step.

Type:

nn.Parameter

forward(edge_distance)[source]

Forward pass for the module.

Applies the soft unit step function to the input edge distances.

Parameters:

edge_distance (Tensor) – A tensor containing edge distances.

Returns:

A tensor with the calculated edge weights after applying the cutoff.

Return type:

Tensor

hamgnn.utils.activation.swish(x)[source]
class hamgnn.utils.activation.SSP(beta=1, threshold=20)[source]

Bases: Module

Applies element-wise \(\text{SSP}(x)=\text{Softplus}(x)-\text{Softplus}(0)\)

Shifted SoftPlus (SSP)

Parameters:
  • beta – the \(\beta\) value for the Softplus formulation. Default: 1

  • threshold – values above this revert to a linear function. Default: 20

Shape:
  • Input: \((N, *)\) where * means, any number of additional dimensions

  • Output: \((N, *)\), same shape as the input

forward(input)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

extra_repr()[source]

Set the extra representation of the module

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

class hamgnn.utils.activation.SWISH[source]

Bases: Module

forward(input)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

hamgnn.utils.activation.get_activation(name)[source]

Basis Functions

class hamgnn.utils.basis_functions.BernsteinRadialBasisFunctions(num_basis_functions, cutoff)[source]

Bases: Module

reset_parameters()[source]
forward(r)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class hamgnn.utils.basis_functions.ExponentialBernsteinRadialBasisFunctions(num_basis_functions, cutoff, ini_alpha=0.5)[source]

Bases: Module

reset_parameters()[source]
forward(r)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class hamgnn.utils.basis_functions.ExponentialGaussianRadialBasisFunctions(num_basis_functions, cutoff, ini_alpha=0.5)[source]

Bases: Module

reset_parameters()[source]
forward(r)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class hamgnn.utils.basis_functions.GaussianRadialBasisFunctions(num_basis_functions, cutoff)[source]

Bases: Module

reset_parameters()[source]
forward(r)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class hamgnn.utils.basis_functions.OverlapBernsteinRadialBasisFunctions(num_basis_functions, cutoff, ini_alpha=0.5)[source]

Bases: Module

reset_parameters()[source]
forward(r)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class hamgnn.utils.basis_functions.sph_harm_layer(num_spherical)[source]

Bases: Module

forward(angle)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class hamgnn.utils.basis_functions.BesselBasis(cutoff=5.0, n_rbf=None, cutoff_func=None)[source]

Bases: Module

Sine for radial basis expansion with coulomb decay. (0th order Bessel from DimeNet)

forward(dist)[source]

Computes the 0th order Bessel expansion of inter-atomic distances. :type dist: :param dist: inter-atomic distances with (N_edge,) shape :type dist: torch.Tensor

Returns:

the 0th order Bessel expansion of inter-atomic distances with (N_edge, n_rbf) shape.

Return type:

rbf (torch.Tensor)

class hamgnn.utils.basis_functions.GaussianSmearing(start=0.0, stop=5.0, num_gaussians=50, cutoff_func=None)[source]

Bases: Module

forward(dist)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Cutoff Functions

hamgnn.utils.cutoff_functions.cutoff_function(x, cutoff)[source]
class hamgnn.utils.cutoff_functions.cuttoff_envelope(cutoff, exponent=6)[source]

Bases: Module

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class hamgnn.utils.cutoff_functions.CosineCutoff(cutoff=5.0)[source]

Bases: Module

Class of Behler cosine cutoff. From schnetpack .. math:

f(r) = \begin{cases}
 0.5 \times \left[1 + \cos\left(\frac{\pi r}{r_\text{cutoff}}\right)\right]
   & r < r_\text{cutoff} \\
 0 & r \geqslant r_\text{cutoff} \\
 \end{cases}
Parameters:

cutoff (float, optional) – cutoff radius.

forward(distances)[source]

Compute cutoff. :type distances: :param distances: values of interatomic distances. :type distances: torch.Tensor

Returns:

values of cutoff function.

Return type:

torch.Tensor

class hamgnn.utils.cutoff_functions.SoftUnitStepCutoff(cutoff)[source]

Bases: Module

A PyTorch module that applies a soft unit step function with a cutoff.

cutoff

The distance at which the cutoff is applied.

Type:

float

cut_param

A learnable parameter influencing the softness of the step.

Type:

nn.Parameter

forward(edge_distance)[source]

Forward pass for the module.

Applies the soft unit step function to the input edge distances.

Parameters:

edge_distance (Tensor) – A tensor containing edge distances.

Returns:

A tensor with the calculated edge weights after applying the cutoff.

Return type:

Tensor

Regression Layers

class hamgnn.utils.regression_layers.denseRegression(in_features=None, out_features=None, bias=True, use_batch_norm=True, activation=Softplus(beta=1, threshold=20), n_h=3)[source]

Bases: Module

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class hamgnn.utils.regression_layers.MLPRegression(num_in_features, num_out_features, num_mlp=3, lbias=False, activation=ELU(alpha=1.0), use_batch_norm=False)[source]

Bases: Module

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Hyperparameter Configuration

hamgnn.utils.hparam.get_hparam_dict(config=None)[source]