Tensor Operations

Tensor Products

Optimized e3nn tensor products and linear layers with external radial weights.

Implements memory-tuned tensor products, KAN-backed pathways, and helpers shared by convolution and message-passing modules.

class hamgnn.nn.tensor_products.LinearScaleWithWeights(irreps_in, irreps_out)[source]

Bases: Module

forward(x, weight)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class hamgnn.nn.tensor_products.TensorProductWithMemoryOptimizationWithWeight(irreps_input_1, irreps_input_2, irreps_out, irreps_scalar, radial_MLP, use_kan, lite_mode)[source]

Bases: Module

forward(x, y, scalars)[source]

Forward pass of the TensorProductWithMemoryOptimization module.

Parameters:
Returns:

Output tensor after applying tensor products and scaling.

Return type:

torch.Tensor

class hamgnn.nn.tensor_products.TensorProductWithScalarComponents(irreps_input_1, irreps_input_2, irreps_out)[source]

Bases: Module

A module for performing tensor products with memory optimization.

Parameters: - irreps_input_1 (str): Irreducible representations for the first input. - irreps_input_2 (str): Irreducible representations for the second input. - irreps_out (str): Irreducible representations for the output.

forward(x, y)[source]

Forward pass of the module.

Parameters:
Returns:

Output tensor after applying tensor products and scaling.

Return type:

torch.Tensor

class hamgnn.nn.tensor_products.ConcatenatedIrrepsTensorProduct(irreps_in1, irreps_in2, num_tensors_in1, irreps_out, irreps_edge_scalars, radial_MLP, use_kan)[source]

Bases: Module

forward(input_tensors1_list, input_tensor2, scalars)[source]

Forward pass for the ConcatenatedIrrepsTensorProduct module.

Parameters:
  • input_tensors1_list (List[torch.Tensor]) – List of tensors for the first input.

  • input_tensor2 (torch.Tensor) – Tensor for the second input.

  • scalars (torch.Tensor) – Scalar inputs for weight generation.

Returns:

Processed output tensor.

Return type:

torch.Tensor