Source code for hamgnn.nn.convolution

# Copyright (c) 2021-2026 HamGNN Team
# SPDX-License-Identifier: GPL-3.0-only

"""E3-equivariant convolution block (:class:`ConvBlockE3`) with tensor products and skips.

Wraps message packing, radial networks, and optional KAN / gate nonlinearities for
graph convolution on atomic data.
"""

from typing import Callable, Dict, List, Optional, Tuple

import torch
from e3nn import o3
from e3nn.util.jit import compile_mode
from torch import nn
from torch_scatter import scatter

from .interaction_blocks import ResidualBlock
from .message_passing import MessagePackBlock
from ..toolbox.nequip.data import AtomicDataDict

[docs] @compile_mode("script") class ConvBlockE3(nn.Module): """ An equivariant convolutional block for processing node features using tensor products with optional skip connections. Parameters: - irreps_in (o3.Irreps): Input irreducible representations. - irreps_out (o3.Irreps): Output irreducible representations. - irreps_edge_attrs (o3.Irreps): Edge attribute irreducible representations. - irreps_edge_embed (o3.Irreps): Edge embedding irreducible representations. - radial_MLP (Optional[List[int]]): MLP architecture for radial embeddings. Defaults to [64, 64, 64]. - use_skip_connections (bool): Whether to use skip connections. Defaults to True. - use_kan (bool): Whether to use the FastKAN module for weight generation. Defaults to False. - nonlinearity_type (str): Type of nonlinearity to use ("gate" or "norm"). Defaults to "gate". - nonlinearity_scalars (Dict[int, Callable]): Nonlinearity for scalar channels. - nonlinearity_gates (Dict[int, Callable]): Nonlinearity for gate channels. - lite_mode (bool): The mode with the fewest model parameters and the fastest running speed. """ def __init__( self, irreps_in: o3.Irreps, irreps_out: o3.Irreps, irreps_node_attrs: o3.Irreps, irreps_edge_attrs: o3.Irreps, irreps_edge_embed: o3.Irreps, radial_MLP: Optional[list] = None, use_skip_connections: bool = True, use_kan: bool = False, nonlinearity_type: str = "gate", nonlinearity_scalars: dict = {"e": "ssp", "o": "tanh"}, nonlinearity_gates: dict = {"e": "ssp", "o": "abs"}, lite_mode: bool = False ): super().__init__() self.radial_MLP = radial_MLP or [64, 64, 64] self.use_kan = use_kan self.use_skip_connections = use_skip_connections self.lite_mode = lite_mode assert nonlinearity_type in ("gate", "norm"), "Invalid nonlinearity type." # Convert nonlinearity mappings scalar_nonlinearities = { 1: nonlinearity_scalars["e"], -1: nonlinearity_scalars["o"], } gate_nonlinearities = { 1: nonlinearity_gates["e"], -1: nonlinearity_gates["o"], } self.irreps_in = o3.Irreps(irreps_in) self.irreps_out = o3.Irreps(irreps_out) self.irreps_node_attrs = o3.Irreps(irreps_node_attrs) self.irreps_edge_attrs = o3.Irreps(irreps_edge_attrs) self.irreps_edge_embed = o3.Irreps(irreps_edge_embed) # Residual block for processing features self.residual = ResidualBlock(self.irreps_in, self.irreps_out) # Convolution layers self.conv_tp = MessagePackBlock( irreps_node_feats=self.irreps_in, irreps_edge_feats=self.irreps_in, irreps_local_env_edge=self.irreps_edge_attrs, irreps_out=self.irreps_out, irreps_edge_scalars=self.irreps_edge_embed, radial_MLP=self.radial_MLP, use_kan=self.use_kan, lite_mode=self.lite_mode ) # Skip connection layer if self.use_skip_connections: self.skip_linear = self.create_linear(self.irreps_in, self.irreps_out)
[docs] def create_linear(self, irreps_in, irreps_out=None): """ Create a linear layer. Parameters: - irreps_in (o3.Irreps): Input irreps for the linear layer. - irreps_out (o3.Irreps, optional): Output irreps for the linear layer. Returns: - o3.Linear: A linear transformation layer. """ return o3.Linear( irreps_in, irreps_out or irreps_in, internal_weights=True, shared_weights=True )
[docs] def forward(self, data: dict) -> torch.Tensor: """ Forward pass of the convolutional block. Parameters: - data (dict): Dictionary containing graph data. Returns: - torch.Tensor: Updated node features. """ sender, receiver = data[AtomicDataDict.EDGE_INDEX_KEY] node_features = data[AtomicDataDict.NODE_FEATURES_KEY] edge_embedding = data[AtomicDataDict.EDGE_EMBEDDING_KEY] edge_attributes = data[AtomicDataDict.EDGE_ATTRS_KEY] num_nodes = len(data[AtomicDataDict.NODE_FEATURES_KEY]) # Skip connection skip_connection = self.skip_linear(node_features) if self.use_skip_connections else None # Messages messages = self.conv_tp( node_features[sender], node_features[receiver], data[AtomicDataDict.EDGE_FEATURES_KEY], edge_attributes, edge_embedding ) # Aggregate messages # Note: Use num_nodes instead of receiver.max().item()+1 to avoid # GPU-CPU synchronization overhead from .item() calls aggregated_messages = scatter( src=messages, index=receiver, dim=0, dim_size=num_nodes ) # Apply residual block output_features = self.residual(aggregated_messages) # Apply skip connection if used if self.use_skip_connections: output_features += skip_connection data[AtomicDataDict.NODE_FEATURES_KEY] = output_features return output_features