Model Structure

Base Model

Base neural-network module and graph construction utilities for HamGNN.

Provides BaseModel (abstract forward; batch graph generation from structures), atomic-radius tables aligned with DFT codes (OpenMX, ABACUS), ASE-based neighbor lists with periodic boundary conditions, and helpers to align recomputed edges with pre-existing edge data via column matching.

hamgnn.models.base_model.get_radii_from_atomic_numbers(atomic_numbers, radius_scale=1.5, radius_type='openmx')[source]

Retrieves the scaled atomic radii for a given list or tensor of atomic numbers.

Parameters: - atomic_numbers (Union[torch.Tensor, List[int]]): A list or tensor containing atomic numbers. - radius_scale (float): A scaling factor to multiply the atomic radii. Default is 1.5. - radius_type (str): The software, in which the atomic radius is utilized, originates from a specific source. Default is openmx.

Returns: - List[float]: A list of scaled atomic radii corresponding to the input atomic numbers.

Return type:

List[float]

hamgnn.models.base_model.neighbor_list_and_relative_vec(pos, r_max, self_interaction=False, strict_self_interaction=True, cell=None, pbc=False)[source]

Create neighbor list and neighbor vectors based on radial cutoff.

Edges are given by the following convention: - edge_index[0] is the source (convolution center). - edge_index[1] is the target (neighbor).

Parameters:
  • pos (shape [N, 3]) – Positional coordinates; Tensor or numpy array.

  • r_max (float) – Radial cutoff distance for neighbor finding.

  • cell (numpy shape [3, 3]) – Cell for periodic boundary conditions.

  • pbc (bool or 3-tuple of bool) – Periodicity in each of the three dimensions.

  • self_interaction (bool) – Include same periodic image self-edges.

  • strict_self_interaction (bool) – Include any self interaction edges.

Returns:

List of edges. shifts (torch.Tensor [num_edges, 3]): Relative cell shift vectors. cell_tensor (torch.Tensor [3, 3]): Cell tensor.

Return type:

edge_index (torch.Tensor [2, num_edges])

hamgnn.models.base_model.find_matching_columns_of_A_in_B(A, B)[source]

Finds matching columns between two matrices A and B. Parameters: - A (torch.Tensor): First matrix. - B (torch.Tensor): Second matrix. Returns: - torch.Tensor: Indices of matching columns in B.

class hamgnn.models.base_model.BaseModel(radius_type='openmx', radius_scale=1.5)[source]

Bases: Module

forward(data)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

generate_graph(data)[source]
property num_params: int

HamGNN Convolution

E3-equivariant message-passing backbone (HamGNNConvE3) for HamGNN.

Stacks embedding, interaction, and atomwise blocks with NequIP-compatible atomic data and radial/spherical encodings.

class hamgnn.models.hamgnn_conv.LayerCheckpointModule(conv, corr, pair, use_corr_prod)[source]

Bases: Module

Wraps per-layer blocks for gradient checkpointing with safe tensor cloning.

All three blocks (ConvBlockE3, CorrProductBlock, PairInteractionBlock) mutate the graph dict in-place (e.g., data[NODE_FEATURES_KEY] = output_features). This causes incorrect gradients when wrapped naively with checkpoint() because checkpoint saves tensor storage and in-place ops corrupt the saved values.

Solution: Clone input tensors before creating a working dict, so checkpointed backward pass reconstructs from clean copies.

forward(node_feats, edge_feats, graph)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class hamgnn.models.hamgnn_conv.HamGNNConvE3(config)[source]

Bases: BaseModel

forward(data)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

HamGNN Transformer

Transformer-style E3-equivariant representation (HamGNNTransformer) for HamGNN.

Uses attention blocks alongside pair interactions and spherical edge attributes.

class hamgnn.models.hamgnn_transformer.HamGNNTransformer(config)[source]

Bases: BaseModel

forward(data)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

HamGNN Output Head

Hamiltonian and band-structure output layers, tensor products, and k-space utilities.

Implements equivariant output heads (e.g. HamLayer), structure handling via pymatgen, and scattering-based reductions for graph batches.

class hamgnn.models.hamgnn_output.HamLayer(irreps_in, feature_irreps_hidden, irreps_out, nonlinearity_type='gate', resnet=True)[source]

Bases: Module

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class hamgnn.models.hamgnn_output.HamGNNPlusPlusOut(irreps_in_node=None, irreps_in_edge=None, nao_max=14, return_forces=False, create_graph=False, ham_type='openmx', ham_only=False, symmetrize=True, include_triplet=False, calculate_band_energy=False, num_k=8, k_path=None, band_num_control=None, soc_switch=True, nonlinearity_type='gate', export_reciprocal_values=False, add_H0=False, soc_basis='so3', spin_constrained=False, use_learned_weight=True, minMagneticMoment=0.5, collinear_spin=False, zero_point_shift=False, add_H_nonsoc=False, get_nonzero_mask_tensor=False, calculate_sparsity=True)[source]

Bases: Module

Neural network module for computing Hamiltonian matrices using Graph Neural Networks.

This class implements a GNN architecture for quantum chemistry applications, supporting different Hamiltonian types (openmx, siesta, abacus) with various configuration options for physics-informed neural network training.

Parameters:
  • irreps_in_node (Union[int, str, o3.Irreps]) – Input irreps for node features.

  • irreps_in_edge (Union[int, str, o3.Irreps]) – Input irreps for edge features.

  • nao_max (int) – Maximum number of atomic orbitals.

  • return_forces (bool) – Whether to compute forces during forward pass.

  • create_graph (bool) – Whether to create computational graph for backpropagation.

  • ham_type (str) – Type of Hamiltonian (‘openmx’, ‘siesta’, ‘abacus’, or ‘pasp’).

  • ham_only (bool) – Whether to only compute the Hamiltonian matrix.

  • symmetrize (bool) – Whether to symmetrize the Hamiltonian matrix.

  • include_triplet (bool) – Whether to include triplet interactions.

  • calculate_band_energy (bool) – Whether to calculate band energies.

  • num_k (int) – Number of k-points for band structure calculations.

  • k_path (Union[list, np.ndarray, tuple]) – Path in k-space for band structure.

  • band_num_control (dict or int) – Control for band numbers.

  • soc_switch (bool) – Whether to include spin-orbit coupling.

  • nonlinearity_type (str) – Type of nonlinearity for neural network layers.

  • export_reciprocal_values (bool) – Whether to export reciprocal space values.

  • add_H0 (bool) – Whether to add the initial Hamiltonian term H0.

  • soc_basis (str) – Basis for spin-orbit coupling (‘so3’ or ‘su2’).

  • spin_constrained (bool) – Whether to apply spin constraints.

  • use_learned_weight (bool) – Whether to use learned weights.

  • minMagneticMoment (float) – Minimum magnetic moment threshold.

  • collinear_spin (bool) – Whether to consider collinear spin.

  • zero_point_shift (bool) – Whether to apply zero-point energy shift.

  • add_H_nonsoc (bool) – Whether to add non-SOC Hamiltonian.

  • get_nonzero_mask_tensor (bool) – Whether to get nonzero mask tensor.

merge_tensor_components(spherical_components)[source]

Merge spherical tensor components into a matrix using Clebsch-Gordan coefficients.

Parameters:

spherical_components (list of torch.Tensor) – List of tensors representing spherical components of irreducible representations. Each tensor has shape (batch_size, component_dim).

Returns:

Merged matrix with shape (batch_size, nao_max * nao_max) representing

flattened matrix of merged components.

Return type:

torch.Tensor

merge_rank2_tensor_components(spherical_components)[source]

Merge rank-2 tensor components into a block matrix representation.

This function is specialized for rank-2 tensors and applies permutation to the result.

Parameters:

spherical_components (list of torch.Tensor) – List of tensors representing spherical components of irreducible representations.

Returns:

Merged block matrix with shape (batch_size, n_blocks, 3, 3) after

applying a specific permutation.

Return type:

torch.Tensor

merge_rank0_tensor_components(spherical_components)[source]

Merge rank-0 (scalar) tensor components into a matrix representation.

Parameters:

spherical_components (list of torch.Tensor) – List of tensors representing spherical components of irreducible representations.

Returns:

Merged matrix with shape (batch_size, nao_max, nao_max).

Return type:

torch.Tensor

construct_j_coupling_matrix(coupling_coefficients)[source]

Construct a matrix representation of J-coupling (spin-orbit interaction).

Builds a matrix representation for J-coupling which can be either rank-2 (with spin-orbit coupling) or rank-0.

Parameters:

coupling_coefficients (torch.Tensor) – Tensor containing coupling coefficients.

Returns:

The J-coupling matrix. If spin-orbit coupling is enabled,

shape is (batch_size, nao_max, nao_max, 3, 3); otherwise, shape is (batch_size, nao_max, nao_max).

Return type:

torch.Tensor

construct_k_coupling_matrix(coupling_coefficients)[source]

Construct a matrix representation of exchange coupling (K-term).

Parameters:

coupling_coefficients (torch.Tensor) – Tensor containing exchange coupling coefficients with shape (n_atoms_or_edges, coefficients_per_block * n_blocks).

Returns:

Exchange coupling matrix with shape (n_atoms_or_edges, nao_max, nao_max).

Return type:

torch.Tensor

reorder_matrix(matrix)[source]

Reorder matrix elements to match the atomic orbital convention used by DFT.

This function performs two types of transformations: 1. Reorders rows and columns according to a predefined permutation (if defined) 2. Flips the sign of specific elements (if defined)

These transformations ensure compatibility with DFT’s atomic orbital ordering convention, which may differ from the internal representation.

Parameters:

matrix (torch.Tensor) – Input matrix in flattened form, with shape (batch_size, nao_max2) where nao_max is the maximum number of atomic orbitals.

Returns:

Reordered matrix in the same shape as input,

but with elements rearranged to match DFT conventions.

Return type:

torch.Tensor

construct_molecular_hamiltonian(data, onsite_hamiltonian, offsite_hamiltonian)[source]

Construct a complete molecular Hamiltonian matrix from on-site and off-site components.

This function transforms separate on-site (diagonal) and off-site (interaction) Hamiltonian components into a unified molecular Hamiltonian matrix. It handles atom-specific orbital basis sets and masks invalid or padded orbitals.

Parameters:
  • data (DataObject) – Object containing crystal structure information, including: - z: Atomic numbers - node_counts: Number of atoms in each crystal - edge_index: Edges between atoms (indices of connected atom pairs)

  • onsite_hamiltonian (torch.Tensor) – On-site Hamiltonian matrix elements for each atom, with shape (n_atoms, nao_max^2)

  • offsite_hamiltonian (torch.Tensor) – Off-site Hamiltonian matrix elements for each edge, with shape (n_edges, nao_max^2)

Returns:

Complete molecular Hamiltonian matrix with shape (n_molecules, n_orbitals, n_orbitals),

where n_orbitals is the number of valid atomic orbitals per molecule after removing padding.

Return type:

torch.Tensor

concatenate_hamiltonians_by_crystal(data, onsite_hamiltonians, offsite_hamiltonians)[source]

Concatenate on-site and off-site Hamiltonian matrices for each crystal in a batch.

This function organizes Hamiltonian matrices by crystal, interleaving on-site and off-site components in a specific order required for further processing.

Parameters:
  • data (DataObject) – Object containing crystal structure information, including: - node_counts: Number of atoms in each crystal - edge_index: Indices of connected atom pairs - batch: Batch assignment for each node

  • onsite_hamiltonians (torch.Tensor) – On-site Hamiltonian matrices for all atoms, with shape (total_atoms, matrix_dimension)

  • offsite_hamiltonians (torch.Tensor) – Off-site Hamiltonian matrices for all edges, with shape (total_edges, matrix_dimension)

Returns:

Concatenated Hamiltonian matrices organized by crystal,

with alternating on-site and off-site blocks.

Return type:

torch.Tensor

symmetrize_hamiltonian(hamiltonian, is_soc=False, inverse_edges=None, symmetry_type='hermitian')[source]

Apply symmetrization to a Hamiltonian matrix.

This is a general-purpose function that handles various types of symmetrization for both on-site and off-site Hamiltonians, with or without spin-orbit coupling (SOC).

Parameters:
  • hamiltonian (torch.Tensor) – Hamiltonian matrix elements in flattened form.

  • is_soc (bool, optional) – Whether this is a spin-orbit coupling Hamiltonian with double dimension. Defaults to False.

  • inverse_edges (torch.Tensor, optional) – Indices mapping each edge to its inverse for off-site Hamiltonians. Required for off-site symmetrization. Defaults to None.

  • symmetry_type (str, optional) – Type of symmetry to apply: - ‘hermitian’: Apply Hermitian symmetry H = 0.5*(H + H?) - ‘anti-hermitian’: Apply anti-Hermitian symmetry H = 0.5*(H - H?) Defaults to ‘hermitian’.

Returns:

Symmetrized Hamiltonian matrix in the same format as input.

Return type:

torch.Tensor

symmetrize_onsite_hamiltonian(hamiltonian, hermitian=True)[source]

Symmetrize on-site Hamiltonian matrices.

Applies Hermitian or anti-Hermitian symmetrization to on-site Hamiltonians.

Parameters:
  • hamiltonian (torch.Tensor) – On-site Hamiltonian matrix elements.

  • hermitian (bool, optional) – If True, apply Hermitian symmetry (H + H?), otherwise apply anti-Hermitian symmetry (H - H?). Defaults to True.

Returns:

Symmetrized Hamiltonian.

Return type:

torch.Tensor

symmetrize_offsite_hamiltonian(hamiltonian, inverse_edges, hermitian=True)[source]

Symmetrize off-site Hamiltonian matrices.

Applies Hermitian or anti-Hermitian symmetrization to off-site Hamiltonians, using the inverse edge mapping to relate connected atom pairs.

Parameters:
  • hamiltonian (torch.Tensor) – Off-site Hamiltonian matrix elements.

  • inverse_edges (torch.Tensor) – Tensor mapping each edge to its inverse edge index.

  • hermitian (bool, optional) – If True, apply Hermitian symmetry (H + H?), otherwise apply anti-Hermitian symmetry (H - H?). Defaults to True.

Returns:

Symmetrized Hamiltonian.

Return type:

torch.Tensor

symmetrize_onsite_hamiltonian_soc(hamiltonian, hermitian=True)[source]

Symmetrize on-site Hamiltonian matrices with spin-orbit coupling.

Applies Hermitian or anti-Hermitian symmetrization to on-site Hamiltonians that include spin-orbit coupling, which have double the dimension.

Parameters:
  • hamiltonian (torch.Tensor) – On-site SOC Hamiltonian matrix elements.

  • hermitian (bool, optional) – If True, apply Hermitian symmetry (H + H?), otherwise apply anti-Hermitian symmetry (H - H?). Defaults to True.

Returns:

Symmetrized SOC Hamiltonian.

Return type:

torch.Tensor

symmetrize_offsite_hamiltonian_soc(hamiltonian, inverse_edges, hermitian=True)[source]

Symmetrize off-site Hamiltonian matrices with spin-orbit coupling.

Applies Hermitian or anti-Hermitian symmetrization to off-site Hamiltonians that include spin-orbit coupling, which have double the dimension.

Parameters:
  • hamiltonian (torch.Tensor) – Off-site SOC Hamiltonian matrix elements.

  • inverse_edges (torch.Tensor) – Tensor mapping each edge to its inverse edge index.

  • hermitian (bool, optional) – If True, apply Hermitian symmetry (H + H?), otherwise apply anti-Hermitian symmetry (H - H?). Defaults to True.

Returns:

Symmetrized SOC Hamiltonian.

Return type:

torch.Tensor

calculate_band_energies_with_overlap(onsite_hamiltonian, offsite_hamiltonian, onsite_overlap, offsite_overlap, crystal_data, export_reciprocal_values=False)[source]

Calculate electronic band structure using provided Hamiltonian and overlap matrices.

This function computes electronic band energies, wavefunctions, and band gaps for a set of crystal structures using the generalized eigenvalue problem H���� = E��S����. This version allows debugging by accepting both reference and predicted overlap matrices.

Parameters:
  • onsite_hamiltonian (torch.Tensor) – On-site Hamiltonian matrix elements with shape (total_atoms, nao_max2).

  • offsite_hamiltonian (torch.Tensor) – Off-site Hamiltonian matrix elements with shape (total_edges, nao_max2).

  • onsite_overlap (torch.Tensor) – Predicted on-site overlap matrix elements with shape (total_atoms, nao_max2).

  • offsite_overlap (torch.Tensor) – Predicted off-site overlap matrix elements with shape (total_edges, nao_max2).

  • crystal_data (DataObject) – Object containing crystal structure information including: - edge_index: Indices of connected atom pairs - cell: Unit cell vectors - z: Atomic numbers - node_counts: Number of atoms in each crystal - batch: Batch assignment for each atom - k_vecs: k-points for band structure calculation - nbr_shift: Neighbor cell shifts for periodic boundary conditions - Son/Soff: Reference overlap matrices

  • export_reciprocal_values (bool, optional) – Whether to export additional reciprocal space matrices (H(k), S(k), dS(k)). Defaults to False.

Returns:

Contains band energies, wavefunctions, and optionally additional reciprocal space

matrices depending on the export_reciprocal_values parameter.

Return type:

tuple

calculate_band_energies(onsite_hamiltonian, offsite_hamiltonian, crystal_data, export_reciprocal_values=False)[source]

Calculate electronic band structure using Hamiltonian matrices and reference overlap matrices.

This function computes electronic band energies, wavefunctions, and band gaps for a set of crystal structures by solving the generalized eigenvalue problem H���� = E��S����.

Parameters:
  • onsite_hamiltonian (torch.Tensor) – On-site Hamiltonian matrix elements with shape (total_atoms, nao_max2).

  • offsite_hamiltonian (torch.Tensor) – Off-site Hamiltonian matrix elements with shape (total_edges, nao_max2).

  • crystal_data (DataObject) – Object containing crystal structure information including: - edge_index: Indices of connected atom pairs - cell: Unit cell vectors - z: Atomic numbers - node_counts: Number of atoms in each crystal - batch: Batch assignment for each atom - k_vecs: k-points for band structure calculation - nbr_shift: Neighbor cell shifts for periodic boundary conditions - Son/Soff: Reference overlap matrices

  • export_reciprocal_values (bool, optional) – Whether to export additional reciprocal space matrices (H(k), S(k), dS(k)). Defaults to False.

Returns:

Contains band energies, wavefunctions, band gaps, and optionally additional

reciprocal space matrices depending on the export_reciprocal_values parameter.

Return type:

tuple

calculate_band_energies_with_spin_orbit_coupling(real_onsite, imag_onsite, real_offsite, imag_offsite, crystal_data)[source]

Calculate electronic band structure with spin-orbit coupling (SOC).

This function computes electronic band energies and wavefunctions for systems with spin-orbit coupling, which requires handling complex Hamiltonian matrices and doubling the matrix dimensions to account for spin.

Parameters:
  • real_onsite (torch.Tensor) – Real part of on-site SOC Hamiltonian with shape (total_atoms, 2*nao_max, 2*nao_max).

  • imag_onsite (torch.Tensor) – Imaginary part of on-site SOC Hamiltonian with shape (total_atoms, 2*nao_max, 2*nao_max).

  • real_offsite (torch.Tensor) – Real part of off-site SOC Hamiltonian with shape (total_edges, 2*nao_max, 2*nao_max).

  • imag_offsite (torch.Tensor) – Imaginary part of off-site SOC Hamiltonian with shape (total_edges, 2*nao_max, 2*nao_max).

  • crystal_data (DataObject) – Object containing crystal structure information including: - edge_index: Indices of connected atom pairs - cell: Unit cell vectors - z: Atomic numbers - node_counts: Number of atoms in each crystal - batch: Batch assignment for each atom - k_vecs: k-points for band structure calculation - cell_shift: Cell shift vectors for periodic images - nbr_shift: Neighbor cell shifts for periodic boundary conditions - Son/Soff: Reference overlap matrices (without SOC)

Returns:

(

band_energies (torch.Tensor): Band energies with shape (total_bands, num_k_points), wavefunctions (torch.Tensor): Flattened wavefunctions

)

Return type:

tuple

apply_orbital_masks_to_hamiltonians(onsite_hamiltonian, offsite_hamiltonian, data, return_masks=False)[source]

Apply atomic orbital validity masks to on-site and off-site Hamiltonian matrices.

This function zeroes out elements of Hamiltonian matrices that correspond to invalid or non-existent atomic orbitals based on the atomic numbers. This ensures physical correctness by preventing interactions involving orbitals that shouldn’t exist for particular atom types.

Parameters:
  • onsite_hamiltonian (torch.Tensor) – On-site Hamiltonian matrix elements with shape (n_atoms, nao_max^2) or (n_atoms, nao_max, nao_max).

  • offsite_hamiltonian (torch.Tensor) – Off-site Hamiltonian matrix elements with shape (n_edges, nao_max^2) or (n_edges, nao_max, nao_max).

  • data (DataObject) – Object containing: - z: Atomic numbers for each atom - edge_index: Indices of connected atom pairs (source, target)

  • return_masks (bool, optional) – Whether to return the masks along with masked Hamiltonians. Defaults to False.

Returns:

If return_masks is False:

(masked_onsite_hamiltonian, masked_offsite_hamiltonian)

If return_masks is True: (masked_onsite_hamiltonian, masked_offsite_hamiltonian,

onsite_orbital_mask, offsite_orbital_mask)

Return type:

tuple

symmetrize_orbital_coefficients(coefficient_matrix)[source]

Enforce spherical symmetry on orbital coefficient matrices by averaging within angular momentum blocks.

This function applies orbital symmetrization to ensure that coefficients maintain proper rotational invariance within each angular momentum subspace (p, d, f orbitals). Each block of coefficients corresponding to orbitals with the same angular momentum is averaged to enforce spherical symmetry constraints.

Orbital index ranges by angular momentum: - s orbitals: 0:3 (single orbital with additional indices) - p orbitals: 3:6 (3 components) - p’ orbitals: 6:9 (3 components) - d orbitals: 9:14 (5 components) - d’ orbitals: 14:19 (5 components, only for nao_max �� 19) - f orbitals: 19:26 (7 components, only for nao_max = 26)

Parameters:

coefficient_matrix (torch.Tensor) – Coefficient matrix with shape (batch_size, nao_max2) or (batch_size, nao_max, nao_max).

Returns:

Symmetrized coefficient matrix with shape (batch_size, nao_max2).

Return type:

torch.Tensor

create_cell_index_mapping(unique_cell_vectors)[source]

Create a mapping from unique cell vectors to their indices.

This function creates a dictionary that associates each unique cell vector (representing a unit cell in a periodic lattice) with its index in the list of unique vectors. This mapping is used for efficient lookup of cell indices when constructing Hamiltonian matrices with periodic boundary conditions.

Parameters:

unique_cell_vectors (List[List[int]]) – A list of cell vectors, where each vector is represented as a list of integers (typically 3 integers for 3D periodic systems).

Returns:

A dictionary mapping each cell vector (as a tuple) to its integer index.

Return type:

dict

extract_unique_cell_vectors(data)[source]

Extract unique cell vectors and create mappings for periodic boundary conditions.

This function processes cell shift data to: 1. Find all unique cell shift vectors 2. Ensure the zero vector (0,0,0) is included (required for on-site interactions) 3. Create a mapping from each cell shift to its unique index 4. Create an index array that maps each edge to its corresponding cell index

This information is essential for constructing Hamiltonian matrices with proper periodic boundary conditions.

Parameters:

data (DataObject) –

Object containing crystal structure information, including: - cell_shift: Tensor of cell shift vectors for each edge, with shape

(n_edges, 3), representing the periodic image shifts.

Returns:

  • unique_cell_vectors (torch.Tensor): Tensor of unique cell shift vectors with shape (n_unique_cells, 3).

  • cell_vector_indices (torch.Tensor): Tensor mapping each edge to the index of its corresponding unique cell vector, with shape (n_edges,).

  • cell_vector_map (dict): Dictionary mapping each cell vector tuple to its index in unique_cell_vectors.

Return type:

tuple

build_edge_lookup_structures(data, inverse_edge_indices=None)[source]

Build efficient edge lookup structures for crystal graph operations.

This function constructs two data structures that enable fast edge lookup: 1. A mapping from each atom to all edges where that atom is the source 2. A mapping from each atom and cell shift combination to corresponding edges

These structures accelerate operations that require finding all edges connected to specific atoms, particularly when periodic boundary conditions are involved.

Parameters:
  • data (DataObject) – Crystal graph data containing: - edge_index: Tensor of shape [2, num_edges] with source and target indices - unique_cell_shift: Tensor of unique cell shift vectors - cell_shift_indices: Tensor mapping each edge to its cell shift index - cell_index_map: Dictionary mapping cell shift tuples to indices - z: Atomic numbers, used to determine number of atoms

  • inverse_edge_indices (torch.Tensor, optional) – Tensor mapping each edge to its inverse edge index (i��j maps to j��i). Required for building the target lookup.

Returns:

  • source_edge_indices (list of torch.Tensor): For each atom, contains indices of edges where that atom is the source.

  • target_edge_indices (list of list of torch.Tensor): For each atom and cell shift combination, contains indices of edges where that atom is the target in the specified cell shift.

Return type:

tuple

create_orbital_validity_mask(atomic_numbers)[source]

Create a mask identifying valid atomic orbitals for each element type.

This function generates a binary mask tensor where each row corresponds to an element in the periodic table (indexed by atomic number), and each column represents an atomic orbital. A value of 1 indicates that the orbital is valid for that element type, while 0 indicates an invalid orbital.

Parameters:

atomic_numbers (torch.Tensor) – Tensor containing atomic numbers, used only for device and dtype information.

Returns:

Binary mask tensor of shape (99, nao_max) where 99 covers

all possible elements in the periodic table and nao_max is the maximum number of atomic orbitals.

Return type:

torch.Tensor

build_interaction_masks(data)[source]

Build boolean masks for valid orbital interactions in Hamiltonian matrices.

This function generates masks that identify which elements in the Hamiltonian matrices correspond to valid orbital-orbital interactions, based on the atomic species involved. It creates separate masks for on-site (same atom) and off-site (different atoms) interactions.

Parameters:

data (DataObject) – Graph data containing: - edge_index: Indices of connected atom pairs - z: Atomic numbers for each atom

Returns:

Combined mask tensor with shape (n_atoms + n_edges, nao_max2)

where the first n_atoms rows are for on-site interactions and the remaining n_edges rows are for off-site interactions.

Return type:

torch.Tensor

build_column_wise_interaction_masks(data)[source]

Build interaction masks with an additional column dimension.

Similar to build_interaction_masks, but creates masks with an additional dimension for column-wise operations (e.g., for real and imaginary parts).

Parameters:

data (DataObject) – Graph data containing: - edge_index: Indices of connected atom pairs - z: Atomic numbers for each atom

Returns:

Combined mask tensor with shape

(n_atoms + n_edges, 2, nao_max2) where the additional dimension can be used for real and imaginary parts.

Return type:

torch.Tensor

build_spin_orbit_interaction_masks(data)[source]

Build interaction masks that account for spin-orbit coupling.

This function generates masks for systems with spin-orbit coupling, which requires handling interactions between different spin components. The resulting masks have double the orbital dimension to account for spin-up and spin-down components.

Parameters:

data (DataObject) – Graph data containing: - edge_index: Indices of connected atom pairs - z: Atomic numbers for each atom

Returns:

A tuple containing:
  • real_imag_masks (torch.Tensor): Mask for real and imaginary components with shape (n_atoms + n_edges, (2*nao_max)2)

  • combined_masks (torch.Tensor): Full mask tensor with shape (2*(n_atoms + n_edges), (2*nao_max)2)

Return type:

tuple

calculate_sparsity_ratio(data)[source]

Calculate the ratio between the total possible matrix elements and the effective matrix elements.

This function computes the sparsity of Hamiltonian matrices by dividing the total number of possible matrix elements by the number of effective matrix elements based on the atomic basis definitions.

Parameters:

data (object) – Data object containing atomic information and Hamiltonian matrices.

Returns:

A scalar tensor on the same device as data.z. The sparsity ratio is defined as the total number of matrix elements divided by the number of effective matrix elements. Returns inf if there are no effective elements.

Return type:

torch.Tensor

Notes

The calculation considers both on-site and off-site Hamiltonian elements if they are present in the data object.

validate_elements_in_basis_def(data, raise_error=True)[source]

Validate that all elements in the input data exist in the basis_def dictionary.

Notes

This function is used to ensure that all elements in the molecular system have corresponding basis set definitions before performing calculations.

forward(data, graph_representation=None)[source]

Forward pass of the Hamiltonian prediction model.

This method constructs Hamiltonian matrices for crystal structures from graph representations, with support for various physical effects including spin-orbit coupling (SOC), magnetism, and long-range interactions. It can also calculate electronic band structures.

The method handles several different physical regimes: 1. Non-magnetic systems (standard DFT-like Hamiltonians) 2. Collinear spin-polarized systems (separate Hamiltonians for up/down spins) 3. Non-collinear magnetic systems (full 2x2 spin blocks with SOC) 4. Systems with spin-orbit coupling (complex Hamiltonians with real/imaginary parts)

Parameters:
  • data (DataObject) – Contains crystal structure information including: - edge_index: Connectivity information between atoms - z: Atomic numbers - pos: Atomic positions - cell: Unit cell vectors - Hon/Hoff: Reference on-site/off-site Hamiltonian matrices (if available) - Son/Soff: Reference on-site/off-site overlap matrices - node_counts: Number of atoms in each crystal - batch: Batch assignment for each atom - inv_edge_idx: Inverse edge indices

  • graph_representation (dict) – Output from the graph neural network containing: - node_attr: Node feature vectors from the GNN - edge_attr: Edge feature vectors from the GNN

Returns:

Dictionary containing predicted quantities, which may include:
  • hamiltonian: Predicted Hamiltonian matrix

  • overlap: Predicted overlap matrix (if ham_only=False)

  • band_energy: Electronic band energies (if calculate_band_energy=True)

  • wavefunction: Eigenvectors of the Hamiltonian (if calculate_band_energy=True)

  • band_gap: Band gap values (if calculate_band_energy=True)

  • Various additional matrices for debugging or analysis

Return type:

dict

Note

The exact contents of the return dictionary depend on the model configuration parameters (soc_switch, spin_constrained, collinear_spin, etc.)

Main Model Implementation

PyTorch Lightning training module wiring representation and output networks.

Defines Model with optimizers, losses, metrics, and distributed training hooks.

class hamgnn.models.Model.Model(representation, output, losses, validation_metrics, lr=0.001, lr_decay=0.1, lr_patience=100, lr_monitor='training/total_loss', epsilon=1e-08, beta1=0.99, beta2=0.999, amsgrad=True, max_points_to_scatter=100000, post_processing=None)[source]

Bases: LightningModule

A PyTorch Lightning module for scientific machine learning models.

This class implements a modular architecture with representation and output components, handles training, validation, and testing with customizable losses and metrics, and supports gradient-based computations.

Parameters:
  • representation (nn.Module) – Neural network module that computes feature representations from input data

  • output (nn.Module) – Neural network module that transforms representations into predictions

  • losses (List[Dict]) – List of dictionaries defining loss functions, their targets, predictions, and weights

  • validation_metrics (List[Dict]) – List of dictionaries defining metrics to track during validation

  • lr (float, default=1e-3) – Initial learning rate for optimizer

  • lr_decay (float, default=0.1) – Factor by which learning rate is reduced on plateau

  • lr_patience (int, default=100) – Number of epochs with no improvement after which learning rate is reduced

  • lr_monitor (str, default="training/total_loss") – Metric to monitor for learning rate scheduling

  • epsilon (float, default=1e-8) – Small constant for numerical stability in optimizer

  • beta1 (float, default=0.99) – Exponential decay rate for first moment estimates in Adam optimizer

  • beta2 (float, default=0.999) – Exponential decay rate for second moment estimates in Adam optimizer

  • amsgrad (bool, default=True) – Whether to use AMSGrad variant of Adam optimizer

  • max_points_to_scatter (int, default=100000) – Maximum number of points to include in scatter plots

  • post_processing (callable, optional) – Function for calculating additional physical quantities that may require gradient backpropagation

calculate_loss(batch, predictions, mode)[source]

Calculate the total loss by summing weighted individual loss components.

Parameters:
  • batch (Dict[str, torch.Tensor]) – Dictionary containing input data and target values

  • predictions (Dict[str, torch.Tensor]) – Dictionary containing model predictions

  • mode (str) – Current mode (‘training’, ‘validation’, or ‘test’)

Returns:

Total weighted loss

Return type:

torch.Tensor

training_step(batch, batch_idx)[source]

Perform a single training step.

Parameters:
  • batch (Dict[str, torch.Tensor]) – Dictionary containing training data

  • batch_idx (int) – Index of the current batch

Returns:

Training loss for this step

Return type:

torch.Tensor

validation_step(batch, batch_idx)[source]

Perform a single validation step.

Parameters:
  • batch (Dict[str, torch.Tensor]) – Dictionary containing validation data

  • batch_idx (int) – Index of the current batch

Returns:

Dictionary containing predictions and targets for logging

Return type:

Dict

validation_epoch_end(validation_step_outputs)[source]

Process and log validation results at the end of an epoch.

Parameters:

validation_step_outputs (List[Dict]) – List of outputs from all validation steps in the epoch

Return type:

None

test_step(batch, batch_idx)[source]

Perform a single test step.

Parameters:
  • batch (Dict[str, torch.Tensor]) – Dictionary containing test data

  • batch_idx (int) – Index of the current batch

Returns:

Dictionary containing predictions, targets, and processed values

Return type:

Dict

test_epoch_end(test_step_outputs)[source]

Process and log test results at the end of testing.

Parameters:

test_step_outputs (List[Dict]) – List of outputs from all test steps

Return type:

None

forward(batch)[source]

Forward pass through the model.

Parameters:

batch (Dict[str, torch.Tensor]) – Dictionary containing input data

Returns:

Dictionary containing model predictions

Return type:

Dict[str, torch.Tensor]

log_metrics(batch, predictions, mode)[source]

Log evaluation metrics for the current batch.

Parameters:
  • batch (Dict[str, torch.Tensor]) – Dictionary containing input data and target values

  • predictions (Dict[str, torch.Tensor]) – Dictionary containing model predictions

  • mode (str) – Current mode (‘validation’ or ‘test’)

Return type:

None

configure_optimizers()[source]

Configure optimizers and learning rate schedulers.

Returns:

Configuration dictionary for PyTorch Lightning

Return type:

Dict