Attention Mechanisms
Attention Layers
Equivariant graph attention blocks and edge softmax aggregation for HamGNN.
Implements attention over message-passing features with e3nn irreps, including
AttentionAggregationV2 and related modules.
- class hamgnn.nn.attention.AttentionAggregationV2(num_heads, irreps_value)[source]
Bases:
ModuleAn equivariant attention mechanism that processes key, value, and query vectors and applies attention across edges in a graph.
Parameters: - num_heads (int): Number of attention heads. - irreps_value (o3.Irreps): Irreducible representations for value vectors.
- forward(value, edge_weights, edge_weights_cutoff, edge_index)[source]
Forward pass of the attention mechanism.
Parameters: - key (torch.Tensor): Key vectors. - value (torch.Tensor): Value vectors. - query (torch.Tensor): Query vectors. - edge_weight_cutoff (torch.Tensor): Cutoff weights for edges. - edge_index (torch.LongTensor): Edge indices.
Returns: - torch.Tensor: Attended output vectors.
- Return type:
- class hamgnn.nn.attention.AttentionAggregation(num_heads, irreps_key, irreps_value, irreps_query)[source]
Bases:
ModuleAn equivariant attention mechanism that processes key, value, and query vectors and applies attention across edges in a graph.
Parameters: - num_heads (int): Number of attention heads. - irreps_key (o3.Irreps): Irreducible representations for key vectors. - irreps_value (o3.Irreps): Irreducible representations for value vectors. - irreps_query (o3.Irreps): Irreducible representations for query vectors.
- forward(key, value, query, edge_weight_cutoff, edge_index)[source]
Forward pass of the attention mechanism.
Parameters: - key (torch.Tensor): Key vectors. - value (torch.Tensor): Value vectors. - query (torch.Tensor): Query vectors. - edge_weight_cutoff (torch.Tensor): Cutoff weights for edges. - edge_index (torch.LongTensor): Edge indices.
Returns: - torch.Tensor: Attended output vectors.
- Return type:
- class hamgnn.nn.attention.AttentionBlockE3(irreps_in, irreps_out, irreps_node_attrs, irreps_edge_feats, irreps_edge_attrs, irreps_edge_embed, num_heads, max_radius, radial_MLP=None, use_skip_connections=True, use_kan=False, nonlinearity_type='gate', nonlinearity_scalars={'e': 'ssp', 'o': 'tanh'}, nonlinearity_gates={'e': 'ssp', 'o': 'abs'})[source]
Bases:
ModuleAn equivariant attention block for processing graph data with attention mechanisms.
Parameters: - irreps_in (o3.Irreps): Input irreducible representations. - irreps_out (o3.Irreps): Output irreducible representations. - irreps_node_attrs (o3.Irreps): Node attribute irreducible representations. - irreps_edge_attrs (o3.Irreps): Edge attribute irreducible representations. - irreps_edge_embed (o3.Irreps): Edge embedding irreducible representations. - num_heads (int): Number of attention heads. - max_radius (float): Maximum radius for edge cutoff. - radial_MLP (Optional[List[int]]): Architecture of the radial MLP. - use_skip_connections (bool): Whether to use skip connections. - use_kan (bool): Whether to use KAN for radial MLP. - nonlinearity_type (str): Type of nonlinearity (‘gate’ or ‘norm’). - nonlinearity_scalars (Dict[int, Callable]): Scalar nonlinearity functions. - nonlinearity_gates (Dict[int, Callable]): Gate nonlinearity functions.
Attention Utilities
Helpers to split and merge E3 irreps across equivariant attention heads.
Provides VectorToAttentionHeads and AttentionHeadsToVector for
multi-head layouts compatible with e3nn irreps.
- class hamgnn.nn.attention_utils.VectorToAttentionHeads(irreps_head, num_heads)[source]
Bases:
ModuleReshapes vectors of shape [N, irreps_mid] to vectors of shape [N, num_heads, irreps_head].
Attributes: - num_heads (int): Number of attention heads. - irreps_head (o3.Irreps): Irreps of each head. - irreps_mid_in (o3.Irreps): Intermediate irreps. - mid_in_indices (List[Tuple[int, int]]): Indices for reshaping.
- forward(x)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses. :rtype:
TensorNote
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class hamgnn.nn.attention_utils.AttentionHeadsToVector(irreps_head)[source]
Bases:
ModuleConverts vectors of shape [N, num_heads, irreps_head] into vectors of shape [N, irreps_head * num_heads].
- irreps_head
A list of irreducible representations (irreps) that define the structure of the attention heads.
- Type:
o3.Irreps
- forward(x)[source]
Forward pass to process the attention heads and flatten them into a single vector.
- Parameters:
x (torch.Tensor) – Input tensor of shape (N, num_heads, input_dim), where: - N is the batch size. - num_heads is the number of attention heads. - input_dim is the total size of all heads.
- Returns:
- Output tensor of shape (N, flattened_dim), where flattened_dim
is the sum of the dimensions of all attention heads.
- Return type:
- Raises:
ValueError – If the sum of head_sizes does not match input_dim of the input tensor.