# Copyright (c) 2021-2026 HamGNN Team
# SPDX-License-Identifier: GPL-3.0-only
"""Equivariant graph attention blocks and edge softmax aggregation for HamGNN.
Implements attention over message-passing features with ``e3nn`` irreps, including
:class:`AttentionAggregationV2` and related modules.
"""
import math
from typing import Callable, Dict, List, Optional, Tuple
import torch
from e3nn import o3
from e3nn.nn import FullyConnectedNet, Gate, NormActivation
from e3nn.util.jit import compile_mode
from torch import nn
from torch_geometric.utils import softmax as edge_softmax
from torch_scatter import scatter
from .attention_utils import AttentionHeadsToVector, VectorToAttentionHeads
from .interaction_blocks import ResidualBlock
from .message_passing import MessagePackBlock
from ..toolbox.efficient_kan import KAN
from ..toolbox.nequip.data import AtomicDataDict
from ..toolbox.nequip.nn import GraphModuleMixin
from ..utils.cutoff_functions import SoftUnitStepCutoff
from ..utils.irreps_utils import acts, irreps2gate, scale_irreps
from ..utils.macro import GRID_RANGE, GRID_SIZE
[docs]
@compile_mode("script")
class AttentionAggregationV2(nn.Module):
"""
An equivariant attention mechanism that processes key, value, and query vectors
and applies attention across edges in a graph.
Parameters:
- num_heads (int): Number of attention heads.
- irreps_value (o3.Irreps): Irreducible representations for value vectors.
"""
def __init__(
self,
num_heads: int,
irreps_value: o3.Irreps,
):
super().__init__()
self.num_heads = num_heads
irreps_value = o3.Irreps(irreps_value)
self.value_irreps_head = scale_irreps(irreps_value, 1/num_heads)
self.unfuse_value = VectorToAttentionHeads(self.value_irreps_head, num_heads)
self.fuse_value = AttentionHeadsToVector(self.value_irreps_head)
[docs]
def forward(
self,
value,
edge_weights: torch.Tensor, # (num_edges, num_heads)
edge_weights_cutoff: torch.Tensor, # (num_edges,)
edge_index: torch.LongTensor
) -> torch.Tensor:
"""
Forward pass of the attention mechanism.
Parameters:
- key (torch.Tensor): Key vectors.
- value (torch.Tensor): Value vectors.
- query (torch.Tensor): Query vectors.
- edge_weight_cutoff (torch.Tensor): Cutoff weights for edges.
- edge_index (torch.LongTensor): Edge indices.
Returns:
- torch.Tensor: Attended output vectors.
"""
value = self.unfuse_value(value)
edge_src, edge_dst = edge_index
# Compute the attention weights per edge
if edge_weights_cutoff is not None:
edge_weights = edge_weights_cutoff[:, None] * edge_weights # (num_edges, num_heads)
edge_weights = edge_softmax(edge_weights, edge_dst) # (num_edges, num_heads)
edge_weights = edge_weights.unsqueeze(-1) # (num_edges, num_heads, 1)
# Compute the attended outputs per node
f_out = scatter(edge_weights * value, edge_dst, dim=0) # (num_nodes, num_heads, irreps_head)
f_out = self.fuse_value(f_out) # Merge heads
return f_out
[docs]
@compile_mode("script")
class AttentionAggregation(nn.Module):
"""
An equivariant attention mechanism that processes key, value, and query vectors
and applies attention across edges in a graph.
Parameters:
- num_heads (int): Number of attention heads.
- irreps_key (o3.Irreps): Irreducible representations for key vectors.
- irreps_value (o3.Irreps): Irreducible representations for value vectors.
- irreps_query (o3.Irreps): Irreducible representations for query vectors.
"""
def __init__(
self,
num_heads: int,
irreps_key: o3.Irreps,
irreps_value: o3.Irreps,
irreps_query: o3.Irreps
):
super().__init__()
self.num_heads = num_heads
self.irreps_key = o3.Irreps(irreps_key)
irreps_value = o3.Irreps(irreps_value)
irreps_query = o3.Irreps(irreps_query)
self.key_irreps_head = scale_irreps(irreps_key, 1/num_heads)
self.value_irreps_head = scale_irreps(irreps_value, 1/num_heads)
self.query_irreps_head = scale_irreps(irreps_query, 1/num_heads)
self.unfuse_key = VectorToAttentionHeads(self.key_irreps_head, num_heads)
self.unfuse_value = VectorToAttentionHeads(self.value_irreps_head, num_heads)
self.unfuse_query = VectorToAttentionHeads(self.query_irreps_head, num_heads)
self.fuse_value = AttentionHeadsToVector(self.value_irreps_head)
[docs]
def forward(
self,
key: torch.Tensor, # (num_edges, hidden_feat_len)
value: torch.Tensor, # (num_edges, hidden_feat_len)
query: torch.Tensor, # (num_edges, hidden_feat_len)
edge_weight_cutoff: torch.Tensor, # (num_edges,)
edge_index: torch.LongTensor
) -> torch.Tensor:
"""
Forward pass of the attention mechanism.
Parameters:
- key (torch.Tensor): Key vectors.
- value (torch.Tensor): Value vectors.
- query (torch.Tensor): Query vectors.
- edge_weight_cutoff (torch.Tensor): Cutoff weights for edges.
- edge_index (torch.LongTensor): Edge indices.
Returns:
- torch.Tensor: Attended output vectors.
"""
key = self.unfuse_key(key)
value = self.unfuse_value(value)
query = self.unfuse_query(query)
edge_src, edge_dst = edge_index
# Compute the attention weights per edge
edge_weights = (query * key).sum(-1) # (num_edges, num_heads)
if edge_weight_cutoff is not None:
edge_weights = edge_weight_cutoff[:, None] * edge_weights # (num_edges, num_heads)
edge_weights = edge_weights / math.sqrt(self.key_irreps_head.dim)
edge_weights = edge_softmax(edge_weights, edge_dst) # (num_edges, num_heads)
edge_weights = edge_weights.unsqueeze(-1) # (num_edges, num_heads, 1)
# Compute the attended outputs per node
f_out = scatter(edge_weights * value, edge_dst, dim=0) # (num_nodes, num_heads, irreps_head)
f_out = self.fuse_value(f_out) # Merge heads
return f_out
[docs]
@compile_mode("script")
class AttentionBlockE3(nn.Module):
"""
An equivariant attention block for processing graph data with attention mechanisms.
Parameters:
- irreps_in (o3.Irreps): Input irreducible representations.
- irreps_out (o3.Irreps): Output irreducible representations.
- irreps_node_attrs (o3.Irreps): Node attribute irreducible representations.
- irreps_edge_attrs (o3.Irreps): Edge attribute irreducible representations.
- irreps_edge_embed (o3.Irreps): Edge embedding irreducible representations.
- num_heads (int): Number of attention heads.
- max_radius (float): Maximum radius for edge cutoff.
- radial_MLP (Optional[List[int]]): Architecture of the radial MLP.
- use_skip_connections (bool): Whether to use skip connections.
- use_kan (bool): Whether to use KAN for radial MLP.
- nonlinearity_type (str): Type of nonlinearity ('gate' or 'norm').
- nonlinearity_scalars (Dict[int, Callable]): Scalar nonlinearity functions.
- nonlinearity_gates (Dict[int, Callable]): Gate nonlinearity functions.
"""
def __init__(
self,
irreps_in: o3.Irreps,
irreps_out: o3.Irreps,
irreps_node_attrs: o3.Irreps,
irreps_edge_feats: o3.Irreps,
irreps_edge_attrs: o3.Irreps,
irreps_edge_embed: o3.Irreps,
num_heads: int,
max_radius: float,
radial_MLP: Optional[List[int]] = None,
use_skip_connections: bool = True,
use_kan: bool = False,
nonlinearity_type: str = "gate",
nonlinearity_scalars: Dict[int, Callable] = {"e": "ssp", "o": "tanh"},
nonlinearity_gates: Dict[int, Callable] = {"e": "ssp", "o": "abs"},
):
super().__init__()
self.radial_MLP = radial_MLP or [64, 64, 64]
self.use_kan = use_kan
self.use_skip_connections = use_skip_connections
assert nonlinearity_type in ("gate", "norm"), "Invalid nonlinearity type."
# Convert nonlinearity mappings
nonlinearity_scalars = {
1: nonlinearity_scalars["e"],
-1: nonlinearity_scalars["o"],
}
nonlinearity_gates = {
1: nonlinearity_gates["e"],
-1: nonlinearity_gates["o"],
}
# Assign irreps
self.irreps_in = o3.Irreps(irreps_in)
self.irreps_out = o3.Irreps(irreps_out)
self.irreps_edge_attrs = o3.Irreps(irreps_edge_attrs)
self.irreps_edge_embed = o3.Irreps(irreps_edge_embed)
self.irreps_edge_feats = o3.Irreps(irreps_edge_feats)
self.irreps_node_attrs = o3.Irreps(irreps_node_attrs)
self.register_buffer(
"max_radius", torch.tensor(max_radius, dtype=torch.get_default_dtype())
)
self.cutoff_func = SoftUnitStepCutoff(cutoff=max_radius)
# Linear transformations
self.linear_up_src = self.create_linear(self.irreps_in)
self.linear_up_tar = self.create_linear(self.irreps_in)
self.linear_up_edge = self.create_linear(self.irreps_in)
# Nonlinearity
self.residual = ResidualBlock(self.irreps_in, self.irreps_out)
# Create TensorProducts for value
self.conv_tp_value = MessagePackBlock(irreps_node_feats=self.irreps_in,
irreps_edge_feats=self.irreps_edge_feats,
irreps_local_env_edge=self.irreps_edge_attrs,
irreps_out=self.irreps_out,
irreps_edge_scalars=self.irreps_edge_embed,
radial_MLP=self.radial_MLP,
use_kan=self.use_kan)
# Linear layers for key, query, and value
self.linear_key = self.create_linear(self.irreps_in, self.irreps_in)
self.linear_query = self.create_linear(self.irreps_in, self.irreps_in)
# Attention mechanism
self.attention = AttentionAggregation(
num_heads=num_heads,
irreps_key=self.irreps_in,
irreps_value=self.irreps_in,
irreps_query=self.irreps_in,
)
# Skip connection
if self.use_skip_connections:
self.skip_linear = self.create_linear(self.irreps_in, self.irreps_out)
[docs]
def create_linear(self, irreps_in, irreps_out=None):
"""Create a linear layer."""
return o3.Linear(
irreps_in, irreps_out or irreps_in, internal_weights=True, shared_weights=True
)
[docs]
def create_tensor_product(self, irreps_mid, instructions):
"""Create a TensorProduct layer."""
return o3.TensorProduct(
self.irreps_in,
self.irreps_edge_attrs,
irreps_mid,
instructions=instructions,
shared_weights=False,
internal_weights=False,
)
[docs]
def init_weight_generator(self, input_dim, weight_numel):
"""Initialize weight generator."""
if self.use_kan:
return KAN([input_dim] + self.radial_MLP + [weight_numel], grid_size=GRID_SIZE, grid_range=GRID_RANGE)
return FullyConnectedNet(
[input_dim] + self.radial_MLP + [weight_numel],
torch.nn.functional.silu,
)
[docs]
def create_nonlinearity(self, nonlinearity_type, nonlinearity_scalars, nonlinearity_gates):
"""Create nonlinearity module."""
if nonlinearity_type == "gate":
irreps_scalars, irreps_gates, irreps_gated, act_scalars, act_gates = irreps2gate(
self.irreps_in, nonlinearity_scalars, nonlinearity_gates
)
return Gate(
irreps_scalars=irreps_scalars,
act_scalars=act_scalars,
irreps_gates=irreps_gates,
act_gates=act_gates,
irreps_gated=irreps_gated,
)
return NormActivation(
irreps_in=self.irreps_in,
scalar_nonlinearity=acts[nonlinearity_scalars[1]],
normalize=True,
epsilon=1e-8,
bias=False,
)
[docs]
def forward(
self,
data: Dict[str, torch.Tensor],
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Forward pass of the attention block.
Parameters:
- data (Dict[str, torch.Tensor]): A dictionary containing the graph data.
Returns:
- Tuple[torch.Tensor, Optional[torch.Tensor]]: Updated node features and skip connection.
"""
sender, receiver = data[AtomicDataDict.EDGE_INDEX_KEY]
node_feats = data[AtomicDataDict.NODE_FEATURES_KEY]
edge_embed = data[AtomicDataDict.EDGE_EMBEDDING_KEY]
edge_attrs = data[AtomicDataDict.EDGE_ATTRS_KEY]
edge_feats = data[AtomicDataDict.EDGE_FEATURES_KEY]
# Skip connection
sc = self.skip_linear(node_feats) if self.use_skip_connections else None
# Process key, query, and value
key = self.linear_key(node_feats)[sender]
query = self.linear_key(node_feats)[receiver]
value = self.conv_tp_value(self.linear_up_src(node_feats)[sender],
self.linear_up_tar(node_feats)[receiver],
self.linear_up_edge(edge_feats),
edge_attrs,
edge_embed)
# Attention mechanism
edge_weight_cutoff = self.cutoff_func(data[AtomicDataDict.EDGE_LENGTH_KEY])
node_feats = self.attention(key, value, query, edge_weight_cutoff, edge_index=data[AtomicDataDict.EDGE_INDEX_KEY])
# Apply nonlinearity
node_feats = self.residual(node_feats)
# Apply skip connection if used
if self.use_skip_connections:
node_feats += sc
data[AtomicDataDict.NODE_FEATURES_KEY] = node_feats
return node_feats