Source code for hamgnn.models.hamgnn_conv

# Copyright (c) 2021-2026 HamGNN Team
# SPDX-License-Identifier: GPL-3.0-only

"""E3-equivariant message-passing backbone (:class:`HamGNNConvE3`) for HamGNN.

Stacks embedding, interaction, and atomwise blocks with NequIP-compatible
atomic data and radial/spherical encodings.
"""

import torch
from torch.utils.checkpoint import checkpoint
from e3nn import o3
from easydict import EasyDict
from typing import Optional, Dict

from .base_model import BaseModel

from ..nn.convolution import ConvBlockE3
from ..nn.embeddings import PairInteractionEmbeddingBlock, RadialBasisEdgeEncoding
from ..nn.interaction_blocks import CorrProductBlock, PairInteractionBlock
from ..toolbox.nequip.data import AtomicDataDict
from ..toolbox.nequip.nn import AtomwiseLinear
from ..toolbox.nequip.nn.embedding import (
    OneHotAtomEncoding,
    SphericalHarmonicEdgeAttrs,
    Embedding_block_q
)
from ..utils.basis_functions import (
    BernsteinRadialBasisFunctions,
    BesselBasis,
    ExponentialBernsteinRadialBasisFunctions,
    ExponentialGaussianRadialBasisFunctions,
    GaussianRadialBasisFunctions,
    GaussianSmearing
)
from ..utils.cutoff_functions import CosineCutoff, cuttoff_envelope
from ..utils.math_utils import upgrade_tensor_precision


[docs] class LayerCheckpointModule(torch.nn.Module): """Wraps per-layer blocks for gradient checkpointing with safe tensor cloning. All three blocks (ConvBlockE3, CorrProductBlock, PairInteractionBlock) mutate the graph dict in-place (e.g., data[NODE_FEATURES_KEY] = output_features). This causes incorrect gradients when wrapped naively with checkpoint() because checkpoint saves tensor storage and in-place ops corrupt the saved values. Solution: Clone input tensors before creating a working dict, so checkpointed backward pass reconstructs from clean copies. """ def __init__( self, conv: torch.nn.Module, corr: Optional[torch.nn.Module], pair: torch.nn.Module, use_corr_prod: bool, ): super().__init__() self.conv = conv self.corr = corr self.pair = pair self.use_corr_prod = use_corr_prod
[docs] def forward( self, node_feats: torch.Tensor, edge_feats: torch.Tensor, graph: Dict[str, torch.Tensor], ): # CRITICAL: Clone to avoid corrupting checkpoint-saved tensor storage node = node_feats.clone() edge = edge_feats.clone() work = { **graph.to_dict(), AtomicDataDict.NODE_FEATURES_KEY: node, AtomicDataDict.EDGE_FEATURES_KEY: edge, } self.conv(work) if self.use_corr_prod and self.corr is not None: self.corr(work) self.pair(work) return work[AtomicDataDict.NODE_FEATURES_KEY], work[AtomicDataDict.EDGE_FEATURES_KEY]
[docs] class HamGNNConvE3(BaseModel): def __init__(self, config): if 'radius_scale' not in config.HamGNN_pre: config.HamGNN_pre.radius_scale = 1.0 else: assert config.HamGNN_pre.radius_scale > 1.0, "The radius scaling factor must be greater than 1.0." super().__init__(radius_type=config.HamGNN_pre.radius_type, radius_scale=config.HamGNN_pre.radius_scale) # Configuration settings self.num_types = config.HamGNN_pre.num_types # Number of atomic species self.set_features = True # Whether to set one-hot encoding as node features # Irreps for edge spherical harmonics self.irreps_edge_sh = o3.Irreps(config.HamGNN_pre.irreps_edge_sh) self.edge_sh_normalization = config.HamGNN_pre.edge_sh_normalization self.edge_sh_normalize = config.HamGNN_pre.edge_sh_normalize self.build_internal_graph = config.HamGNN_pre.build_internal_graph if 'use_corr_prod' not in config.HamGNN_pre: self.use_corr_prod = False else: self.use_corr_prod = config.HamGNN_pre.use_corr_prod self.use_gradient_checkpointing = getattr(config.HamGNN_pre, 'use_gradient_checkpointing', False) # Legacy edge update self.legacy_edge_update = getattr( config.HamGNN_pre, 'legacy_edge_update', False) # Set product mode self.lite_mode = getattr(config.HamGNN_pre, 'lite_mode', False) # Radial basis function self.cutoff = config.HamGNN_pre.cutoff self.rbf_func = config.HamGNN_pre.rbf_func.lower() self.num_radial = config.HamGNN_pre.num_radial if self.rbf_func == 'gaussian': self.radial_basis_functions = GaussianSmearing( start=0.0, stop=self.cutoff, num_gaussians=self.num_radial, cutoff_func=None) elif self.rbf_func == 'bessel': self.radial_basis_functions = BesselBasis( cutoff=self.cutoff, n_rbf=self.num_radial, cutoff_func=None) elif self.rbf_func == 'exp-gaussian': self.radial_basis_functions = ExponentialGaussianRadialBasisFunctions( self.num_radial, self.cutoff) elif self.rbf_func == 'exp-bernstein': self.radial_basis_functions = ExponentialBernsteinRadialBasisFunctions( self.num_radial, self.cutoff) elif self.rbf_func == 'bernstein': self.radial_basis_functions = BernsteinRadialBasisFunctions( self.num_radial, self.cutoff) else: raise ValueError( f'Unsupported radial basis function: {self.rbf_func}') self.num_layers = config.HamGNN_pre.num_layers # Number of transformer layers self.irreps_node_features = o3.Irreps( config.HamGNN_pre.irreps_node_features) # Irreps for node features # Atomic embedding self.apply_charge_doping = getattr(config.HamGNN_pre, 'apply_charge_doping', False) if self.apply_charge_doping: num_charge_attr_feas = getattr(config.HamGNN_pre, 'num_charge_attr_feas', 8) self.atomic_embedding = Embedding_block_q( num_types=self.num_types, num_charge_attr_feas=num_charge_attr_feas, apply_charge_doping=True, set_features=self.set_features) else: self.atomic_embedding = OneHotAtomEncoding( num_types=self.num_types, set_features=self.set_features) # Spherical harmonics for edges self.spharm_edges = SphericalHarmonicEdgeAttrs(irreps_edge_sh=self.irreps_edge_sh, edge_sh_normalization=self.edge_sh_normalization, edge_sh_normalize=self.edge_sh_normalize) # Radial basis for edges self.cutoff_func = CosineCutoff(self.cutoff) self.radial_basis = RadialBasisEdgeEncoding(basis=self.radial_basis_functions, cutoff=self.cutoff_func) # Edge features embedding use_kan = config.HamGNN_pre.use_kan self.radial_MLP = config.HamGNN_pre.radial_MLP self.pair_embedding = PairInteractionEmbeddingBlock(irreps_node_feats=self.atomic_embedding.irreps_out['node_attrs'], irreps_edge_attrs=self.spharm_edges.irreps_out[ AtomicDataDict.EDGE_ATTRS_KEY], irreps_edge_embed=self.radial_basis.irreps_out[ AtomicDataDict.EDGE_EMBEDDING_KEY], irreps_edge_feats=self.irreps_node_features, irreps_node_attrs=self.atomic_embedding.irreps_out[ 'node_attrs'], use_kan=use_kan, radial_MLP=self.radial_MLP, lite_mode=self.lite_mode) # Chemical embedding self.chemical_embedding = AtomwiseLinear(irreps_in={AtomicDataDict.NODE_FEATURES_KEY: self.atomic_embedding.irreps_out['node_attrs']}, irreps_out=self.irreps_node_features) # Define the OrbTransformer layers correlation = config.HamGNN_pre.correlation num_hidden_features = config.HamGNN_pre.num_hidden_features self.convolutions = torch.nn.ModuleList() if self.use_corr_prod: self.corr_products = torch.nn.ModuleList() self.pair_interactions = torch.nn.ModuleList() for i in range(self.num_layers): conv = ConvBlockE3(irreps_in=self.irreps_node_features, irreps_out=self.irreps_node_features, irreps_node_attrs=self.atomic_embedding.irreps_out['node_attrs'], irreps_edge_attrs=self.spharm_edges.irreps_out[AtomicDataDict.EDGE_ATTRS_KEY], irreps_edge_embed=self.radial_basis.irreps_out[ AtomicDataDict.EDGE_EMBEDDING_KEY], radial_MLP=self.radial_MLP, use_skip_connections=True, use_kan=use_kan, lite_mode=self.lite_mode) self.convolutions.append(conv) if self.use_corr_prod: corr_product = CorrProductBlock( irreps_node_feats=self.irreps_node_features, num_hidden_features=num_hidden_features, correlation=correlation, num_elements=self.num_types, use_skip_connections=True ) self.corr_products.append(corr_product) pair_interaction = PairInteractionBlock(irreps_node_feats=self.irreps_node_features, irreps_node_attrs=self.atomic_embedding.irreps_out[ 'node_attrs'], irreps_edge_attrs=self.spharm_edges.irreps_out[ AtomicDataDict.EDGE_ATTRS_KEY], irreps_edge_embed=self.radial_basis.irreps_out[ AtomicDataDict.EDGE_EMBEDDING_KEY], irreps_edge_feats=self.irreps_node_features, use_skip_connections=( True if i > 0 else False) if self.legacy_edge_update else True, legacy_edge_update=self.legacy_edge_update, use_kan=use_kan, radial_MLP=self.radial_MLP, lite_mode=self.lite_mode) self.pair_interactions.append(pair_interaction) # Gradient checkpointing layer wrappers if self.use_gradient_checkpointing: self.layer_checkpoints = torch.nn.ModuleList() for i in range(self.num_layers): layer_chkpt = LayerCheckpointModule( conv=self.convolutions[i], corr=self.corr_products[i] if self.use_corr_prod else None, pair=self.pair_interactions[i], use_corr_prod=self.use_corr_prod, ) self.layer_checkpoints.append(layer_chkpt)
[docs] def forward(self, data): if torch.get_default_dtype() == torch.float64: upgrade_tensor_precision(data) if self.build_internal_graph: graph = self.generate_graph(data) else: graph = data self.atomic_embedding(graph) self.spharm_edges(graph) self.radial_basis(graph) self.pair_embedding(graph) self.chemical_embedding(graph) # Orbital convolution for i in range(self.num_layers): if self.use_gradient_checkpointing: node_feats = graph[AtomicDataDict.NODE_FEATURES_KEY] edge_feats = graph[AtomicDataDict.EDGE_FEATURES_KEY] new_node, new_edge = checkpoint( # pyright: ignore[reportGeneralTypeIssues] self.layer_checkpoints[i], node_feats, edge_feats, graph, use_reentrant=True) graph[AtomicDataDict.NODE_FEATURES_KEY] = new_node graph[AtomicDataDict.EDGE_FEATURES_KEY] = new_edge else: self.convolutions[i](graph) if self.use_corr_prod: self.corr_products[i](graph) self.pair_interactions[i](graph) graph_representation = EasyDict() graph_representation['node_attr'] = graph[AtomicDataDict.NODE_FEATURES_KEY] if self.build_internal_graph: graph_representation['edge_attr'] = graph[AtomicDataDict.EDGE_FEATURES_KEY][graph.matching_edges] else: graph_representation['edge_attr'] = graph[AtomicDataDict.EDGE_FEATURES_KEY] return graph_representation