Source code for hamgnn.models.hamgnn_transformer

# Copyright (c) 2021-2026 HamGNN Team
# SPDX-License-Identifier: GPL-3.0-only

"""Transformer-style E3-equivariant representation (:class:`HamGNNTransformer`) for HamGNN.

Uses attention blocks alongside pair interactions and spherical edge attributes.
"""

import torch
from e3nn import o3
from easydict import EasyDict

from .base_model import BaseModel

from ..nn.attention import AttentionBlockE3
from ..nn.embeddings import PairInteractionEmbeddingBlock, RadialBasisEdgeEncoding
from ..nn.interaction_blocks import CorrProductBlock, PairInteractionBlock
from ..toolbox.nequip.data import AtomicDataDict
from ..toolbox.nequip.nn import AtomwiseLinear
from ..toolbox.nequip.nn.embedding import (
    OneHotAtomEncoding,
    SphericalHarmonicEdgeAttrs,
    Embedding_block_q
)

from ..utils.basis_functions import (
    BernsteinRadialBasisFunctions,
    BesselBasis,
    ExponentialBernsteinRadialBasisFunctions,
    ExponentialGaussianRadialBasisFunctions,
    GaussianRadialBasisFunctions,
    GaussianSmearing
)
from ..utils.cutoff_functions import CosineCutoff

[docs] class HamGNNTransformer(BaseModel): def __init__(self, config): if 'radius_scale' not in config.HamGNN_pre: config.HamGNN_pre.radius_scale = 1.0 else: assert config.HamGNN_pre.radius_scale > 1.0, "The radius scaling factor must be greater than 1.0." super().__init__(radius_type=config.HamGNN_pre.radius_type, radius_scale=config.HamGNN_pre.radius_scale) # Configuration settings self.num_types = config.HamGNN_pre.num_types # Number of atomic species self.set_features = True # Whether to set one-hot encoding as node features self.irreps_edge_sh = o3.Irreps(config.HamGNN_pre.irreps_edge_sh) # Irreps for edge spherical harmonics self.edge_sh_normalization = config.HamGNN_pre.edge_sh_normalization self.edge_sh_normalize = config.HamGNN_pre.edge_sh_normalize self.build_internal_graph = config.HamGNN_pre.build_internal_graph # Radial basis function self.cutoff = config.HamGNN_pre.cutoff self.rbf_func = config.HamGNN_pre.rbf_func.lower() self.num_radial = config.HamGNN_pre.num_radial if self.rbf_func == 'gaussian': self.radial_basis_functions = GaussianSmearing(start=0.0, stop=self.cutoff, num_gaussians=self.num_radial, cutoff_func=None) elif self.rbf_func == 'bessel': self.radial_basis_functions = BesselBasis(cutoff=self.cutoff, n_rbf=self.num_radial, cutoff_func=None) elif self.rbf_func == 'exp-gaussian': self.radial_basis_functions = ExponentialGaussianRadialBasisFunctions(self.num_radial, self.cutoff) elif self.rbf_func == 'exp-bernstein': self.radial_basis_functions = ExponentialBernsteinRadialBasisFunctions(self.num_radial, self.cutoff) elif self.rbf_func == 'bernstein': self.radial_basis_functions = BernsteinRadialBasisFunctions(self.num_radial, self.cutoff) else: raise ValueError(f'Unsupported radial basis function: {self.rbf_func}') self.num_layers = config.HamGNN_pre.num_layers # Number of transformer layers self.irreps_node_features = o3.Irreps(config.HamGNN_pre.irreps_node_features) # Irreps for node features # Atomic embedding self.apply_charge_doping = getattr(config.HamGNN_pre, 'apply_charge_doping', False) if self.apply_charge_doping: num_charge_attr_feas = getattr(config.HamGNN_pre, 'num_charge_attr_feas', 8) self.atomic_embedding = Embedding_block_q( num_types=self.num_types, num_charge_attr_feas=num_charge_attr_feas, apply_charge_doping=True, set_features=self.set_features) else: self.atomic_embedding = OneHotAtomEncoding(num_types=self.num_types, set_features=self.set_features) # Spherical harmonics for edges self.spharm_edges = SphericalHarmonicEdgeAttrs(irreps_edge_sh=self.irreps_edge_sh, edge_sh_normalization=self.edge_sh_normalization, edge_sh_normalize=self.edge_sh_normalize) # Radial basis for edges self.cutoff_func = CosineCutoff(self.cutoff) self.radial_basis = RadialBasisEdgeEncoding(basis=self.radial_basis_functions, cutoff=self.cutoff_func) # Edge features embedding use_kan = config.HamGNN_pre.use_kan self.radial_MLP = config.HamGNN_pre.radial_MLP self.pair_embedding = PairInteractionEmbeddingBlock(irreps_node_feats=self.atomic_embedding.irreps_out['node_attrs'], irreps_edge_attrs=self.spharm_edges.irreps_out[AtomicDataDict.EDGE_ATTRS_KEY], irreps_edge_embed=self.radial_basis.irreps_out[AtomicDataDict.EDGE_EMBEDDING_KEY], irreps_edge_feats=self.irreps_node_features, irreps_node_attrs=self.atomic_embedding.irreps_out['node_attrs'], use_kan=use_kan, radial_MLP=self.radial_MLP) # Chemical embedding self.chemical_embedding = AtomwiseLinear(irreps_in={AtomicDataDict.NODE_FEATURES_KEY: self.atomic_embedding.irreps_out['node_attrs']}, irreps_out=self.irreps_node_features) # Define the OrbTransformer layers self.num_heads = config.HamGNN_pre.num_heads correlation = config.HamGNN_pre.correlation num_hidden_features = config.HamGNN_pre.num_hidden_features self.orb_transformers = torch.nn.ModuleList() self.corr_products = torch.nn.ModuleList() self.pair_interactions = torch.nn.ModuleList() for i in range(self.num_layers): orb_transformer = AttentionBlockE3(irreps_in=self.irreps_node_features, irreps_node_attrs=self.atomic_embedding.irreps_out['node_attrs'], irreps_out=self.irreps_node_features, irreps_edge_feats=self.irreps_node_features, irreps_edge_attrs=self.spharm_edges.irreps_out[AtomicDataDict.EDGE_ATTRS_KEY], irreps_edge_embed=self.radial_basis.irreps_out[AtomicDataDict.EDGE_EMBEDDING_KEY], num_heads=self.num_heads, max_radius=self.cutoff, radial_MLP=self.radial_MLP, use_skip_connections=True, use_kan=use_kan) self.orb_transformers.append(orb_transformer) corr_product = CorrProductBlock( irreps_node_feats=self.irreps_node_features, num_hidden_features=num_hidden_features, correlation=correlation, num_elements=self.num_types, use_skip_connections=True ) self.corr_products.append(corr_product) pair_interaction = PairInteractionBlock(irreps_node_feats=self.irreps_node_features, irreps_node_attrs=self.atomic_embedding.irreps_out['node_attrs'], irreps_edge_attrs=self.spharm_edges.irreps_out[AtomicDataDict.EDGE_ATTRS_KEY], irreps_edge_embed=self.radial_basis.irreps_out[AtomicDataDict.EDGE_EMBEDDING_KEY], irreps_edge_feats=self.irreps_node_features, use_skip_connections=True, legacy_edge_update=getattr( config.HamGNN_pre, 'legacy_edge_update', False), use_kan=use_kan, radial_MLP=self.radial_MLP) self.pair_interactions.append(pair_interaction)
[docs] def forward(self, data): if self.build_internal_graph: graph = self.generate_graph(data) else: graph = data self.atomic_embedding(graph) self.spharm_edges(graph) self.radial_basis(graph) self.pair_embedding(graph) self.chemical_embedding(graph) # Orbital convolution for i in range(self.num_layers): self.orb_transformers[i](graph) self.corr_products[i](graph) self.pair_interactions[i](graph) graph_representation = EasyDict() graph_representation['node_attr'] = graph[AtomicDataDict.NODE_FEATURES_KEY] if self.build_internal_graph: graph_representation['edge_attr'] = graph[AtomicDataDict.EDGE_FEATURES_KEY][graph.matching_edges] else: graph_representation['edge_attr'] = graph[AtomicDataDict.EDGE_FEATURES_KEY] return graph_representation