Source code for hamgnn.nn.interaction_blocks

# Copyright (c) 2021-2026 HamGNN Team
# SPDX-License-Identifier: GPL-3.0-only

"""Pair and correlation interaction blocks for updating node and edge irreps.

Provides :class:`PairInteractionBlock`, :class:`CorrProductBlock`, residual wrappers,
and MACE-compatible equivariant product pathways.
"""

from typing import Callable, Dict, List, Optional, Tuple

import torch
from e3nn import o3
from e3nn.nn import Gate, NormActivation
from e3nn.util.jit import compile_mode
from torch import nn

from .message_passing import MessagePackBlock
from ..toolbox.nequip.data import AtomicDataDict
from ..toolbox.mace.modules.blocks import EquivariantProductBasisBlock
from ..toolbox.mace.modules.irreps_tools import (
    linear_out_irreps,
    reshape_irreps,
    tp_out_irreps_with_instructions,
)
from ..utils.irreps_utils import acts, irreps2gate


[docs] @compile_mode("script") class PairInteractionBlock(nn.Module): """ A pair interaction block for updating edge features based on node features and edge attributes. Parameters: - irreps_node_feats (o3.Irreps): Irreducible representations for node features. - irreps_edge_attrs (o3.Irreps): Irreducible representations for edge attributes. - irreps_edge_embed (o3.Irreps): Irreducible representations for edge embeddings. - irreps_edge_feats (o3.Irreps): Irreducible representations for edge features. - use_skip_connections (bool): Whether to use skip connections. - legacy_edge_update (bool): If True and use_skip_connections is False, keep legacy (buggy) behavior where edge features are not updated by conv_tp output. Use only for reproducing old checkpoints. - use_kan (bool): Whether to use KAN for radial MLP. - radial_MLP (Optional[List[int]]): Architecture of the radial MLP. Defaults to [64, 64, 64]. - nonlinearity_type (str): Type of nonlinearity to use ("gate" or "norm"). - nonlinearity_scalars (Dict[int, Callable]): Nonlinearity for scalar channels. - nonlinearity_gates (Dict[int, Callable]): Nonlinearity for gate channels. - lite_mode (bool): The mode with the fewest model parameters and the fastest running speed. """ def __init__( self, irreps_node_feats: o3.Irreps, irreps_node_attrs: o3.Irreps, irreps_edge_attrs: o3.Irreps, irreps_edge_embed: o3.Irreps, irreps_edge_feats: o3.Irreps, use_skip_connections: bool = False, legacy_edge_update: bool = False, use_kan: bool = False, radial_MLP: Optional[List[int]] = None, nonlinearity_type: str = "gate", nonlinearity_scalars: Dict[int, Callable] = {"e": "ssp", "o": "tanh"}, nonlinearity_gates: Dict[int, Callable] = {"e": "ssp", "o": "abs"}, lite_mode: bool = False ) -> None: super().__init__() self.radial_MLP = radial_MLP or [64, 64, 64] self.use_skip_connections = use_skip_connections self.legacy_edge_update = legacy_edge_update self.use_kan = use_kan self.lite_mode = lite_mode # Assign irreps self.irreps_node_feats = o3.Irreps(irreps_node_feats) self.irreps_edge_attrs = o3.Irreps(irreps_edge_attrs) self.irreps_edge_embed = o3.Irreps(irreps_edge_embed) self.irreps_edge_feats = o3.Irreps(irreps_edge_feats) self.irreps_node_attrs = o3.Irreps(irreps_node_attrs) assert nonlinearity_type in ( "gate", "norm"), "Invalid nonlinearity type." # Convert nonlinearity mappings scalar_nonlinearities = { 1: nonlinearity_scalars["e"], -1: nonlinearity_scalars["o"], } gate_nonlinearities = { 1: nonlinearity_gates["e"], -1: nonlinearity_gates["o"], } # Linear transformations self.linear_up_src = self.create_linear(self.irreps_node_feats) self.linear_up_tar = self.create_linear(self.irreps_node_feats) # TensorProduct layer for edge feature mixing self.conv_tp = MessagePackBlock( irreps_node_feats=self.irreps_node_feats, irreps_edge_feats=self.irreps_edge_feats, irreps_local_env_edge=self.irreps_edge_attrs, irreps_out=self.irreps_edge_feats, irreps_edge_scalars=self.irreps_edge_embed, radial_MLP=self.radial_MLP, use_kan=self.use_kan, lite_mode=self.lite_mode ) # Skip connection if self.use_skip_connections: self.skip_linear = self.create_linear( irreps_edge_feats, irreps_edge_feats)
[docs] def create_linear(self, irreps_in, irreps_out=None): """ Create a linear layer. Parameters: - irreps_in (o3.Irreps): Input irreps for the linear layer. - irreps_out (o3.Irreps, optional): Output irreps for the linear layer. Returns: - o3.Linear: A linear transformation layer. """ return o3.Linear( irreps_in, irreps_out or irreps_in, internal_weights=True, shared_weights=True )
[docs] def forward(self, data: Dict[str, torch.Tensor]) -> torch.Tensor: """ Forward pass of the pair interaction block. Parameters: - data (Dict[str, torch.Tensor]): A dictionary containing the graph data. Returns: - torch.Tensor: Updated edge features. """ edge_src, edge_dst = data[AtomicDataDict.EDGE_INDEX_KEY] node_feats = data[AtomicDataDict.NODE_FEATURES_KEY] edge_embed = data[AtomicDataDict.EDGE_EMBEDDING_KEY] edge_feats = data[AtomicDataDict.EDGE_FEATURES_KEY] # Mixing node features for edge features edge_feats_mix = self.conv_tp( self.linear_up_src(node_feats)[edge_src], self.linear_up_tar(node_feats)[edge_dst], edge_feats, data[AtomicDataDict.EDGE_ATTRS_KEY], edge_embed ) if self.use_skip_connections: edge_feats = edge_feats_mix + self.skip_linear(edge_feats) elif self.legacy_edge_update: # Legacy: when no skip, keep original edge_feats (reproduce old buggy behavior) pass else: edge_feats = edge_feats_mix data[AtomicDataDict.EDGE_FEATURES_KEY] = edge_feats return edge_feats
[docs] @compile_mode("script") class CorrProductBlock(nn.Module): """ A correlation product block for updating node features using an equivariant product operation. Parameters: - irreps_node_feats (o3.Irreps): Irreducible representations for node features. - num_hidden_features (int): Number of hidden features. - correlation (int): Correlation level for the product operation. - use_skip_connections (bool): Whether to use skip connections. - num_elements (int): Number of elements for the product operation. """ def __init__( self, irreps_node_feats: o3.Irreps, num_hidden_features: int, correlation: int, use_skip_connections: bool = True, num_elements: Optional[int] = None ) -> None: super().__init__() self.irreps_node_feats = o3.Irreps(irreps_node_feats).simplify() self.num_hidden_features = num_hidden_features self.correlation = correlation self.use_skip_connections = use_skip_connections self.num_elements = num_elements self.irreps_hidden_features = o3.Irreps( [(self.num_hidden_features, irrep.ir) for irrep in self.irreps_node_feats] ) # Linear layers for lifting and skip connection self.linear_pre = o3.Linear( self.irreps_node_feats, self.irreps_hidden_features, internal_weights=True, shared_weights=True, ) self.linear_sc = o3.Linear( self.irreps_node_feats, self.irreps_node_feats, internal_weights=True, shared_weights=True, ) # Equivariant product operation self.prod = EquivariantProductBasisBlock( node_feats_irreps=self.irreps_hidden_features, target_irreps=self.irreps_hidden_features, correlation=correlation, num_elements=num_elements, use_sc=False, ) # Linear layer for output self.linear_out = o3.Linear( self.irreps_hidden_features, self.irreps_node_feats, internal_weights=True, shared_weights=True, ) self.reshape = reshape_irreps(self.irreps_hidden_features)
[docs] def forward( self, data: Dict[str, torch.Tensor] ) -> torch.Tensor: """ Forward pass of the correlation product block. Parameters: - data (Dict[str, torch.Tensor]): A dictionary containing the graph data. Returns: - torch.Tensor: Updated node features. """ node_feats = self.linear_pre(data[AtomicDataDict.NODE_FEATURES_KEY]) # [n_nodes, channels, (l + 1)**2] node_feats = self.reshape(node_feats) out = self.prod(node_feats, None, data[AtomicDataDict.NODE_ATTRS_KEY]) out = self.linear_out(out) if self.use_skip_connections: sc = self.linear_sc(data[AtomicDataDict.NODE_FEATURES_KEY]) data[AtomicDataDict.NODE_FEATURES_KEY] = out + sc else: data[AtomicDataDict.NODE_FEATURES_KEY] = out return out
[docs] @compile_mode("script") class ResidualBlock(nn.Module): """ A residual block used in equivariant neural networks. Args: irreps_in (str): The input irreducible representations (irreps). feature_irreps_hidden (str): The hidden feature irreps. resnet (bool): If True, apply a residual connection. nonlinearity_type (str): The type of nonlinearity to apply ('gate' or 'norm'). nonlinearity_scalars (Dict[int, Callable]): A dictionary mapping parity to nonlinearity functions for scalar features. nonlinearity_gates (Dict[int, Callable]): A dictionary mapping parity to nonlinearity functions for gated features. """ def __init__( self, irreps_in: str, feature_irreps_hidden: str, resnet: bool = True, nonlinearity_type: str = "gate", nonlinearity_scalars: Dict[int, Callable] = {"e": "ssp", "o": "tanh"}, nonlinearity_gates: Dict[int, Callable] = {"e": "ssp", "o": "abs"}, ): super().__init__() # Ensure valid nonlinearity type assert nonlinearity_type in ( "gate", "norm"), "Invalid nonlinearity_type. Choose either 'gate' or 'norm'." # Convert scalar and gate nonlinearity based on parity nonlinearity_scalars = { 1: nonlinearity_scalars["e"], -1: nonlinearity_scalars["o"]} nonlinearity_gates = { 1: nonlinearity_gates["e"], -1: nonlinearity_gates["o"]} self.irreps_in = o3.Irreps(irreps_in) self.feature_irreps_hidden = o3.Irreps(feature_irreps_hidden) self.resnet = resnet self.equivariant_nonlin = self.create_nonlinearity( nonlinearity_type, self.feature_irreps_hidden, nonlinearity_scalars, nonlinearity_gates) # Define linear layers self.linear1 = o3.Linear( irreps_in=self.irreps_in, irreps_out=self.equivariant_nonlin.irreps_in) self.linear2 = o3.Linear( irreps_in=self.equivariant_nonlin.irreps_out, irreps_out=irreps_in)
[docs] def create_nonlinearity(self, nonlinearity_type, irreps_mid, nonlinearity_scalars, nonlinearity_gates): """Create nonlinearity module.""" if nonlinearity_type == "gate": irreps_scalars, irreps_gates, irreps_gated, act_scalars, act_gates = irreps2gate( irreps_mid, nonlinearity_scalars, nonlinearity_gates ) return Gate( irreps_scalars=irreps_scalars, act_scalars=act_scalars, irreps_gates=irreps_gates, act_gates=act_gates, irreps_gated=irreps_gated, ) return NormActivation( irreps_in=irreps_mid, scalar_nonlinearity=acts[nonlinearity_scalars[1]], normalize=True, epsilon=1e-8, bias=False, )
[docs] def forward(self, x): """ Forward pass of the residual block. Args: x (torch.Tensor): Input tensor with shape matching `irreps_in`. Returns: torch.Tensor: Output tensor with shape matching `irreps_in`. """ # Store old input for resnet connection if applicable old_x = x # Apply first linear transformation x = self.linear1(x) # Apply nonlinearity x = self.equivariant_nonlin(x) # Apply second linear transformation x = self.linear2(x) # Apply residual connection if resnet is enabled if self.resnet: x = old_x + x return x