Utilities
General utility modules for HamGNN v2.1.
Irreps Utilities
- hamgnn.utils.irreps_utils.extract_scalar_irreps(irreps)[source]
Extracts and returns the scalar irreducible representations (irreps) from the given irreps.
A scalar irrep is defined as one with l=0 and p=1. This function calculates the total multiplicity of such scalar irreps and constructs a new Irreps object containing only these.
Parameters: - irreps (o3.Irreps): The input irreps from which to extract scalar components.
Returns: - o3.Irreps: An Irreps object containing only the scalar components.
- Return type:
- hamgnn.utils.irreps_utils.irreps2gate(irreps, nonlinearity_scalars={-1: 'tanh', 1: 'ssp'}, nonlinearity_gates={-1: 'abs', 1: 'ssp'})[source]
Splits irreducible representations into scalar and gated components and associates activation functions.
Parameters: - irreps (o3.Irreps): The input irreducible representations. - nonlinearity_scalars (Dict[int, str]): Activation functions for scalar components. - nonlinearity_gates (Dict[int, str]): Activation functions for gate components.
Returns: - Tuple containing:
irreps_scalars (o3.Irreps): Scalar irreps.
irreps_gates (o3.Irreps): Gate irreps.
irreps_gated (o3.Irreps): Gated irreps.
act_scalars (List[Callable]): Activation functions for scalars.
act_gates (List[Callable]): Activation functions for gates.
- hamgnn.utils.irreps_utils.scale_irreps(irreps, factor)[source]
Scales the multiplicities of the irreducible representations (irreps) by a given factor, ensuring they remain at least 1.
Parameters: - irreps (o3.Irreps): The input irreps. - factor (float): The scaling factor.
Returns: - o3.Irreps: The scaled irreps.
- Return type:
- hamgnn.utils.irreps_utils.filter_and_split_irreps(irreps, num_channels, min_l, max_l)[source]
Filters and splits irreducible representations (irreps) based on specified angular momentum range.
Parameters: - irreps (o3.Irreps): The input irreducible representations. - num_channels (int): The number of channels to split the multiplicity by. - min_l (int): The minimum angular momentum (inclusive). - max_l (int): The maximum angular momentum (inclusive).
Returns: - o3.Irreps: The resulting irreducible representations after filtering and splitting.
- Return type:
Triplets
Math Utilities
- hamgnn.utils.math_utils.count_neighbors_per_node(source_nodes)[source]
Calculate the number of neighbors for each node.
- Parameters:
source_nodes (torch.Tensor) – A tensor containing source node indices.
- Returns:
- A tensor where each index represents a node and the value
at that index is the count of its neighbors.
- Return type:
- hamgnn.utils.math_utils.blockwise_2x2_concat(top_left, top_right, bottom_left, bottom_right)[source]
Concatenates four tensors in a 2x2 block pattern to form a doubled-size tensor. The concatenation pattern follows: [top_left | top_right] ———————- [bottom_left | bottom_right] :type top_left:
Tensor:param top_left: Tensor of shape [N, H, W] :type top_left: torch.Tensor :type top_right:Tensor:param top_right: Tensor of same shape as top_left :type top_right: torch.Tensor :type bottom_left:Tensor:param bottom_left: Tensor of same shape as top_left :type bottom_left: torch.Tensor :type bottom_right:Tensor:param bottom_right: Tensor of same shape as top_left :type bottom_right: torch.Tensor- Returns:
Concatenated tensor of shape [N, 2H, 2W]
- Return type:
- Raises:
ValueError – If input tensors have mismatching shapes
Example
>>> a = torch.ones(2, 3, 3) >>> b = torch.zeros(2, 3, 3) >>> result = blockwise_2x2_concat(a, b, b, a) >>> result.shape torch.Size([2, 6, 6])
- hamgnn.utils.math_utils.extract_elements_above_threshold(condition_tensor, source_tensor, threshold=0.0)[source]
Extracts elements from source tensor where condition tensor exceeds threshold. :type condition_tensor:
Tensor:param condition_tensor: Tensor[Nbatch, N, N] used for threshold comparison :type source_tensor:Tensor:param source_tensor: Tensor[Nbatch, N, N] from which values are extracted :type threshold:float:param threshold: Minimum value for elements in condition_tensor to trigger extraction- Returns:
1D tensor of extracted values from source_tensor
- Return type:
- Raises:
ValueError – If input tensors have mismatching shapes
Example
>>> S = torch.randn(2, 3, 3) >>> H = torch.randn(2, 3, 3) >>> result = extract_elements_above_threshold(S, H, 0.5)
- hamgnn.utils.math_utils.upgrade_tensor_precision(tensor_dict)[source]
Upgrades the precision of specific tensor types in the provided dictionary. This function iterates through the given dictionary and converts: - torch.float32 (float) tensors to torch.float64 (double) - torch.complex64 tensors to torch.complex128 All other tensor types remain unchanged. The original device of each tensor is preserved during conversion. :type tensor_dict: :param tensor_dict: Dictionary containing torch tensors with string keys
and torch tensor values.
- Returns:
The function modifies the dictionary in-place.
- Return type:
None
Notes
For float32 tensors, either .to(dtype=torch.float64) or .double() can be used for conversion. This function uses the .to() method for consistency with complex tensor conversion.
Example
>>> data = {'float_tensor': torch.tensor([1.0, 2.0], dtype=torch.float32)} >>> upgrade_tensor_precision(data) >>> print(data['float_tensor'].dtype) torch.float64
Loss Functions
- class hamgnn.utils.losses.cosine_similarity_loss[source]
Bases:
Module- forward(pred, target)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class hamgnn.utils.losses.sum_zero_loss[source]
Bases:
Module- forward(pred, target)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class hamgnn.utils.losses.Euclidean_loss[source]
Bases:
Module- forward(pred, target)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class hamgnn.utils.losses.RMSELoss[source]
Bases:
Module- forward(pred, target)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.