Source code for hamgnn.nn.embeddings

# Copyright (c) 2021-2026 HamGNN Team
# SPDX-License-Identifier: GPL-3.0-only

"""Radial and pair embeddings for edges and nodes in equivariant HamGNN stacks.

Includes radial-basis edge encodings, one-hot and spherical-harmonic style embedders,
and pair-interaction feature blocks used with NequIP-style :class:`GraphModuleMixin`.
"""

from typing import Callable, Dict, List, Optional, Tuple

import torch
import numpy as np
from e3nn import o3
from e3nn.nn import FullyConnectedNet
from e3nn.util.jit import compile_mode
from torch import nn

from ..toolbox.efficient_kan import KAN
from ..toolbox.nequip.data import AtomicDataDict
from ..toolbox.nequip.nn import GraphModuleMixin
from ..utils.macro import GRID_RANGE, GRID_SIZE
from .electron_configurations import electron_configurations
from .tensor_products import TensorProductWithMemoryOptimizationWithWeight

[docs] @compile_mode('script') class RadialBasisEdgeEncoding(GraphModuleMixin, torch.nn.Module): """ Encodes edge lengths using a specified radial basis. Attributes: out_field (str): The key for storing the encoded edge features. """ def __init__( self, basis=None, cutoff=None, out_field: str = AtomicDataDict.EDGE_EMBEDDING_KEY, irreps_in=None, ): """ Initializes the RadialBasisEdgeEncoding module. :param basis: The radial basis function used for encoding. :param out_field: The output field key for encoded edges. :param irreps_in: Input irreducible representations. """ super().__init__() self.basis = basis self.cutoff = cutoff self.out_field = out_field # Determine the number of basis functions based on the basis type basis_type = type(basis).__name__.split(".")[-1] if basis_type in {'BesselBasis', 'GaussianSmearing'}: num_basis = basis.freqs.size(0) if basis_type == 'BesselBasis' else basis.offset.size(0) elif basis_type in { 'ExponentialGaussianRadialBasisFunctions', 'ExponentialBernsteinRadialBasisFunctions', 'GaussianRadialBasisFunctions', 'BernsteinRadialBasisFunctions' }: num_basis = basis.num_basis_functions else: raise NotImplementedError(f"Basis type {basis_type} is not supported.") self._init_irreps( irreps_in=irreps_in, irreps_out={self.out_field: o3.Irreps([(num_basis, (0, 1))])}, )
[docs] def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type: """ Computes the edge encoding and updates the data dictionary. :param data: A dictionary containing graph data. :return: Updated graph data with encoded edge features. """ j, i = data.edge_index nbr_shift = data.nbr_shift pos = data.pos # Calculate edge directions and lengths edge_dir = (pos[i] + nbr_shift) - pos[j] edge_length = edge_dir.norm(dim=-1) # Update data with computed edge vectors and lengths data[AtomicDataDict.EDGE_VECTORS_KEY] = edge_dir/edge_length[:,None] data[AtomicDataDict.EDGE_LENGTH_KEY] = edge_length # Apply the radial basis to the edge lengths edge_length_embedded = self.basis(edge_length) if self.cutoff is not None: edge_length_embedded = edge_length_embedded*self.cutoff(edge_length)[:, None] data[self.out_field] = edge_length_embedded return data
[docs] @compile_mode("script") class EdgeScalarEmbedding(nn.Module): """ A layer to compute edge scalars from node attributes and edge embeddings. Args: irreps_node_attrs (Irreps): Irreps for node attributes. irreps_edge_embed (Irreps): Irreps for edge embeddings. irreps_edge_scalars (Irreps): Irreps for edge scalars. """ def __init__(self, irreps_node_attrs, irreps_edge_embed, irreps_edge_scalars): super().__init__() self.linear_out = o3.Linear( irreps_node_attrs + irreps_node_attrs + irreps_edge_embed, irreps_edge_scalars )
[docs] def forward(self, node_attr_src, node_attr_dst, edge_embed): """ Forward pass to compute edge scalars. Args: node_attr_src (Tensor): Source node attributes. node_attr_dst (Tensor): Destination node attributes. edge_embed (Tensor): Edge embeddings. Returns: Tensor: Computed edge scalars. """ combined_features = torch.cat([node_attr_src, node_attr_dst, edge_embed], dim=-1) return self.linear_out(combined_features)
[docs] @compile_mode("script") class LocalEnvironmentEmbedding(nn.Module): """ Embeds local environments using node and edge attributes, edge embeddings, and spherical harmonics. Args: irreps_edge_attrs (Irreps): Irreps for edge attributes. irreps_edge_embed (Irreps): Irreps for edge embeddings. irreps_node_attrs (Irreps): Irreps for node attributes. irreps_edge_scalars (Irreps): Irreps for edge scalars. irreps_env_sh (Irreps): Irreps for environment spherical harmonics. radial_mlp_dims (list[int]): Dimensions for the radial MLP. use_kan (bool): Whether to use the KAN model. """ def __init__(self, irreps_edge_attrs, irreps_edge_embed, irreps_node_attrs, irreps_edge_scalars, irreps_env_sh, radial_MLP=[64, 64], use_kan=False): super().__init__() self.edge_scalar_layer = EdgeScalarEmbedding(irreps_node_attrs, irreps_edge_embed, irreps_edge_scalars) instructions = [(i, 0, i, "uvw", True) for i in range(len(irreps_edge_attrs))] self.tensor_product = o3.TensorProduct( irreps_edge_attrs, o3.Irreps('1x0e'), irreps_env_sh, instructions=instructions, shared_weights=False, internal_weights=False, ) self.weight_numel = self.tensor_product.weight_numel input_dim = irreps_edge_embed.num_irreps self.weight_generator = self._initialize_weight_generator(input_dim, self.weight_numel, radial_MLP, use_kan) def _initialize_weight_generator(self, input_dim, weight_numel, radial_MLP, use_kan): """ Initializes the weight generator. Args: input_dim (int): Input dimension for the generator. weight_numel (int): Number of elements in weights. radial_mlp_dims (list[int]): Dimensions for the radial MLP. use_kan (bool): Whether to use the KAN model. Returns: nn.Module: The weight generator model. """ if use_kan: return KAN([input_dim] + radial_MLP + [weight_numel], grid_size=GRID_SIZE, grid_range=GRID_RANGE) return FullyConnectedNet( [input_dim] + radial_MLP + [weight_numel], torch.nn.functional.silu, )
[docs] def forward(self, edge_index, node_attr, edge_attr, edge_embed): """ Forward pass to compute local environment embeddings. Args: edge_index (Tensor): Indices of the edges. node_attr (Tensor): Node attributes. edge_attr (Tensor): Edge attributes. edge_embed (Tensor): Edge embeddings. Returns: Tensor: Local environment embeddings. """ src, dst = edge_index pseudo_scalar = torch.ones_like(edge_embed[:, :1]) edge_scalars = self.edge_scalar_layer(node_attr[src], node_attr[dst], edge_embed) weights = self.weight_generator(edge_scalars) local_env_edge = self.tensor_product(edge_attr, pseudo_scalar, weights) return local_env_edge
[docs] @compile_mode("script") class PairInteractionEmbeddingBlock(nn.Module): """ A pair interaction block for updating edge features based on node features and edge attributes. Parameters: - irreps_node_feats (o3.Irreps): Irreducible representations for node features. - irreps_edge_attrs (o3.Irreps): Irreducible representations for edge attributes. - irreps_edge_embed (o3.Irreps): Irreducible representations for edge embeddings. - irreps_edge_feats (o3.Irreps): Irreducible representations for edge features. - use_skip_connections (bool): Whether to use skip connections. - use_kan (bool): Whether to use KAN for radial MLP. - radial_MLP (Optional[List[int]]): Architecture of the radial MLP. - nonlinearity_type (str): Type of nonlinearity to use ("gate" or "norm"). - nonlinearity_scalars (Dict[int, Callable]): Nonlinearity for scalar channels. - nonlinearity_gates (Dict[int, Callable]): Nonlinearity for gate channels. """ def __init__( self, irreps_node_feats: o3.Irreps, irreps_edge_attrs: o3.Irreps, irreps_node_attrs: o3.Irreps, irreps_edge_embed: o3.Irreps, irreps_edge_feats: o3.Irreps, use_kan: bool = False, radial_MLP: Optional[List[int]] = None, nonlinearity_type: str = "gate", nonlinearity_scalars: Dict[int, Callable] = {"e": "ssp", "o": "tanh"}, nonlinearity_gates: Dict[int, Callable] = {"e": "ssp", "o": "abs"}, lite_mode: bool = False ) -> None: super().__init__() self.radial_MLP = radial_MLP or [64, 64, 64] self.use_kan = use_kan self.lite_mode = lite_mode # Assign irreps self.irreps_node_feats = o3.Irreps(irreps_node_feats) self.irreps_edge_attrs = o3.Irreps(irreps_edge_attrs) self.irreps_edge_embed = o3.Irreps(irreps_edge_embed) self.irreps_edge_feats = o3.Irreps(irreps_edge_feats) self.irreps_node_attrs = o3.Irreps(irreps_node_attrs) assert nonlinearity_type in ("gate", "norm"), "Invalid nonlinearity type." # Convert nonlinearity mappings nonlinearity_scalars = { 1: nonlinearity_scalars["e"], -1: nonlinearity_scalars["o"], } nonlinearity_gates = { 1: nonlinearity_gates["e"], -1: nonlinearity_gates["o"], } # Linear layers for lifting node features self.linear_up_src = self.create_linear(self.irreps_node_feats) self.linear_up_dst = self.create_linear(self.irreps_node_feats) # TensorProduct layer for edge feature mixing self.conv_tp = TensorProductWithMemoryOptimizationWithWeight(irreps_input_1=self.irreps_node_feats, irreps_input_2=self.irreps_edge_attrs, irreps_out=self.irreps_edge_feats, irreps_scalar=self.irreps_edge_embed, radial_MLP=self.radial_MLP, use_kan=self.use_kan, lite_mode=self.lite_mode)
[docs] def create_linear(self, irreps_in, irreps_out=None): """Create a linear layer.""" return o3.Linear( irreps_in, irreps_out or irreps_in, internal_weights=True, shared_weights=True )
[docs] def create_tensor_product(self, irreps_mid, instructions): """Create a TensorProduct layer.""" return o3.TensorProduct( self.irreps_node_feats, self.irreps_edge_attrs, irreps_mid, instructions=instructions, shared_weights=False, internal_weights=False, )
[docs] def init_weight_generator(self, input_dim, weight_numel): """Initialize weight generator.""" if self.use_kan: return KAN([input_dim] + self.radial_MLP + [weight_numel], grid_size=GRID_SIZE, grid_range=GRID_RANGE) return FullyConnectedNet( [input_dim] + self.radial_MLP + [weight_numel], torch.nn.functional.silu, )
[docs] def forward( self, data: Dict[str, torch.Tensor], ) -> torch.Tensor: """ Forward pass of the pair interaction block. Parameters: - data (Dict[str, torch.Tensor]): A dictionary containing the graph data. Returns: - torch.Tensor: Updated edge features. """ edge_src, edge_dst = data[AtomicDataDict.EDGE_INDEX_KEY] node_feats = data[AtomicDataDict.NODE_FEATURES_KEY] edge_embed = data[AtomicDataDict.EDGE_EMBEDDING_KEY] edge_attributes = data[AtomicDataDict.EDGE_ATTRS_KEY] node_feats_src = self.linear_up_src(node_feats[edge_src]) node_feats_dst = self.linear_up_dst(node_feats[edge_dst]) # Mixing node features for edge features edge_feats_mix_tp = self.conv_tp( node_feats_src + node_feats_dst, edge_attributes, edge_embed ) data[AtomicDataDict.EDGE_FEATURES_KEY] = edge_feats_mix_tp return edge_feats_mix_tp
""" Embedding layer which takes scalar nuclear charges Z and transforms them to vectors of size num_features """
[docs] class Embedding(nn.Module): def __init__(self, num_features, Zmax=87): super(Embedding, self).__init__() self.num_features = num_features self.Zmax = Zmax self.register_buffer('electron_config', torch.tensor(electron_configurations)) self.register_parameter('element_embedding', nn.Parameter(torch.Tensor(self.Zmax, self.num_features))) self.config_linear = nn.Linear(self.electron_config.size(1), self.num_features, bias=False) self.reset_parameters()
[docs] def reset_parameters(self): nn.init.uniform_(self.element_embedding, -np.sqrt(3), np.sqrt(3)) nn.init.orthogonal_(self.config_linear.weight)
[docs] def forward(self, Z): embedding = self.element_embedding + self.config_linear(self.electron_config) return embedding[Z]