Interaction Blocks

Neural Network Interaction Components

Pair and correlation interaction blocks for updating node and edge irreps.

Provides PairInteractionBlock, CorrProductBlock, residual wrappers, and MACE-compatible equivariant product pathways.

class hamgnn.nn.interaction_blocks.PairInteractionBlock(irreps_node_feats, irreps_node_attrs, irreps_edge_attrs, irreps_edge_embed, irreps_edge_feats, use_skip_connections=False, legacy_edge_update=False, use_kan=False, radial_MLP=None, nonlinearity_type='gate', nonlinearity_scalars={'e': 'ssp', 'o': 'tanh'}, nonlinearity_gates={'e': 'ssp', 'o': 'abs'}, lite_mode=False)[source]

Bases: Module

A pair interaction block for updating edge features based on node features and edge attributes.

Parameters: - irreps_node_feats (o3.Irreps): Irreducible representations for node features. - irreps_edge_attrs (o3.Irreps): Irreducible representations for edge attributes. - irreps_edge_embed (o3.Irreps): Irreducible representations for edge embeddings. - irreps_edge_feats (o3.Irreps): Irreducible representations for edge features. - use_skip_connections (bool): Whether to use skip connections. - legacy_edge_update (bool): If True and use_skip_connections is False, keep legacy (buggy) behavior

where edge features are not updated by conv_tp output. Use only for reproducing old checkpoints.

  • use_kan (bool): Whether to use KAN for radial MLP.

  • radial_MLP (Optional[List[int]]): Architecture of the radial MLP. Defaults to [64, 64, 64].

  • nonlinearity_type (str): Type of nonlinearity to use (“gate” or “norm”).

  • nonlinearity_scalars (Dict[int, Callable]): Nonlinearity for scalar channels.

  • nonlinearity_gates (Dict[int, Callable]): Nonlinearity for gate channels.

  • lite_mode (bool): The mode with the fewest model parameters and the fastest running speed.

create_linear(irreps_in, irreps_out=None)[source]

Create a linear layer.

Parameters: - irreps_in (o3.Irreps): Input irreps for the linear layer. - irreps_out (o3.Irreps, optional): Output irreps for the linear layer.

Returns: - o3.Linear: A linear transformation layer.

forward(data)[source]

Forward pass of the pair interaction block.

Parameters: - data (Dict[str, torch.Tensor]): A dictionary containing the graph data.

Returns: - torch.Tensor: Updated edge features.

Return type:

Tensor

class hamgnn.nn.interaction_blocks.CorrProductBlock(irreps_node_feats, num_hidden_features, correlation, use_skip_connections=True, num_elements=None)[source]

Bases: Module

A correlation product block for updating node features using an equivariant product operation.

Parameters: - irreps_node_feats (o3.Irreps): Irreducible representations for node features. - num_hidden_features (int): Number of hidden features. - correlation (int): Correlation level for the product operation. - use_skip_connections (bool): Whether to use skip connections. - num_elements (int): Number of elements for the product operation.

forward(data)[source]

Forward pass of the correlation product block.

Parameters: - data (Dict[str, torch.Tensor]): A dictionary containing the graph data.

Returns: - torch.Tensor: Updated node features.

Return type:

Tensor

class hamgnn.nn.interaction_blocks.ResidualBlock(irreps_in, feature_irreps_hidden, resnet=True, nonlinearity_type='gate', nonlinearity_scalars={'e': 'ssp', 'o': 'tanh'}, nonlinearity_gates={'e': 'ssp', 'o': 'abs'})[source]

Bases: Module

A residual block used in equivariant neural networks.

Parameters:
  • irreps_in (str) – The input irreducible representations (irreps).

  • feature_irreps_hidden (str) – The hidden feature irreps.

  • resnet (bool) – If True, apply a residual connection.

  • nonlinearity_type (str) – The type of nonlinearity to apply (‘gate’ or ‘norm’).

  • nonlinearity_scalars (Dict[int, Callable]) – A dictionary mapping parity to nonlinearity functions for scalar features.

  • nonlinearity_gates (Dict[int, Callable]) – A dictionary mapping parity to nonlinearity functions for gated features.

create_nonlinearity(nonlinearity_type, irreps_mid, nonlinearity_scalars, nonlinearity_gates)[source]

Create nonlinearity module.

forward(x)[source]

Forward pass of the residual block.

Parameters:

x (torch.Tensor) – Input tensor with shape matching irreps_in.

Returns:

Output tensor with shape matching irreps_in.

Return type:

torch.Tensor