Attention Mechanisms

Attention Layers

Equivariant graph attention blocks and edge softmax aggregation for HamGNN.

Implements attention over message-passing features with e3nn irreps, including AttentionAggregationV2 and related modules.

class hamgnn.nn.attention.AttentionAggregationV2(num_heads, irreps_value)[source]

Bases: Module

An equivariant attention mechanism that processes key, value, and query vectors and applies attention across edges in a graph.

Parameters: - num_heads (int): Number of attention heads. - irreps_value (o3.Irreps): Irreducible representations for value vectors.

forward(value, edge_weights, edge_weights_cutoff, edge_index)[source]

Forward pass of the attention mechanism.

Parameters: - key (torch.Tensor): Key vectors. - value (torch.Tensor): Value vectors. - query (torch.Tensor): Query vectors. - edge_weight_cutoff (torch.Tensor): Cutoff weights for edges. - edge_index (torch.LongTensor): Edge indices.

Returns: - torch.Tensor: Attended output vectors.

Return type:

Tensor

class hamgnn.nn.attention.AttentionAggregation(num_heads, irreps_key, irreps_value, irreps_query)[source]

Bases: Module

An equivariant attention mechanism that processes key, value, and query vectors and applies attention across edges in a graph.

Parameters: - num_heads (int): Number of attention heads. - irreps_key (o3.Irreps): Irreducible representations for key vectors. - irreps_value (o3.Irreps): Irreducible representations for value vectors. - irreps_query (o3.Irreps): Irreducible representations for query vectors.

forward(key, value, query, edge_weight_cutoff, edge_index)[source]

Forward pass of the attention mechanism.

Parameters: - key (torch.Tensor): Key vectors. - value (torch.Tensor): Value vectors. - query (torch.Tensor): Query vectors. - edge_weight_cutoff (torch.Tensor): Cutoff weights for edges. - edge_index (torch.LongTensor): Edge indices.

Returns: - torch.Tensor: Attended output vectors.

Return type:

Tensor

class hamgnn.nn.attention.AttentionBlockE3(irreps_in, irreps_out, irreps_node_attrs, irreps_edge_feats, irreps_edge_attrs, irreps_edge_embed, num_heads, max_radius, radial_MLP=None, use_skip_connections=True, use_kan=False, nonlinearity_type='gate', nonlinearity_scalars={'e': 'ssp', 'o': 'tanh'}, nonlinearity_gates={'e': 'ssp', 'o': 'abs'})[source]

Bases: Module

An equivariant attention block for processing graph data with attention mechanisms.

Parameters: - irreps_in (o3.Irreps): Input irreducible representations. - irreps_out (o3.Irreps): Output irreducible representations. - irreps_node_attrs (o3.Irreps): Node attribute irreducible representations. - irreps_edge_attrs (o3.Irreps): Edge attribute irreducible representations. - irreps_edge_embed (o3.Irreps): Edge embedding irreducible representations. - num_heads (int): Number of attention heads. - max_radius (float): Maximum radius for edge cutoff. - radial_MLP (Optional[List[int]]): Architecture of the radial MLP. - use_skip_connections (bool): Whether to use skip connections. - use_kan (bool): Whether to use KAN for radial MLP. - nonlinearity_type (str): Type of nonlinearity (‘gate’ or ‘norm’). - nonlinearity_scalars (Dict[int, Callable]): Scalar nonlinearity functions. - nonlinearity_gates (Dict[int, Callable]): Gate nonlinearity functions.

create_linear(irreps_in, irreps_out=None)[source]

Create a linear layer.

create_tensor_product(irreps_mid, instructions)[source]

Create a TensorProduct layer.

init_weight_generator(input_dim, weight_numel)[source]

Initialize weight generator.

create_nonlinearity(nonlinearity_type, nonlinearity_scalars, nonlinearity_gates)[source]

Create nonlinearity module.

forward(data)[source]

Forward pass of the attention block.

Parameters: - data (Dict[str, torch.Tensor]): A dictionary containing the graph data.

Returns: - Tuple[torch.Tensor, Optional[torch.Tensor]]: Updated node features and skip connection.

Return type:

Tuple[Tensor, Optional[Tensor]]

Attention Utilities

Helpers to split and merge E3 irreps across equivariant attention heads.

Provides VectorToAttentionHeads and AttentionHeadsToVector for multi-head layouts compatible with e3nn irreps.

class hamgnn.nn.attention_utils.VectorToAttentionHeads(irreps_head, num_heads)[source]

Bases: Module

Reshapes vectors of shape [N, irreps_mid] to vectors of shape [N, num_heads, irreps_head].

Attributes: - num_heads (int): Number of attention heads. - irreps_head (o3.Irreps): Irreps of each head. - irreps_mid_in (o3.Irreps): Intermediate irreps. - mid_in_indices (List[Tuple[int, int]]): Indices for reshaping.

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class hamgnn.nn.attention_utils.AttentionHeadsToVector(irreps_head)[source]

Bases: Module

Converts vectors of shape [N, num_heads, irreps_head] into vectors of shape [N, irreps_head * num_heads].

irreps_head

A list of irreducible representations (irreps) that define the structure of the attention heads.

Type:

o3.Irreps

head_sizes

A list of sizes for each attention head, derived from the irreps.

Type:

List[int]

forward(x)[source]

Forward pass to process the attention heads and flatten them into a single vector.

Parameters:

x (torch.Tensor) – Input tensor of shape (N, num_heads, input_dim), where: - N is the batch size. - num_heads is the number of attention heads. - input_dim is the total size of all heads.

Returns:

Output tensor of shape (N, flattened_dim), where flattened_dim

is the sum of the dimensions of all attention heads.

Return type:

torch.Tensor

Raises:

ValueError – If the sum of head_sizes does not match input_dim of the input tensor.