# Copyright (c) 2021-2026 HamGNN Team
# SPDX-License-Identifier: GPL-3.0-only
"""Hamiltonian and band-structure output layers, tensor products, and k-space utilities.
Implements equivariant output heads (e.g. :class:`HamLayer`), structure handling
via pymatgen, and scattering-based reductions for graph batches.
"""
import copy
import math
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
import numpy as np
import opt_einsum as oe
import torch
from e3nn import o3
from e3nn.util.jit import compile_mode
from pymatgen.core.periodic_table import Element
from pymatgen.core.structure import Structure
from pymatgen.symmetry.kpath import KPathSeek
from torch import nn
from torch_geometric.nn import global_mean_pool
from torch_scatter import scatter
from ..nn.interaction_blocks import ResidualBlock
from ..nn.tensor_decomposition import E3TensorDecomposition
from ..physics.Clebsch_Gordan_coefficients import ClebschGordanCoefficients
from ..physics.kpoints import kpoints_generator
from ..utils.constants import au2ang
from ..utils.math_utils import (
blockwise_2x2_concat,
extract_elements_above_threshold,
upgrade_tensor_precision
)
[docs]
@compile_mode("script")
class HamLayer(nn.Module):
def __init__(self, irreps_in, feature_irreps_hidden, irreps_out, nonlinearity_type: str = "gate", resnet: bool = True):
super().__init__()
# Define the residual block
self.residual_block = ResidualBlock(irreps_in=irreps_in,
feature_irreps_hidden=feature_irreps_hidden,
nonlinearity_type=nonlinearity_type,
resnet=resnet)
# Define the linear transformation
self.linear_transform = o3.Linear(irreps_in=irreps_in, irreps_out=irreps_out)
[docs]
def forward(self, x):
# Apply the residual block
x = self.residual_block(x)
# Apply the linear transformation
x = self.linear_transform(x)
return x
[docs]
class HamGNNPlusPlusOut(nn.Module):
"""
Neural network module for computing Hamiltonian matrices using Graph Neural Networks.
This class implements a GNN architecture for quantum chemistry applications, supporting
different Hamiltonian types (openmx, siesta, abacus) with various configuration options
for physics-informed neural network training.
Args:
irreps_in_node (Union[int, str, o3.Irreps]): Input irreps for node features.
irreps_in_edge (Union[int, str, o3.Irreps]): Input irreps for edge features.
nao_max (int): Maximum number of atomic orbitals.
return_forces (bool): Whether to compute forces during forward pass.
create_graph (bool): Whether to create computational graph for backpropagation.
ham_type (str): Type of Hamiltonian ('openmx', 'siesta', 'abacus', or 'pasp').
ham_only (bool): Whether to only compute the Hamiltonian matrix.
symmetrize (bool): Whether to symmetrize the Hamiltonian matrix.
include_triplet (bool): Whether to include triplet interactions.
calculate_band_energy (bool): Whether to calculate band energies.
num_k (int): Number of k-points for band structure calculations.
k_path (Union[list, np.ndarray, tuple]): Path in k-space for band structure.
band_num_control (dict or int): Control for band numbers.
soc_switch (bool): Whether to include spin-orbit coupling.
nonlinearity_type (str): Type of nonlinearity for neural network layers.
export_reciprocal_values (bool): Whether to export reciprocal space values.
add_H0 (bool): Whether to add the initial Hamiltonian term H0.
soc_basis (str): Basis for spin-orbit coupling ('so3' or 'su2').
spin_constrained (bool): Whether to apply spin constraints.
use_learned_weight (bool): Whether to use learned weights.
minMagneticMoment (float): Minimum magnetic moment threshold.
collinear_spin (bool): Whether to consider collinear spin.
zero_point_shift (bool): Whether to apply zero-point energy shift.
add_H_nonsoc (bool): Whether to add non-SOC Hamiltonian.
get_nonzero_mask_tensor (bool): Whether to get nonzero mask tensor.
"""
def __init__(self,
irreps_in_node: Union[int, str, o3.Irreps] = None,
irreps_in_edge: Union[int, str, o3.Irreps] = None,
nao_max: int = 14,
return_forces: bool = False,
create_graph: bool = False,
ham_type: str = 'openmx',
ham_only: bool = False,
symmetrize: bool = True,
include_triplet: bool = False,
calculate_band_energy: bool = False,
num_k: int = 8,
k_path: Union[list, np.ndarray, tuple] = None,
band_num_control: dict = None,
soc_switch: bool = True,
nonlinearity_type: str = 'gate',
export_reciprocal_values: bool = False,
add_H0: bool = False,
soc_basis: str = 'so3',
spin_constrained: bool = False,
use_learned_weight: bool = True,
minMagneticMoment: float = 0.5,
collinear_spin: bool = False,
zero_point_shift: bool = False,
add_H_nonsoc: bool = False,
get_nonzero_mask_tensor: bool = False,
calculate_sparsity: bool = True
):
super().__init__()
# Store force computation settings
self.derivative = return_forces
self.create_graph = create_graph
# Store model parameters
self.nao_max = nao_max
self.ham_type = ham_type.lower()
self.ham_only = ham_only
self.symmetrize = symmetrize
self.include_triplet = include_triplet
self.soc_switch = soc_switch
self.nonlinearity_type = nonlinearity_type
# Output and export settings
self.export_reciprocal_values = export_reciprocal_values
self.add_H0 = add_H0
# Spin-related parameters
self.spin_constrained = spin_constrained
self.use_learned_weight = use_learned_weight
self.min_magnetic_moment = minMagneticMoment
self.collinear_spin = collinear_spin
self.soc_basis = soc_basis.lower()
# Adjust SOC basis for non-openmx Hamiltonians
if soc_switch:
if self.ham_type != 'openmx':
self.soc_basis = 'su2'
# Band structure calculations
self.calculate_band_energy = calculate_band_energy
self.num_k = num_k
self.k_path = k_path
# Additional physics parameters
self.add_quartic = False
self.zero_point_shift = zero_point_shift
self.add_H_nonsoc = add_H_nonsoc
self.get_nonzero_mask_tensor = get_nonzero_mask_tensor
# Sparsity calculation flag
self.calculate_sparsity = calculate_sparsity
# Initialize configurations
self._configure_band_num_control(band_num_control)
self._initialize_basis_information()
self._initialize_irreducible_representations()
# Clebsch-Gordan coefficients
self.cg_calculator = ClebschGordanCoefficients(max_l=self.hamiltonian_irreps.lmax)
# Hamiltonian networks
self.onsite_hamiltonian_network = self._create_hamiltonian_layer(
irreps_in=irreps_in_node,
irreps_out=self.hamiltonian_irreps
)
self.offsite_hamiltonian_network = self._create_hamiltonian_layer(
irreps_in=irreps_in_edge,
irreps_out=self.hamiltonian_irreps
)
# Spin-orbit coupling networks
if soc_switch:
# Initialize based on selected basis
if self.soc_basis == 'su2':
self.onsite_hamiltonian_network = self._create_hamiltonian_layer(
irreps_in=irreps_in_node,
irreps_out=2*self.hamiltonian_irreps_su2
)
self.offsite_hamiltonian_network = self._create_hamiltonian_layer(
irreps_in=irreps_in_edge,
irreps_out=2*self.hamiltonian_irreps_su2
)
elif self.soc_basis == 'so3':
self.onsite_ksi_network = self._create_hamiltonian_layer(
irreps_in=irreps_in_node,
irreps_out=(self.nao_max**2*o3.Irreps("0e")).simplify()
)
self.offsite_ksi_network = self._create_hamiltonian_layer(
irreps_in=irreps_in_edge,
irreps_out=(self.nao_max**2*o3.Irreps("0e")).simplify()
)
else:
raise NotImplementedError(f"SOC basis '{soc_basis}' not supported!")
# Spin-constrained networks
if self.spin_constrained:
# J network
self.onsite_J_network = self._create_hamiltonian_layer(
irreps_in=irreps_in_node,
irreps_out=self.J_irreps
)
self.offsite_J_network = self._create_hamiltonian_layer(
irreps_in=irreps_in_edge,
irreps_out=self.J_irreps
)
# K network (if quartic term is enabled)
if self.add_quartic:
self.onsite_K_network = self._create_hamiltonian_layer(
irreps_in=irreps_in_node,
irreps_out=self.K_irreps
)
self.offsite_K_network = self._create_hamiltonian_layer(
irreps_in=irreps_in_edge,
irreps_out=self.K_irreps
)
# Weight networks
if self.use_learned_weight:
self.onsite_weight_network = self._create_hamiltonian_layer(
irreps_in=irreps_in_node,
irreps_out=self.hamiltonian_irreps
)
self.offsite_weight_network = self._create_hamiltonian_layer(
irreps_in=irreps_in_edge,
irreps_out=self.hamiltonian_irreps
)
# Overlap networks (if not Hamiltonian-only mode)
if not self.ham_only:
self.onsite_overlap_network = self._create_hamiltonian_layer(
irreps_in=irreps_in_node,
irreps_out=self.hamiltonian_irreps
)
self.offsite_overlap_network = self._create_hamiltonian_layer(
irreps_in=irreps_in_edge,
irreps_out=self.hamiltonian_irreps
)
def _initialize_irreducible_representations(self):
"""
Initialize the irreducible representations for the Hamiltonian and related tensors.
Sets up Hamiltonian irreps, their dimensions, SU(2) basis irreps (if spin-orbit
coupling is enabled), and J/K irreps (if spin-constrained mode is enabled).
"""
self.hamiltonian_irreps_dimensions = []
# Initialize main Hamiltonian irreps
self.hamiltonian_irreps = o3.Irreps()
for _, li in self.row:
for _, lj in self.col:
for L in range(abs(li.l-lj.l), li.l+lj.l+1):
self.hamiltonian_irreps += o3.Irrep(L, (-1)**(li.l+lj.l))
# Store dimensions for each irrep
for irrep in self.hamiltonian_irreps:
self.hamiltonian_irreps_dimensions.append(irrep.dim)
self.hamiltonian_irreps_dimensions = torch.LongTensor(self.hamiltonian_irreps_dimensions)
# SU(2) basis irreps (for spin-orbit coupling)
if self.soc_switch and (self.soc_basis == 'su2'):
out_js_list = []
for _, li in self.row:
for _, lj in self.col:
out_js_list.append((li.l, lj.l))
self.hamiltonian_decomposition = E3TensorDecomposition(
None,
out_js_list,
default_dtype_torch=torch.get_default_dtype(),
nao_max=self.nao_max,
spinful=True
)
self.hamiltonian_irreps_su2 = self.hamiltonian_decomposition.required_irreps_out
# Spin-constrained mode irreps
if self.spin_constrained:
# Initialize J and K irreps
self.J_irreps = o3.Irreps()
self.K_irreps = o3.Irreps()
self.J_irreps_dimensions = []
self.K_irreps_dimensions = []
self.num_blocks = 0
for _, li in self.row:
for _, lj in self.col:
self.num_blocks += 1
if self.soc_switch:
for L in range(0, 3):
self.J_irreps += o3.Irrep(L, 1) # t=1, p=1
self.K_irreps += o3.Irrep(0, 1) # t=1, p=1
else:
self.J_irreps += o3.Irrep(0, 1) # t=1, p=1
# Store dimensions for J and K irreps
for irrep in self.J_irreps:
self.J_irreps_dimensions.append(irrep.dim)
for irrep in self.K_irreps:
self.K_irreps_dimensions.append(irrep.dim)
self.J_irreps_dimensions = torch.LongTensor(self.J_irreps_dimensions)
self.K_irreps_dimensions = torch.LongTensor(self.K_irreps_dimensions)
def _initialize_basis_information(self):
"""
Initialize basis information based on the selected Hamiltonian type.
Configures basis sets for different Hamiltonian types:
- openmx: Basis sets for OpenMX DFT code
- siesta: Basis sets for SIESTA DFT code
- abacus: Basis sets for ABACUS DFT code
- pasp: Simplified basis set
"""
if self.ham_type == 'openmx':
self._initialize_openmx_basis()
elif self.ham_type == 'siesta':
self._initialize_siesta_basis()
elif self.ham_type == 'abacus':
self._initialize_abacus_basis()
elif self.ham_type == 'pasp':
self.row = self.col = o3.Irreps("1x1o")
else:
raise NotImplementedError(f"Hamiltonian type '{self.ham_type}' is not supported.")
def _initialize_openmx_basis(self):
"""
Sets basis information for 'openmx' Hamiltonian.
"""
self.num_valence = {Element['H'].Z: 1, Element['He'].Z: 2, Element['Li'].Z: 3, Element['Be'].Z: 2, Element['B'].Z: 3,
Element['C'].Z: 4, Element['N'].Z: 5, Element['O'].Z: 6, Element['F'].Z: 7, Element['Ne'].Z: 8,
Element['Na'].Z: 9, Element['Mg'].Z: 8, Element['Al'].Z: 3, Element['Si'].Z: 4, Element['P'].Z: 5,
Element['S'].Z: 6, Element['Cl'].Z: 7, Element['Ar'].Z: 8, Element['K'].Z: 9, Element['Ca'].Z: 10,
Element['Sc'].Z: 11, Element['Ti'].Z: 12, Element['V'].Z: 13, Element['Cr'].Z: 14, Element['Mn'].Z: 15,
Element['Fe'].Z: 16, Element['Co'].Z: 17, Element['Ni'].Z: 18, Element['Cu'].Z: 19, Element['Zn'].Z: 20,
Element['Ga'].Z: 13, Element['Ge'].Z: 4, Element['As'].Z: 15, Element['Se'].Z: 6, Element['Br'].Z: 7,
Element['Kr'].Z: 8, Element['Rb'].Z: 9, Element['Sr'].Z: 10, Element['Y'].Z: 11, Element['Zr'].Z: 12,
Element['Nb'].Z: 13, Element['Mo'].Z: 14, Element['Tc'].Z: 15, Element['Ru'].Z: 14, Element['Rh'].Z: 15,
Element['Pd'].Z: 16, Element['Ag'].Z: 17, Element['Cd'].Z: 12, Element['In'].Z: 13, Element['Sn'].Z: 14,
Element['Sb'].Z: 15, Element['Te'].Z: 16, Element['I'].Z: 7, Element['Xe'].Z: 8, Element['Cs'].Z: 9,
Element['Ba'].Z: 10, Element['La'].Z: 11, Element['Ce'].Z: 12, Element['Pr'].Z: 13, Element['Nd'].Z: 14,
Element['Pm'].Z: 15, Element['Sm'].Z: 16, Element['Dy'].Z: 20, Element['Ho'].Z: 21, Element['Lu'].Z: 11,
Element['Hf'].Z: 12, Element['Ta'].Z: 13, Element['W'].Z: 12, Element['Re'].Z: 15, Element['Os'].Z: 14,
Element['Ir'].Z: 15, Element['Pt'].Z: 16, Element['Au'].Z: 17, Element['Hg'].Z: 18, Element['Tl'].Z: 19,
Element['Pb'].Z: 14, Element['Bi'].Z: 15
}
if self.nao_max == 14:
self.index_change = torch.LongTensor([0,1,2,5,3,4,8,6,7,11,13,9,12,10])
self.row = self.col = o3.Irreps("1x0e+1x0e+1x0e+1x1o+1x1o+1x2e")
self.basis_def = { 1:[0,1,3,4,5], # H
2:[0,1,3,4,5], # He
3:[0,1,2,3,4,5,6,7,8], # Li
4:[0,1,3,4,5,6,7,8], # Be
5:[0,1,3,4,5,6,7,8,9,10,11,12,13], # B
6:[0,1,3,4,5,6,7,8,9,10,11,12,13], # C
7:[0,1,3,4,5,6,7,8,9,10,11,12,13], # N
8:[0,1,3,4,5,6,7,8,9,10,11,12,13], # O
9:[0,1,3,4,5,6,7,8,9,10,11,12,13], # F
10:[0,1,3,4,5,6,7,8,9,10,11,12,13], # Ne
11:[0,1,2,3,4,5,6,7,8,9,10,11,12,13], # Na
12:[0,1,2,3,4,5,6,7,8,9,10,11,12,13], # Mg
13:[0,1,3,4,5,6,7,8,9,10,11,12,13], # Al
14:[0,1,3,4,5,6,7,8,9,10,11,12,13], # Si
15:[0,1,3,4,5,6,7,8,9,10,11,12,13], # p
16:[0,1,3,4,5,6,7,8,9,10,11,12,13], # S
17:[0,1,3,4,5,6,7,8,9,10,11,12,13], # Cl
18:[0,1,3,4,5,6,7,8,9,10,11,12,13], # Ar
19:[0,1,2,3,4,5,6,7,8,9,10,11,12,13], # K
20:[0,1,2,3,4,5,6,7,8,9,10,11,12,13], # Ca
35:[0,1,2,3,4,5,6,7,8,9,10,11,12,13], # Br
Element['V'].Z: [0,1,2,3,4,5,6,7,8,9,10,11,12,13], # V
Element['Mn'].Z: [0,1,2,3,4,5,6,7,8,9,10,11,12,13], # Mn
}
elif self.nao_max == 13:
self.basis_def = { 1:[0,1,2,3,4], # H
5:[0,1,2,3,4,5,6,7,8,9,10,11,12], # B
6:[0,1,2,3,4,5,6,7,8,9,10,11,12], # C
7:[0,1,2,3,4,5,6,7,8,9,10,11,12], # N
8:[0,1,2,3,4,5,6,7,8,9,10,11,12] # O
}
self.index_change = torch.LongTensor([0,1,4,2,3,7,5,6,10,12,8,11,9])
self.row = self.col = o3.Irreps("1x0e+1x0e+1x1o+1x1o+1x2e")
elif self.nao_max == 19:
self.index_change = torch.LongTensor([0,1,2,5,3,4,8,6,7,11,13,9,12,10,16,18,14,17,15])
self.row = self.col = o3.Irreps("1x0e+1x0e+1x0e+1x1o+1x1o+1x2e+1x2e")
self.basis_def = { 1:[0,1,3,4,5], # H
2:[0,1,3,4,5], # He
3:[0,1,2,3,4,5,6,7,8], # Li
4:[0,1,3,4,5,6,7,8], # Be
5:[0,1,3,4,5,6,7,8,9,10,11,12,13], # B
6:[0,1,3,4,5,6,7,8,9,10,11,12,13], # C
7:[0,1,3,4,5,6,7,8,9,10,11,12,13], # N
8:[0,1,3,4,5,6,7,8,9,10,11,12,13], # O
9:[0,1,3,4,5,6,7,8,9,10,11,12,13], # F
10:[0,1,3,4,5,6,7,8,9,10,11,12,13], # Ne
11:[0,1,2,3,4,5,6,7,8,9,10,11,12,13], # Na
12:[0,1,2,3,4,5,6,7,8,9,10,11,12,13], # Mg
13:[0,1,3,4,5,6,7,8,9,10,11,12,13], # Al
14:[0,1,3,4,5,6,7,8,9,10,11,12,13], # Si
15:[0,1,3,4,5,6,7,8,9,10,11,12,13], # p
16:[0,1,3,4,5,6,7,8,9,10,11,12,13], # S
17:[0,1,3,4,5,6,7,8,9,10,11,12,13], # Cl
18:[0,1,3,4,5,6,7,8,9,10,11,12,13], # Ar
19:[0,1,2,3,4,5,6,7,8,9,10,11,12,13], # K
20:[0,1,2,3,4,5,6,7,8,9,10,11,12,13], # Ca
25:[0,1,2,3,4,5,6,7,8,9,10,11,12,13], # Mn
42:[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18], # Mo
83:[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18], # Bi
34:[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18], # Se
24:[0,1,2,3,4,5,6,7,8,9,10,11,12,13], # Cr
53:[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18], # I
28:[0,1,2,3,4,5,6,7,8,9,10,11,12,13], # Ni
35:[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18], # Br
26:[0,1,2,3,4,5,6,7,8,9,10,11,12,13], # Fe
77:[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18], # Ir
52:[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18], # Te
Element['V'].Z: [0,1,2,3,4,5,6,7,8,9,10,11,12,13], # V
Element['Sb'].Z: [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18], # Sb
}
elif self.nao_max == 26:
self.index_change = torch.LongTensor([0,1,2,5,3,4,8,6,7,11,13,9,12,10,16,18,14,17,15,22,23,21,24,20,25,19])
self.row = self.col = o3.Irreps("1x0e+1x0e+1x0e+1x1o+1x1o+1x2e+1x2e+1x3o")
self.basis_def = (lambda s1=[0],s2=[1],s3=[2],p1=[3,4,5],p2=[6,7,8],d1=[9,10,11,12,13],d2=[14,15,16,17,18],f1=[19,20,21,22,23,24,25]: {
Element['H'].Z : s1+s2+p1, # H6.0-s2p1
Element['He'].Z : s1+s2+p1, # He8.0-s2p1
Element['Li'].Z : s1+s2+s3+p1+p2, # Li8.0-s3p2
Element['Be'].Z : s1+s2+p1+p2, # Be7.0-s2p2
Element['B'].Z : s1+s2+p1+p2+d1, # B7.0-s2p2d1
Element['C'].Z : s1+s2+p1+p2+d1, # C6.0-s2p2d1
Element['N'].Z : s1+s2+p1+p2+d1, # N6.0-s2p2d1
Element['O'].Z : s1+s2+p1+p2+d1, # O6.0-s2p2d1
Element['F'].Z : s1+s2+p1+p2+d1, # F6.0-s2p2d1
Element['Ne'].Z: s1+s2+p1+p2+d1, # Ne9.0-s2p2d1
Element['Na'].Z: s1+s2+s3+p1+p2+d1, # Na9.0-s3p2d1
Element['Mg'].Z: s1+s2+s3+p1+p2+d1, # Mg9.0-s3p2d1
Element['Al'].Z: s1+s2+p1+p2+d1, # Al7.0-s2p2d1
Element['Si'].Z: s1+s2+p1+p2+d1, # Si7.0-s2p2d1
Element['P'].Z: s1+s2+p1+p2+d1, # P7.0-s2p2d1
Element['S'].Z: s1+s2+p1+p2+d1, # S7.0-s2p2d1
Element['Cl'].Z: s1+s2+p1+p2+d1, # Cl7.0-s2p2d1
Element['Ar'].Z: s1+s2+p1+p2+d1, # Ar9.0-s2p2d1
Element['K'].Z: s1+s2+s3+p1+p2+d1, # K10.0-s3p2d1
Element['Ca'].Z: s1+s2+s3+p1+p2+d1, # Ca9.0-s3p2d1
Element['Sc'].Z: s1+s2+s3+p1+p2+d1, # Sc9.0-s3p2d1
Element['Ti'].Z: s1+s2+s3+p1+p2+d1, # Ti7.0-s3p2d1
Element['V'].Z: s1+s2+s3+p1+p2+d1, # V6.0-s3p2d1
Element['Cr'].Z: s1+s2+s3+p1+p2+d1, # Cr6.0-s3p2d1
Element['Mn'].Z: s1+s2+s3+p1+p2+d1, # Mn6.0-s3p2d1
Element['Fe'].Z: s1+s2+s3+p1+p2+d1, # Fe5.5H-s3p2d1
Element['Co'].Z: s1+s2+s3+p1+p2+d1, # Co6.0H-s3p2d1
Element['Ni'].Z: s1+s2+s3+p1+p2+d1, # Ni6.0H-s3p2d1
Element['Cu'].Z: s1+s2+s3+p1+p2+d1, # Cu6.0H-s3p2d1
Element['Zn'].Z: s1+s2+s3+p1+p2+d1, # Zn6.0H-s3p2d1
Element['Ga'].Z: s1+s2+s3+p1+p2+d1+d2, # Ga7.0-s3p2d2
Element['Ge'].Z: s1+s2+s3+p1+p2+d1+d2, # Ge7.0-s3p2d2
Element['As'].Z: s1+s2+s3+p1+p2+d1+d2, # As7.0-s3p2d2
Element['Se'].Z: s1+s2+s3+p1+p2+d1+d2, # Se7.0-s3p2d2
Element['Br'].Z: s1+s2+s3+p1+p2+d1+d2, # Br7.0-s3p2d2
Element['Kr'].Z: s1+s2+s3+p1+p2+d1+d2, # Kr10.0-s3p2d2
Element['Rb'].Z: s1+s2+s3+p1+p2+d1+d2, # Rb11.0-s3p2d2
Element['Sr'].Z: s1+s2+s3+p1+p2+d1+d2, # Sr10.0-s3p2d2
Element['Y'].Z: s1+s2+s3+p1+p2+d1+d2, # Y10.0-s3p2d2
Element['Zr'].Z: s1+s2+s3+p1+p2+d1+d2, # Zr7.0-s3p2d2
Element['Nb'].Z: s1+s2+s3+p1+p2+d1+d2, # Nb7.0-s3p2d2
Element['Mo'].Z: s1+s2+s3+p1+p2+d1+d2, # Mo7.0-s3p2d2
Element['Tc'].Z: s1+s2+s3+p1+p2+d1+d2, # Tc7.0-s3p2d2
Element['Ru'].Z: s1+s2+s3+p1+p2+d1+d2, # Ru7.0-s3p2d2
Element['Rh'].Z: s1+s2+s3+p1+p2+d1+d2, # Rh7.0-s3p2d2
Element['Pd'].Z: s1+s2+s3+p1+p2+d1+d2, # Pd7.0-s3p2d2
Element['Ag'].Z: s1+s2+s3+p1+p2+d1+d2, # Ag7.0-s3p2d2
Element['Cd'].Z: s1+s2+s3+p1+p2+d1+d2, # Cd7.0-s3p2d2
Element['In'].Z: s1+s2+s3+p1+p2+d1+d2, # In7.0-s3p2d2
Element['Sn'].Z: s1+s2+s3+p1+p2+d1+d2, # Sn7.0-s3p2d2
Element['Sb'].Z: s1+s2+s3+p1+p2+d1+d2, # Sb7.0-s3p2d2
Element['Te'].Z: s1+s2+s3+p1+p2+d1+d2+f1, # Te7.0-s3p2d2f1
Element['I'].Z: s1+s2+s3+p1+p2+d1+d2+f1, # I7.0-s3p2d2f1
Element['Xe'].Z: s1+s2+s3+p1+p2+d1+d2, # Xe11.0-s3p2d2
Element['Cs'].Z: s1+s2+s3+p1+p2+d1+d2, # Cs12.0-s3p2d2
Element['Ba'].Z: s1+s2+s3+p1+p2+d1+d2, # Ba10.0-s3p2d2
Element['La'].Z: s1+s2+s3+p1+p2+d1+d2+f1, # La8.0-s3p2d2f1
Element['Ce'].Z: s1+s2+s3+p1+p2+d1+d2+f1, # Ce8.0-s3p2d2f1
Element['Pr'].Z: s1+s2+s3+p1+p2+d1+d2+f1, # Pr8.0-s3p2d2f1
Element['Nd'].Z: s1+s2+s3+p1+p2+d1+d2+f1, # Nd8.0-s3p2d2f1
Element['Pm'].Z: s1+s2+s3+p1+p2+d1+d2+f1, # Pm8.0-s3p2d2f1
Element['Sm'].Z: s1+s2+s3+p1+p2+d1+d2+f1, # Sm8.0-s3p2d2f1
Element['Dy'].Z: s1+s2+s3+p1+p2+d1+d2+f1, # Dy8.0-s3p2d2f1
Element['Ho'].Z: s1+s2+s3+p1+p2+d1+d2+f1, # Ho8.0-s3p2d2f1
Element['Lu'].Z: s1+s2+s3+p1+p2+d1+d2+f1, # Lu8.0-s3p2d2f1
Element['Hf'].Z: s1+s2+s3+p1+p2+d1+d2+f1, # Hf9.0-s3p2d2f1
Element['Ta'].Z: s1+s2+s3+p1+p2+d1+d2+f1, # Ta7.0-s3p2d2f1
Element['W'].Z: s1+s2+s3+p1+p2+d1+d2+f1, # W7.0-s3p2d2f1
Element['Re'].Z: s1+s2+s3+p1+p2+d1+d2+f1, # Re7.0-s3p2d2f1
Element['Os'].Z: s1+s2+s3+p1+p2+d1+d2+f1, # Os7.0-s3p2d2f1
Element['Ir'].Z: s1+s2+s3+p1+p2+d1+d2+f1, # Ir7.0-s3p2d2f1
Element['Pt'].Z: s1+s2+s3+p1+p2+d1+d2+f1, # Pt7.0-s3p2d2f1
Element['Au'].Z: s1+s2+s3+p1+p2+d1+d2+f1, # Au7.0-s3p2d2f1
Element['Hg'].Z: s1+s2+s3+p1+p2+d1+d2+f1, # Hg8.0-s3p2d2f1
Element['Tl'].Z: s1+s2+s3+p1+p2+d1+d2+f1, # Tl8.0-s3p2d2f1
Element['Pb'].Z: s1+s2+s3+p1+p2+d1+d2+f1, # Pb8.0-s3p2d2f1
Element['Bi'].Z: s1+s2+s3+p1+p2+d1+d2+f1, # Bi8.0-s3p2d2f1
})()
else:
raise NotImplementedError(f"NAO max '{self.nao_max}' not supported for 'openmx'.")
def _initialize_siesta_basis(self):
"""
Sets basis information for 'siesta' Hamiltonian.
"""
self.num_valence = {
1:1,2:2,
3:1,4:2,5:3,6:4,7:5,8:6,9:7,10:8,
11:1,12:2,13:3,14:4,15:5,16:6,17:7,18:8,
19:1,20:2,22:12,31:3,33:5,72:4
}
if self.nao_max == 13:
self.index_change = None
self.row = self.col = o3.Irreps("1x0e+1x0e+1x1o+1x1o+1x2e")
self.minus_index = torch.LongTensor([2,4,5,7,9,11]) # this list should follow the order in siesta. See spher_harm.f
self.basis_def = (lambda s1=[0],s2=[1],p1=[2,3,4],p2=[5,6,7],d1=[8,9,10,11,12]: {
1 : s1+s2+p1, # H
2 : s1+s2+p1, # He
3 : s1+s2+p1, # Li
4 : s1+s2+p1, # Be
5 : s1+s2+p1+p2+d1, # B
6 : s1+s2+p1+p2+d1, # C
7 : s1+s2+p1+p2+d1, # N
8 : s1+s2+p1+p2+d1, # O
9 : s1+s2+p1+p2+d1, # F
10: s1+s2+p1+p2+d1, # Ne
11: s1+s2+p1, # Na
12: s1+s2+p1, # Mg
13: s1+s2+p1+p2+d1, # Al
14: s1+s2+p1+p2+d1, # Si
15: s1+s2+p1+p2+d1, # P
16: s1+s2+p1+p2+d1, # S
17: s1+s2+p1+p2+d1, # Cl
18: s1+s2+p1+p2+d1, # Ar
19: s1+s2+p1, # K
20: s1+s2+p1, # Cl
31: s1+s2+p1+p2+d1, # Ga
33: s1+s2+p1+p2+d1, # As
})()
elif self.nao_max == 19:
self.index_change = None
self.row = self.col = o3.Irreps("1x0e+1x0e+1x0e+1x1o+1x1o+1x2e+1x2e")
self.minus_index = torch.LongTensor([3,5,6,8,10,12,15,17]) # this list should follow the order in siesta. See spher_harm.f
self.basis_def = (lambda s1=[0],s2=[1],s3=[2],p1=[3,4,5],p2=[6,7,8],d1=[9,10,11,12,13],d2=[14,15,16,17,18]: {
1 : s1+s2+p1, # H
2 : s1+s2+p1, # He
3 : s1+s2+p1, # Li
4 : s1+s2+p1, # Be
5 : s1+s2+p1+p2+d1, # B
6 : s1+s2+p1+p2+d1, # C
7 : s1+s2+p1+p2+d1, # N
8 : s1+s2+p1+p2+d1, # O
9 : s1+s2+p1+p2+d1, # F
10: s1+s2+p1+p2+d1, # Ne
11: s1+s2+p1, # Na
12: s1+s2+p1, # Mg
13: s1+s2+p1+p2+d1, # Al
14: s1+s2+p1+p2+d1, # Si
15: s1+s2+p1+p2+d1, # P
16: s1+s2+p1+p2+d1, # S
17: s1+s2+p1+p2+d1, # Cl
18: s1+s2+p1+p2+d1, # Ar
19: s1+s2+p1, # K
20: s1+s2+p1, # Cl
22: s1+s2+s3+p1+p2+d1+d2, # Ti, created by Qin.
})()
else:
raise NotImplementedError(f"NAO max '{self.nao_max}' not supported for 'siesta'.")
def _initialize_abacus_basis(self):
"""
Sets basis information for 'abacus' Hamiltonian.
"""
self.num_valence = {1: 1, 2: 2,
3: 3, 4: 4,
5: 3, 6: 4,
7: 5, 8: 6,
9: 7, 10: 8,
11: 9, 12: 10,
13: 11, 14: 4,
15: 5, 16: 6,
17: 7, 18: 8,
19: 9, 20: 10,
21: 11, 22: 12,
23: 13, 24: 14,
25: 15, 26: 16,
27: 17, 28: 18,
29: 19, 30: 20,
31: 13, 32: 14,
33: 5, 34: 6,
35: 7, 36: 8,
37: 9, 38: 10,
39: 11, 40: 12,
41: 13, 42: 14,
43: 15, 44: 16,
45: 17, 46: 18,
47: 19, 48: 20,
49: 13, 50: 14,
51: 15, 52: 16,
53: 17, 54: 18,
55: 9, 56: 10,
57: 11, 72: 26,
73: 27, 74: 28,
75: 15, 76: 16,
77: 17, 78: 18,
79: 19, 80: 20,
81: 13, 82: 14,
83: 15}
if self.nao_max == 13:
self.index_change = torch.LongTensor([0,1,3,4,2,6,7,5,10,11,9,12,8])
self.row = self.col = o3.Irreps("1x0e+1x0e+1x1o+1x1o+1x2e")
self.minus_index = torch.LongTensor([3,4,6,7,9,10])
self.basis_def = (lambda s1=[0],s2=[1],p1=[2,3,4],p2=[5,6,7],d1=[8,9,10,11,12]: {
1 : np.array(s1+s2+p1, dtype=int), # H
2 : np.array(s1+s2+p1, dtype=int), # He
5 : np.array(s1+s2+p1+p2+d1, dtype=int), # B
6 : np.array(s1+s2+p1+p2+d1, dtype=int), # C
7 : np.array(s1+s2+p1+p2+d1, dtype=int), # N
8 : np.array(s1+s2+p1+p2+d1, dtype=int), # O
9 : np.array(s1+s2+p1+p2+d1, dtype=int), # F
10: np.array(s1+s2+p1+p2+d1, dtype=int), # Ne
14: np.array(s1+s2+p1+p2+d1, dtype=int), # Si
15: np.array(s1+s2+p1+p2+d1, dtype=int), # P
16: np.array(s1+s2+p1+p2+d1, dtype=int), # S
17: np.array(s1+s2+p1+p2+d1, dtype=int), # Cl
18: np.array(s1+s2+p1+p2+d1, dtype=int), # Ar
})()
elif self.nao_max == 27:
self.index_change = torch.LongTensor([0,1,2,3,5,6,4,8,9,7,12,13,11,14,10,17,18,16,19,15,23,24,22,25,21,26,20])
self.row = self.col = o3.Irreps("1x0e+1x0e+1x0e+1x0e+1x1o+1x1o+1x2e+1x2e+1x3o")
self.minus_index = torch.LongTensor([5,6,8,9,11,12,16,17,21,22,25,26]) # this list should follow the order in abacus.
self.basis_def = (lambda s1=[0],s2=[1],s3=[2],s4=[3],p1=[4,5,6],p2=[7,8,9],d1=[10,11,12,13,14],d2=[15,16,17,18,19],f1=[20,21,22,23,24,25,26]: {
1 : s1+s2+p1, # H
2 : s1+s2+p1, # He
3 : s1+s2+s3+s4+p1, # Li
4 : s1+s2+s3+s4+p1, # Bi
5 : s1+s2+p1+p2+d1, # B
6 : s1+s2+p1+p2+d1, # C
7 : s1+s2+p1+p2+d1, # N
8 : s1+s2+p1+p2+d1, # O
9 : s1+s2+p1+p2+d1, # F
10: s1+s2+p1+p2+d1, # Ne
11: s1+s2+s3+s4+p1+p2+d1, # Na
12: s1+s2+s3+s4+p1+p2+d1, # Mg
# 13: Al
14: s1+s2+p1+p2+d1, # Si
15: s1+s2+p1+p2+d1, # P
16: s1+s2+p1+p2+d1, # S
17: s1+s2+p1+p2+d1, # Cl
18: s1+s2+p1+p2+d1, # Ar
19: s1+s2+s3+s4+p1+p2+d1, # K
20: s1+s2+s3+s4+p1+p2+d1, # Ca
21: s1+s2+s3+s4+p1+p2+d1+d2+f1, # Sc
22: s1+s2+s3+s4+p1+p2+d1+d2+f1, # Ti
23: s1+s2+s3+s4+p1+p2+d1+d2+f1, # V
24: s1+s2+s3+s4+p1+p2+d1+d2+f1, # Cr
25: s1+s2+s3+s4+p1+p2+d1+d2+f1, # Mn
26: s1+s2+s3+s4+p1+p2+d1+d2+f1, # Fe
27: s1+s2+s3+s4+p1+p2+d1+d2+f1, # Co
28: s1+s2+s3+s4+p1+p2+d1+d2+f1, # Ni
29: s1+s2+s3+s4+p1+p2+d1+d2+f1, # Cu
30: s1+s2+s3+s4+p1+p2+d1+d2+f1, # Zn
31: s1+s2+p1+p2+d1+d2+f1, # Ga
32: s1+s2+p1+p2+d1+d2+f1, # Ge
33: s1+s2+p1+p2+d1, # As
34: s1+s2+p1+p2+d1, # Se
35: s1+s2+p1+p2+d1, # Br
36: s1+s2+p1+p2+d1, # Kr
37: s1+s2+s3+s4+p1+p2+d1, # Rb
38: s1+s2+s3+s4+p1+p2+d1, # Sr
39: s1+s2+s3+s4+p1+p2+d1+d2+f1, # Y
40: s1+s2+s3+s4+p1+p2+d1+d2+f1, # Zr
41: s1+s2+s3+s4+p1+p2+d1+d2+f1, # Nb
42: s1+s2+s3+s4+p1+p2+d1+d2+f1, # Mo
43: s1+s2+s3+s4+p1+p2+d1+d2+f1, # Tc
44: s1+s2+s3+s4+p1+p2+d1+d2+f1, # Ru
45: s1+s2+s3+s4+p1+p2+d1+d2+f1, # Rh
46: s1+s2+s3+s4+p1+p2+d1+d2+f1, # Pd
47: s1+s2+s3+s4+p1+p2+d1+d2+f1, # Ag
48: s1+s2+s3+s4+p1+p2+d1+d2+f1, # Cd
49: s1+s2+p1+p2+d1+d2+f1, # In
50: s1+s2+p1+p2+d1+d2+f1, # Sn
51: s1+s2+p1+p2+d1+d2+f1, # Sb
52: s1+s2+p1+p2+d1+d2+f1, # Te
53: s1+s2+p1+p2+d1+d2+f1, # I
54: s1+s2+p1+p2+d1+d2+f1, # Xe
55: s1+s2+s3+s4+p1+p2+d1, # Cs
56: s1+s2+s3+s4+p1+p2+d1+d2+f1, # Ba
#
79: s1+s2+s3+s4+p1+p2+d1+d2+f1, # Au
80: s1+s2+s3+s4+p1+p2+d1+d2+f1, # Hg
81: s1+s2+p1+p2+d1+d2+f1, # Tl
82: s1+s2+p1+p2+d1+d2+f1, # Pb
83: s1+s2+p1+p2+d1+d2+f1, # Bi
})()
elif self.nao_max == 40:
self.index_change = torch.LongTensor([0,1,2,3,5,6,4,8,9,7,11,12,10,14,15,13,18,19,17,20,16,23,24,22,25,21,29,30,28,31,27,32,26,36,37,35,38,34,39,33])
self.row = self.col = o3.Irreps("1x0e+1x0e+1x0e+1x0e+1x1o+1x1o+1x1o+1x1o+1x2e+1x2e+1x3o+1x3o")
self.minus_index = torch.LongTensor([5,6,8,9,11,12,14,15,17,18,22,23,27,28,31,32,34,35,38,39]) # this list should follow the order in abacus.
self.basis_def = (lambda s1=[0],
s2=[1],
s3=[2],
s4=[3],
p1=[4,5,6],
p2=[7,8,9],
p3=[10,11,12],
p4=[13,14,15],
d1=[16,17,18,19,20],
d2=[21,22,23,24,25],
f1=[26,27,28,29,30,31,32],
f2=[33,34,35,36,37,38,39]: {
Element('Ag').Z: s1+s2+s3+s4+p1+p2+d1+d2+f1,
Element('Al').Z: s1+s2+s3+s4+p1+p2+p3+p4+d1,
Element('Ar').Z: s1+s2+p1+p2+d1,
Element('As').Z: s1+s2+p1+p2+d1,
Element('Au').Z: s1+s2+s3+s4+p1+p2+d1+d2+f1,
Element('Ba').Z: s1+s2+s3+s4+p1+p2+d1+d2+f1,
Element('Be').Z: s1+s2+s3+s4+p1,
Element('B').Z: s1+s2+p1+p2+d1,
Element('Bi').Z: s1+s2+p1+p2+d1+d2+f1,
Element('Br').Z: s1+s2+p1+p2+d1,
Element('Ca').Z: s1+s2+s3+s4+p1+p2+d1,
Element('Cd').Z: s1+s2+s3+s4+p1+p2+d1+d2+f1,
Element('C').Z: s1+s2+p1+p2+d1,
Element('Cl').Z: s1+s2+p1+p2+d1,
Element('Co').Z: s1+s2+s3+s4+p1+p2+d1+d2+f1,
Element('Cr').Z: s1+s2+s3+s4+p1+p2+d1+d2+f1,
Element('Cs').Z: s1+s2+s3+s4+p1+p2+d1,
Element('Cu').Z: s1+s2+s3+s4+p1+p2+d1+d2+f1,
Element('Fe').Z: s1+s2+s3+s4+p1+p2+d1+d2+f1,
Element('F').Z: s1+s2+p1+p2+d1,
Element('Ga').Z: s1+s2+p1+p2+d1+d2+f1,
Element('Ge').Z: s1+s2+p1+p2+d1+d2+f1,
Element('He').Z: s1+s2+p1,
Element('Hf').Z: s1+s2+s3+s4+p1+p2+d1+d2+f1+f2, # Hf_gga_10au_100Ry_4s2p2d2f.orb
Element('H').Z: s1+s2+p1,
Element('Hg').Z: s1+s2+s3+s4+p1+p2+d1+d2+f1,
Element('I').Z: s1+s2+p1+p2+d1+d2+f1,
Element('In').Z: s1+s2+p1+p2+d1+d2+f1,
Element('Ir').Z: s1+s2+s3+s4+p1+p2+d1+d2+f1,
Element('K').Z: s1+s2+s3+s4+p1+p2+d1,
Element('Kr').Z: s1+s2+p1+p2+d1,
Element('Li').Z: s1+s2+s3+s4+p1,
Element('Mg').Z: s1+s2+s3+s4+p1+p2+d1,
Element('Mn').Z: s1+s2+s3+s4+p1+p2+d1+d2+f1,
Element('Mo').Z: s1+s2+s3+s4+p1+p2+d1+d2+f1,
Element('Na').Z: s1+s2+s3+s4+p1+p2+d1,
Element('Nb').Z: s1+s2+s3+s4+p1+p2+d1+d2+f1,
Element('Ne').Z: s1+s2+p1+p2+d1,
Element('N').Z: s1+s2+p1+p2+d1,
Element('Ni').Z: s1+s2+s3+s4+p1+p2+d1+d2+f1,
Element('O').Z: s1+s2+p1+p2+d1,
Element('Os').Z: s1+s2+s3+s4+p1+p2+d1+d2+f1,
Element('Pb').Z: s1+s2+p1+p2+d1+d2+f1,
Element('Pd').Z: s1+s2+s3+s4+p1+p2+d1+d2+f1,
Element('P').Z: s1+s2+p1+p2+d1,
Element('Pt').Z: s1+s2+s3+s4+p1+p2+d1+d2+f1,
Element('Rb').Z: s1+s2+s3+s4+p1+p2+d1,
Element('Re').Z: s1+s2+s3+s4+p1+p2+d1+d2+f1,
Element('Rh').Z: s1+s2+s3+s4+p1+p2+d1+d2+f1,
Element('Ru').Z: s1+s2+s3+s4+p1+p2+d1+d2+f1,
Element('Sb').Z: s1+s2+p1+p2+d1+d2+f1,
Element('Sc').Z: s1+s2+s3+s4+p1+p2+d1+d2+f1,
Element('Se').Z: s1+s2+p1+p2+d1,
Element('S').Z: s1+s2+p1+p2+d1,
Element('Si').Z: s1+s2+p1+p2+d1,
Element('Sn').Z: s1+s2+p1+p2+d1+d2+f1,
Element('Sr').Z: s1+s2+s3+s4+p1+p2+d1,
Element('Ta').Z: s1+s2+s3+s4+p1+p2+d1+d2+f1+f2, # Ta_gga_10au_100Ry_4s2p2d2f.orb
Element('Tc').Z: s1+s2+s3+s4+p1+p2+d1+d2+f1,
Element('Te').Z: s1+s2+p1+p2+d1+d2+f1,
Element('Ti').Z: s1+s2+s3+s4+p1+p2+d1+d2+f1,
Element('Tl').Z: s1+s2+p1+p2+d1+d2+f1,
Element('V').Z: s1+s2+s3+s4+p1+p2+d1+d2+f1,
Element('W').Z: s1+s2+s3+s4+p1+p2+d1+d2+f1+f2, # W_gga_10au_100Ry_4s2p2d2f.orb
Element('Xe').Z: s1+s2+p1+p2+d1+d2+f1,
Element('Y').Z: s1+s2+s3+s4+p1+p2+d1+d2+f1,
Element('Zn').Z: s1+s2+s3+s4+p1+p2+d1+d2+f1,
Element('Zr').Z: s1+s2+s3+s4+p1+p2+d1+d2+f1,
})()
else:
raise NotImplementedError(f"NAO max '{self.nao_max}' not supported for 'abacus'.")
def _configure_band_num_control(self, band_num_control):
"""
Configure band number control settings.
Args:
band_num_control (dict or int): Band number control configuration.
If dict, maps atomic numbers to band counts.
If int, specifies a global band count.
"""
if band_num_control is not None and not self.export_reciprocal_values:
if isinstance(band_num_control, dict):
# Convert atomic numbers to integers for consistency
self.band_num_control = {int(k): v for k, v in band_num_control.items()}
elif isinstance(band_num_control, int):
self.band_num_control = band_num_control
else:
self.band_num_control = None
else:
self.band_num_control = None
def _create_hamiltonian_layer(self, irreps_in, irreps_out):
"""
Create a Hamiltonian neural network layer.
Args:
irreps_in (o3.Irreps): Input irreducible representations.
irreps_out (o3.Irreps): Output irreducible representations.
Returns:
HamLayer: A neural network layer for Hamiltonian prediction.
"""
return HamLayer(
irreps_in=irreps_in,
feature_irreps_hidden=irreps_in,
irreps_out=irreps_out,
nonlinearity_type=self.nonlinearity_type,
resnet=True
)
[docs]
def merge_tensor_components(self, spherical_components):
"""
Merge spherical tensor components into a matrix using Clebsch-Gordan coefficients.
Args:
spherical_components (list of torch.Tensor): List of tensors representing spherical
components of irreducible representations. Each tensor has shape
(batch_size, component_dim).
Returns:
torch.Tensor: Merged matrix with shape (batch_size, nao_max * nao_max) representing
flattened matrix of merged components.
"""
batch_size = spherical_components[0].shape[0]
result_matrix = torch.zeros(batch_size, self.nao_max, self.nao_max).type_as(spherical_components[0])
component_index = 0 # Index for accessing the correct irreps
row_start_idx = 0
for _, row_orbital in self.row:
row_dimension = 2 * row_orbital.l + 1 # Dimension based on angular momentum
col_start_idx = 0
for _, col_orbital in self.col:
col_dimension = 2 * col_orbital.l + 1 # Dimension based on angular momentum
# Iterate through allowed angular momentum values
for angular_momentum in range(abs(row_orbital.l - col_orbital.l), row_orbital.l + col_orbital.l + 1):
# Compute inverse spherical tensor product
clebsch_gordan_coef = math.sqrt(2 * angular_momentum + 1) * self.cg_calculator(
row_orbital.l, col_orbital.l, angular_momentum
).unsqueeze(0)
tensor_product = (clebsch_gordan_coef * spherical_components[component_index].unsqueeze(-2).unsqueeze(-2)).sum(-1)
# Add product to appropriate part of the result matrix
submatrix = result_matrix.narrow(-2, row_start_idx, row_dimension).narrow(-1, col_start_idx, col_dimension)
submatrix += tensor_product
component_index += 1
col_start_idx += col_dimension
row_start_idx += row_dimension
return result_matrix.reshape(-1, self.nao_max * self.nao_max)
[docs]
def merge_rank2_tensor_components(self, spherical_components):
"""
Merge rank-2 tensor components into a block matrix representation.
This function is specialized for rank-2 tensors and applies permutation to the result.
Args:
spherical_components (list of torch.Tensor): List of tensors representing spherical
components of irreducible representations.
Returns:
torch.Tensor: Merged block matrix with shape (batch_size, n_blocks, 3, 3) after
applying a specific permutation.
"""
batch_size = spherical_components[0].shape[0]
result_matrix = torch.zeros(batch_size, self.num_blocks, 3, 3).type_as(spherical_components[0])
permutation_indices = torch.LongTensor([2, 0, 1]) # Indices for coordinate permutation
component_index = 0 # Index for accessing the correct irreps
block_index = 0
for _, row_orbital in self.row:
for _, col_orbital in self.col:
for angular_momentum in range(0, 3):
# Compute inverse spherical tensor product
clebsch_gordan_coef = math.sqrt(2 * angular_momentum + 1) * self.cg_calculator(1, 1, angular_momentum).unsqueeze(0)
tensor_product = (clebsch_gordan_coef * spherical_components[component_index].unsqueeze(-2).unsqueeze(-2)).sum(-1)
# Add product to appropriate part of the block
result_matrix[:, block_index, :, :] += tensor_product
component_index += 1
block_index += 1
# Apply permutation to the result
permuted_result = result_matrix[:, :, permutation_indices[:, None], permutation_indices[None, :]]
return permuted_result
[docs]
def merge_rank0_tensor_components(self, spherical_components):
"""
Merge rank-0 (scalar) tensor components into a matrix representation.
Args:
spherical_components (list of torch.Tensor): List of tensors representing spherical
components of irreducible representations.
Returns:
torch.Tensor: Merged matrix with shape (batch_size, nao_max, nao_max).
"""
batch_size = spherical_components[0].shape[0]
result_matrix = torch.zeros(batch_size, self.nao_max, self.nao_max).type_as(spherical_components[0])
component_index = 0 # Index for accessing the correct irreps
row_start_idx = 0
for _, row_orbital in self.row:
row_dimension = 2 * row_orbital.l + 1 # Dimension based on angular momentum
col_start_idx = 0
for _, col_orbital in self.col:
col_dimension = 2 * col_orbital.l + 1 # Dimension based on angular momentum
# For rank-0, we just expand the component to fill the submatrix
expanded_component = spherical_components[component_index].unsqueeze(-1).expand(-1, row_dimension, col_dimension)
# Add expanded component to appropriate part of the result matrix
submatrix = result_matrix.narrow(-2, row_start_idx, row_dimension).narrow(-1, col_start_idx, col_dimension)
submatrix += expanded_component
component_index += 1
col_start_idx += col_dimension
row_start_idx += row_dimension
return result_matrix
[docs]
def construct_j_coupling_matrix(self, coupling_coefficients):
"""
Construct a matrix representation of J-coupling (spin-orbit interaction).
Builds a matrix representation for J-coupling which can be either rank-2
(with spin-orbit coupling) or rank-0.
Args:
coupling_coefficients (torch.Tensor): Tensor containing coupling coefficients.
Returns:
torch.Tensor: The J-coupling matrix. If spin-orbit coupling is enabled,
shape is (batch_size, nao_max, nao_max, 3, 3); otherwise, shape is
(batch_size, nao_max, nao_max).
"""
# Split coupling coefficients into spherical components
spherical_components = torch.split(coupling_coefficients, self.J_irreps_dim.tolist(), dim=-1)
if self.soc_switch: # If spin-orbit coupling is enabled
# Use rank-2 tensor merge for spin-orbit coupling
processed_coefficients = self.merge_rank2_tensor_components(spherical_components)
result_matrix = torch.zeros(
processed_coefficients.shape[0], self.nao_max, self.nao_max, 3, 3
).type_as(processed_coefficients)
# Distribute the blocks to the full matrix
block_index = 0
row_start_idx = 0
for _, row_orbital in self.row:
row_dimension = 2 * row_orbital.l + 1
col_start_idx = 0
for _, col_orbital in self.col:
col_dimension = 2 * col_orbital.l + 1
# Expand block to fill the corresponding submatrix
block_expanded = processed_coefficients[:, block_index].reshape(-1, 1, 1, 3, 3).expand(
processed_coefficients.shape[0], row_dimension, col_dimension, 3, 3
)
result_matrix[:, row_start_idx:row_start_idx+row_dimension,
col_start_idx:col_start_idx+col_dimension] = block_expanded
block_index += 1
col_start_idx += col_dimension
row_start_idx += row_dimension
else:
# Use rank-0 tensor merge if spin-orbit coupling is disabled
result_matrix = self.merge_rank0_tensor_components(spherical_components)
return result_matrix
[docs]
def construct_k_coupling_matrix(self, coupling_coefficients):
"""
Construct a matrix representation of exchange coupling (K-term).
Args:
coupling_coefficients (torch.Tensor): Tensor containing exchange coupling coefficients
with shape (n_atoms_or_edges, coefficients_per_block * n_blocks).
Returns:
torch.Tensor: Exchange coupling matrix with shape (n_atoms_or_edges, nao_max, nao_max).
"""
# Reshape to separate blocks
reshaped_coefficients = coupling_coefficients.reshape(coupling_coefficients.shape[0], -1)
# Initialize result matrix
result_matrix = torch.zeros(
reshaped_coefficients.shape[0], self.nao_max, self.nao_max
).type_as(reshaped_coefficients)
# Distribute coefficients to blocks
block_index = 0
row_start_idx = 0
for _, row_orbital in self.row:
row_dimension = 2 * row_orbital.l + 1
col_start_idx = 0
for _, col_orbital in self.col:
col_dimension = 2 * col_orbital.l + 1
# Expand coefficient to fill the block
expanded_coefficient = reshaped_coefficients[:, block_index].reshape(-1, 1, 1).expand(
reshaped_coefficients.shape[0], row_dimension, col_dimension
)
# Place in the appropriate position in the result matrix
result_matrix[:, row_start_idx:row_start_idx+row_dimension,
col_start_idx:col_start_idx+col_dimension] = expanded_coefficient
block_index += 1
col_start_idx += col_dimension
row_start_idx += row_dimension
return result_matrix
[docs]
def reorder_matrix(self, matrix):
"""
Reorder matrix elements to match the atomic orbital convention used by DFT.
This function performs two types of transformations:
1. Reorders rows and columns according to a predefined permutation (if defined)
2. Flips the sign of specific elements (if defined)
These transformations ensure compatibility with DFT's atomic orbital ordering
convention, which may differ from the internal representation.
Args:
matrix (torch.Tensor): Input matrix in flattened form,
with shape (batch_size, nao_max2) where nao_max is the maximum
number of atomic orbitals.
Returns:
torch.Tensor: Reordered matrix in the same shape as input,
but with elements rearranged to match DFT conventions.
"""
# Only perform reordering if necessary index maps are defined
if self.index_change is not None or hasattr(self, 'minus_index'):
# Reshape to 3D tensor for easier manipulation
matrix_form = matrix.reshape(-1, self.nao_max, self.nao_max)
# Apply index permutation if defined
if self.index_change is not None:
# Use advanced indexing to permute both rows and columns
matrix_form = matrix_form[:, self.index_change[:, None], self.index_change[None, :]]
# Apply sign flipping if defined
if hasattr(self, 'minus_index'):
# Flip signs for specified rows
matrix_form[:, self.minus_index, :] = -matrix_form[:, self.minus_index, :]
# Flip signs for specified columns
matrix_form[:, :, self.minus_index] = -matrix_form[:, :, self.minus_index]
# Return to original flattened shape
matrix = matrix_form.reshape(-1, self.nao_max**2)
return matrix
[docs]
def construct_molecular_hamiltonian(self, data, onsite_hamiltonian, offsite_hamiltonian):
"""
Construct a complete molecular Hamiltonian matrix from on-site and off-site components.
This function transforms separate on-site (diagonal) and off-site (interaction) Hamiltonian
components into a unified molecular Hamiltonian matrix. It handles atom-specific orbital
basis sets and masks invalid or padded orbitals.
Args:
data (DataObject): Object containing crystal structure information, including:
- z: Atomic numbers
- node_counts: Number of atoms in each crystal
- edge_index: Edges between atoms (indices of connected atom pairs)
onsite_hamiltonian (torch.Tensor): On-site Hamiltonian matrix elements for each atom,
with shape (n_atoms, nao_max^2)
offsite_hamiltonian (torch.Tensor): Off-site Hamiltonian matrix elements for each edge,
with shape (n_edges, nao_max^2)
Returns:
torch.Tensor: Complete molecular Hamiltonian matrix with shape (n_molecules, n_orbitals, n_orbitals),
where n_orbitals is the number of valid atomic orbitals per molecule after removing
padding.
"""
# Get the maximum number of atoms in any crystal in the batch
max_atoms_per_crystal = torch.max(data.node_counts).item()
n_atoms = data.z.shape[0]
# Create orbital validity mask based on atomic numbers
# (determine which orbitals are valid for each atom type)
atom_orbital_mask = torch.zeros((99, self.nao_max), device=data.z.device).type_as(data.z)
basis_definition_mapped = copy.deepcopy(self.basis_def)
# Convert 1-indexed orbital indices to 0-indexed
for atomic_number, orbital_indices in self.basis_def.items():
zero_indexed_indices = [idx - 1 for idx in orbital_indices] # Convert to 0-indexed
basis_definition_mapped[atomic_number] = zero_indexed_indices
atom_orbital_mask[atomic_number][zero_indexed_indices] = 1
# Get the orbital mask for each atom in the batch
# First reshape to (n_batch_crystals, max_atoms*nao_max)
crystal_orbital_mask = atom_orbital_mask[data.z].view(-1, max_atoms_per_crystal * self.nao_max)
# Create a 2D mask for interactions between all orbitals (outer product)
# Shape: (n_batch_crystals, max_atoms*nao_max, max_atoms*nao_max)
orbital_interaction_mask = crystal_orbital_mask[:, :, None] * crystal_orbital_mask[:, None, :]
# Reshape mask to match the final Hamiltonian format
orbital_interaction_mask = orbital_interaction_mask.view(-1, max_atoms_per_crystal * self.nao_max)
# Initialize the combined Hamiltonian matrix
# Shape: (n_atoms, max_atoms_per_crystal, nao_max^2)
combined_hamiltonian = torch.zeros(
[n_atoms, max_atoms_per_crystal, self.nao_max**2]
).type_as(onsite_hamiltonian)
# Create atom index tensor for efficient indexing
atom_indices = torch.arange(n_atoms, device=data.z.device).type_as(data.z)
# Place on-site Hamiltonian elements
# For each atom, place its on-site Hamiltonian at the corresponding position
combined_hamiltonian[atom_indices, atom_indices % max_atoms_per_crystal] = onsite_hamiltonian
# Place off-site Hamiltonian elements
# For each edge (i,j), place the off-site Hamiltonian from i to j
source_atoms, target_atoms = data.edge_index
combined_hamiltonian[source_atoms, target_atoms % max_atoms_per_crystal] = offsite_hamiltonian
# Reshape and reorder dimensions to get the correct Hamiltonian structure
# From (n_atoms, max_atoms, nao_max, nao_max) to (n_atoms*nao_max, max_atoms*nao_max)
combined_hamiltonian = combined_hamiltonian.reshape(
n_atoms, max_atoms_per_crystal, self.nao_max, self.nao_max
)
combined_hamiltonian = combined_hamiltonian.permute((0, 2, 1, 3))
combined_hamiltonian = combined_hamiltonian.reshape(
n_atoms * self.nao_max, max_atoms_per_crystal * self.nao_max
)
# Apply mask to remove padded/invalid orbitals
valid_elements = torch.masked_select(combined_hamiltonian, orbital_interaction_mask > 0)
# Calculate the number of valid orbitals per molecule
n_molecules = n_atoms / max_atoms_per_crystal
orbitals_per_molecule = int(math.sqrt(valid_elements.shape[0] / n_molecules))
# Reshape to final molecular Hamiltonian form
molecular_hamiltonian = valid_elements.reshape(-1, orbitals_per_molecule)
return molecular_hamiltonian
[docs]
def concatenate_hamiltonians_by_crystal(self, data, onsite_hamiltonians, offsite_hamiltonians):
"""
Concatenate on-site and off-site Hamiltonian matrices for each crystal in a batch.
This function organizes Hamiltonian matrices by crystal, interleaving on-site and
off-site components in a specific order required for further processing.
Args:
data (DataObject): Object containing crystal structure information, including:
- node_counts: Number of atoms in each crystal
- edge_index: Indices of connected atom pairs
- batch: Batch assignment for each node
onsite_hamiltonians (torch.Tensor): On-site Hamiltonian matrices for all atoms,
with shape (total_atoms, matrix_dimension)
offsite_hamiltonians (torch.Tensor): Off-site Hamiltonian matrices for all edges,
with shape (total_edges, matrix_dimension)
Returns:
torch.Tensor: Concatenated Hamiltonian matrices organized by crystal,
with alternating on-site and off-site blocks.
"""
# Split on-site Hamiltonians by crystal based on node counts
atoms_per_crystal = data.node_counts
onsite_by_crystal = torch.split(onsite_hamiltonians, atoms_per_crystal.tolist(), dim=0)
# Get source and target indices for each edge
source_nodes, target_nodes = data.edge_index
# Count edges per crystal using scatter operation
edge_counter = torch.ones_like(source_nodes)
edges_per_crystal = scatter(edge_counter, data.batch[source_nodes], dim=0)
# Split off-site Hamiltonians by crystal based on edge counts
offsite_by_crystal = torch.split(offsite_hamiltonians, edges_per_crystal.tolist(), dim=0)
# Interleave on-site and off-site Hamiltonians for each crystal
concatenated_hamiltonians = []
for crystal_idx in range(len(atoms_per_crystal)):
concatenated_hamiltonians.append(onsite_by_crystal[crystal_idx]) # Add on-site Hamiltonians
concatenated_hamiltonians.append(offsite_by_crystal[crystal_idx]) # Add off-site Hamiltonians
# Concatenate all Hamiltonians into a single tensor
return torch.cat(concatenated_hamiltonians, dim=0)
[docs]
def symmetrize_hamiltonian(self, hamiltonian, is_soc=False, inverse_edges=None, symmetry_type='hermitian'):
"""
Apply symmetrization to a Hamiltonian matrix.
This is a general-purpose function that handles various types of symmetrization for both
on-site and off-site Hamiltonians, with or without spin-orbit coupling (SOC).
Args:
hamiltonian (torch.Tensor): Hamiltonian matrix elements in flattened form.
is_soc (bool, optional): Whether this is a spin-orbit coupling Hamiltonian with
double dimension. Defaults to False.
inverse_edges (torch.Tensor, optional): Indices mapping each edge to its inverse
for off-site Hamiltonians. Required for off-site symmetrization. Defaults to None.
symmetry_type (str, optional): Type of symmetry to apply:
- 'hermitian': Apply Hermitian symmetry H = 0.5*(H + H?)
- 'anti-hermitian': Apply anti-Hermitian symmetry H = 0.5*(H - H?)
Defaults to 'hermitian'.
Returns:
torch.Tensor: Symmetrized Hamiltonian matrix in the same format as input.
"""
# Return original matrix if symmetrization is disabled
if not self.symmetrize:
# For SOC case, ensure consistent shape even when not symmetrizing
if is_soc:
dimension = 2 * self.nao_max
return hamiltonian.reshape(-1, dimension**2)
return hamiltonian
# Determine matrix dimension based on whether this is SOC
dimension = 2 * self.nao_max if is_soc else self.nao_max
# Reshape for matrix operations
matrix_form = hamiltonian.reshape(-1, dimension, dimension)
# Determine symmetrization operation based on type
if symmetry_type == 'hermitian':
if inverse_edges is None:
# On-site case: symmetrize with its own transpose
symmetrized = 0.5 * (matrix_form + matrix_form.conj().permute(0, 2, 1))
else:
# Off-site case: symmetrize with transpose of inverse edges
symmetrized = 0.5 * (matrix_form + matrix_form[inverse_edges].conj().permute(0, 2, 1))
elif symmetry_type == 'anti-hermitian':
if inverse_edges is None:
# On-site case: anti-symmetrize with its own transpose
symmetrized = 0.5 * (matrix_form - matrix_form.conj().permute(0, 2, 1))
else:
# Off-site case: anti-symmetrize with transpose of inverse edges
symmetrized = 0.5 * (matrix_form - matrix_form[inverse_edges].conj().permute(0, 2, 1))
else:
raise ValueError(f"Unknown symmetry type: {symmetry_type}")
# Reshape back to original format
return symmetrized.reshape(-1, dimension**2)
[docs]
def symmetrize_onsite_hamiltonian(self, hamiltonian, hermitian=True):
"""
Symmetrize on-site Hamiltonian matrices.
Applies Hermitian or anti-Hermitian symmetrization to on-site Hamiltonians.
Args:
hamiltonian (torch.Tensor): On-site Hamiltonian matrix elements.
hermitian (bool, optional): If True, apply Hermitian symmetry (H + H?),
otherwise apply anti-Hermitian symmetry (H - H?). Defaults to True.
Returns:
torch.Tensor: Symmetrized Hamiltonian.
"""
symmetry_type = 'hermitian' if hermitian else 'anti-hermitian'
return self.symmetrize_hamiltonian(
hamiltonian, is_soc=False, inverse_edges=None, symmetry_type=symmetry_type
)
[docs]
def symmetrize_offsite_hamiltonian(self, hamiltonian, inverse_edges, hermitian=True):
"""
Symmetrize off-site Hamiltonian matrices.
Applies Hermitian or anti-Hermitian symmetrization to off-site Hamiltonians,
using the inverse edge mapping to relate connected atom pairs.
Args:
hamiltonian (torch.Tensor): Off-site Hamiltonian matrix elements.
inverse_edges (torch.Tensor): Tensor mapping each edge to its inverse edge index.
hermitian (bool, optional): If True, apply Hermitian symmetry (H + H?),
otherwise apply anti-Hermitian symmetry (H - H?). Defaults to True.
Returns:
torch.Tensor: Symmetrized Hamiltonian.
"""
symmetry_type = 'hermitian' if hermitian else 'anti-hermitian'
return self.symmetrize_hamiltonian(
hamiltonian, is_soc=False, inverse_edges=inverse_edges, symmetry_type=symmetry_type
)
[docs]
def symmetrize_onsite_hamiltonian_soc(self, hamiltonian, hermitian=True):
"""
Symmetrize on-site Hamiltonian matrices with spin-orbit coupling.
Applies Hermitian or anti-Hermitian symmetrization to on-site Hamiltonians
that include spin-orbit coupling, which have double the dimension.
Args:
hamiltonian (torch.Tensor): On-site SOC Hamiltonian matrix elements.
hermitian (bool, optional): If True, apply Hermitian symmetry (H + H?),
otherwise apply anti-Hermitian symmetry (H - H?). Defaults to True.
Returns:
torch.Tensor: Symmetrized SOC Hamiltonian.
"""
symmetry_type = 'hermitian' if hermitian else 'anti-hermitian'
return self.symmetrize_hamiltonian(
hamiltonian, is_soc=True, inverse_edges=None, symmetry_type=symmetry_type
)
[docs]
def symmetrize_offsite_hamiltonian_soc(self, hamiltonian, inverse_edges, hermitian=True):
"""
Symmetrize off-site Hamiltonian matrices with spin-orbit coupling.
Applies Hermitian or anti-Hermitian symmetrization to off-site Hamiltonians
that include spin-orbit coupling, which have double the dimension.
Args:
hamiltonian (torch.Tensor): Off-site SOC Hamiltonian matrix elements.
inverse_edges (torch.Tensor): Tensor mapping each edge to its inverse edge index.
hermitian (bool, optional): If True, apply Hermitian symmetry (H + H?),
otherwise apply anti-Hermitian symmetry (H - H?). Defaults to True.
Returns:
torch.Tensor: Symmetrized SOC Hamiltonian.
"""
symmetry_type = 'hermitian' if hermitian else 'anti-hermitian'
return self.symmetrize_hamiltonian(
hamiltonian, is_soc=True, inverse_edges=inverse_edges, symmetry_type=symmetry_type
)
[docs]
def calculate_band_energies_with_overlap(self, onsite_hamiltonian, offsite_hamiltonian,
onsite_overlap, offsite_overlap, crystal_data,
export_reciprocal_values=False):
"""
Calculate electronic band structure using provided Hamiltonian and overlap matrices.
This function computes electronic band energies, wavefunctions, and band gaps for a set
of crystal structures using the generalized eigenvalue problem H���� = E��S����. This version
allows debugging by accepting both reference and predicted overlap matrices.
Args:
onsite_hamiltonian (torch.Tensor): On-site Hamiltonian matrix elements with shape
(total_atoms, nao_max2).
offsite_hamiltonian (torch.Tensor): Off-site Hamiltonian matrix elements with shape
(total_edges, nao_max2).
onsite_overlap (torch.Tensor): Predicted on-site overlap matrix elements with shape
(total_atoms, nao_max2).
offsite_overlap (torch.Tensor): Predicted off-site overlap matrix elements with shape
(total_edges, nao_max2).
crystal_data (DataObject): Object containing crystal structure information including:
- edge_index: Indices of connected atom pairs
- cell: Unit cell vectors
- z: Atomic numbers
- node_counts: Number of atoms in each crystal
- batch: Batch assignment for each atom
- k_vecs: k-points for band structure calculation
- nbr_shift: Neighbor cell shifts for periodic boundary conditions
- Son/Soff: Reference overlap matrices
export_reciprocal_values (bool, optional): Whether to export additional reciprocal
space matrices (H(k), S(k), dS(k)). Defaults to False.
Returns:
tuple: Contains band energies, wavefunctions, and optionally additional reciprocal space
matrices depending on the export_reciprocal_values parameter.
"""
source_indices, target_indices = crystal_data.edge_index
lattice_vectors = crystal_data.cell # shape: (batch_size, 3, 3)
batch_size = lattice_vectors.shape[0]
# Create orbital validity mask based on atomic numbers
atomic_orbital_mask = torch.zeros((99, self.nao_max)).type_as(crystal_data.z)
for atomic_number, orbital_indices in self.basis_def.items():
atomic_orbital_mask[atomic_number][orbital_indices] = 1
# Get orbital mask for each atom
orbital_mask = atomic_orbital_mask[crystal_data.z] # shape: [total_atoms, nao_max]
orbital_mask_by_crystal = torch.split(orbital_mask, crystal_data.node_counts.tolist(), dim=0)
# Create interaction masks for each crystal
orbital_interaction_masks = []
for idx in range(batch_size):
# Create outer product of orbital masks for interaction masking
orbital_interaction_masks.append(
orbital_mask_by_crystal[idx].reshape(-1, 1) *
orbital_mask_by_crystal[idx].reshape(1, -1)
)
# Set the number of valence electrons per atom type
valence_electron_counts = torch.zeros((99,)).type_as(crystal_data.z)
for atomic_number, count in self.num_valence.items():
valence_electron_counts[atomic_number] = count
# Get valence electron count for each atom and then for each crystal
atom_valence_counts = valence_electron_counts[crystal_data.z] # shape: [total_atoms]
crystal_valence_counts = scatter(atom_valence_counts, crystal_data.batch, dim=0) # shape: [batch_size]
# Initialize band window sizes if needed
if self.band_num_control is not None:
bands_per_atom_type = torch.zeros((99,)).type_as(crystal_data.z)
for atomic_number, band_count in self.band_num_control.items():
bands_per_atom_type[atomic_number] = band_count
atom_band_counts = bands_per_atom_type[crystal_data.z] # shape: [total_atoms]
crystal_band_counts = scatter(atom_band_counts, crystal_data.batch, dim=0) # shape: [batch_size]
# Split data by crystal
atoms_per_crystal = crystal_data.node_counts
atom_index_offsets = torch.cumsum(atoms_per_crystal, dim=0) - atoms_per_crystal
# Split Hamiltonians and overlaps by crystal
onsite_hamiltonians_by_crystal = torch.split(onsite_hamiltonian, atoms_per_crystal.tolist(), dim=0)
reference_onsite_overlaps = torch.split(crystal_data.Son, atoms_per_crystal.tolist(), dim=0)
predicted_onsite_overlaps = torch.split(onsite_overlap, atoms_per_crystal.tolist(), dim=0)
# Count edges per crystal and split edge data
edge_counter = torch.ones_like(source_indices)
edges_per_crystal = scatter(edge_counter, crystal_data.batch[source_indices], dim=0)
edge_index_offsets = torch.cumsum(edges_per_crystal, dim=0) - edges_per_crystal
offsite_hamiltonians_by_crystal = torch.split(offsite_hamiltonian, edges_per_crystal.tolist(), dim=0)
reference_offsite_overlaps = torch.split(crystal_data.Soff, edges_per_crystal.tolist(), dim=0)
predicted_offsite_overlaps = torch.split(offsite_overlap, edges_per_crystal.tolist(), dim=0)
# Prepare derivatives if exporting reciprocal values
if export_reciprocal_values:
onsite_overlap_derivatives = torch.split(crystal_data.dSon, atoms_per_crystal.tolist(), dim=0)
offsite_overlap_derivatives = torch.split(crystal_data.dSoff, edges_per_crystal.tolist(), dim=0)
# Initialize results containers
band_energies = []
wavefunctions = []
reciprocal_hamiltonians = []
symmetrized_hamiltonians = []
reciprocal_overlaps = []
reciprocal_overlap_derivatives = []
band_gaps = []
# Process each crystal
for crystal_idx in range(batch_size):
k_points = crystal_data.k_vecs[crystal_idx]
num_atoms = atoms_per_crystal[crystal_idx]
# Calculate phase factors for k-points
# Shape: (edges_in_crystal, num_k_points)
phase_factors = torch.exp(
2j * torch.pi * torch.sum(
crystal_data.nbr_shift[
edge_index_offsets[crystal_idx] +
torch.arange(edges_per_crystal[crystal_idx]).type_as(source_indices),
None, :
] * k_points[None, :, :],
dim=-1
)
)
# Initialize k-space Hamiltonian and overlap matrices
hamiltonian_k = torch.view_as_complex(
torch.zeros((
self.num_k, num_atoms, num_atoms, self.nao_max, self.nao_max, 2
)).type_as(onsite_hamiltonian)
)
reference_overlap_k = torch.view_as_complex(
torch.zeros((
self.num_k, num_atoms, num_atoms, self.nao_max, self.nao_max, 2
)).type_as(onsite_hamiltonian)
)
predicted_overlap_k = torch.view_as_complex(
torch.zeros((
self.num_k, num_atoms, num_atoms, self.nao_max, self.nao_max, 2
)).type_as(onsite_hamiltonian)
)
if export_reciprocal_values:
overlap_derivatives_k = torch.view_as_complex(
torch.zeros((
self.num_k, num_atoms, num_atoms, self.nao_max, self.nao_max, 3, 2
)).type_as(onsite_hamiltonian)
)
# Add on-site terms to k-space matrices
atom_indices = torch.arange(num_atoms).type_as(source_indices)
# Add on-site Hamiltonian and overlap terms
hamiltonian_k[:, atom_indices, atom_indices, :, :] += onsite_hamiltonians_by_crystal[crystal_idx].reshape(
-1, self.nao_max, self.nao_max
)[None, atom_indices, :, :].type_as(hamiltonian_k)
reference_overlap_k[:, atom_indices, atom_indices, :, :] += reference_onsite_overlaps[crystal_idx].reshape(
-1, self.nao_max, self.nao_max
)[None, atom_indices, :, :].type_as(reference_overlap_k)
predicted_overlap_k[:, atom_indices, atom_indices, :, :] += predicted_onsite_overlaps[crystal_idx].reshape(
-1, self.nao_max, self.nao_max
)[None, atom_indices, :, :].type_as(predicted_overlap_k)
if export_reciprocal_values:
overlap_derivatives_k[:, atom_indices, atom_indices, :, :, :] += onsite_overlap_derivatives[crystal_idx].reshape(
-1, self.nao_max, self.nao_max, 3
)[None, atom_indices, :, :, :].type_as(overlap_derivatives_k)
# Add off-site terms to k-space matrices
for edge_idx in range(edges_per_crystal[crystal_idx]):
# Get local indices within this crystal
source_atom_idx = source_indices[edge_index_offsets[crystal_idx] + edge_idx] - atom_index_offsets[crystal_idx]
target_atom_idx = target_indices[edge_index_offsets[crystal_idx] + edge_idx] - atom_index_offsets[crystal_idx]
# Add contribution for each k-point with phase factor
hamiltonian_k[:, source_atom_idx, target_atom_idx, :, :] += (
phase_factors[edge_idx, :, None, None] *
offsite_hamiltonians_by_crystal[crystal_idx].reshape(-1, self.nao_max, self.nao_max)[None, edge_idx, :, :]
)
reference_overlap_k[:, source_atom_idx, target_atom_idx, :, :] += (
phase_factors[edge_idx, :, None, None] *
reference_offsite_overlaps[crystal_idx].reshape(-1, self.nao_max, self.nao_max)[None, edge_idx, :, :]
)
predicted_overlap_k[:, source_atom_idx, target_atom_idx, :, :] += (
phase_factors[edge_idx, :, None, None] *
predicted_offsite_overlaps[crystal_idx].reshape(-1, self.nao_max, self.nao_max)[None, edge_idx, :, :]
)
# Add derivative terms if needed
if export_reciprocal_values:
for edge_idx in range(edges_per_crystal[crystal_idx]):
source_atom_idx = source_indices[edge_index_offsets[crystal_idx] + edge_idx] - atom_index_offsets[crystal_idx]
target_atom_idx = target_indices[edge_index_offsets[crystal_idx] + edge_idx] - atom_index_offsets[crystal_idx]
overlap_derivatives_k[:, source_atom_idx, target_atom_idx, :, :, :] += (
phase_factors[edge_idx, :, None, None, None] *
offsite_overlap_derivatives[crystal_idx].reshape(-1, self.nao_max, self.nao_max, 3)[None, edge_idx, :, :, :]
)
# Reshape matrices to combine atom and orbital dimensions
hamiltonian_k = torch.swapaxes(hamiltonian_k, -2, -3) # Swap atom and orbital indices
hamiltonian_k = hamiltonian_k.reshape(-1, num_atoms * self.nao_max, num_atoms * self.nao_max)
reference_overlap_k = torch.swapaxes(reference_overlap_k, -2, -3)
reference_overlap_k = reference_overlap_k.reshape(-1, num_atoms * self.nao_max, num_atoms * self.nao_max)
predicted_overlap_k = torch.swapaxes(predicted_overlap_k, -2, -3)
predicted_overlap_k = predicted_overlap_k.reshape(-1, num_atoms * self.nao_max, num_atoms * self.nao_max)
if export_reciprocal_values:
overlap_derivatives_k = torch.swapaxes(overlap_derivatives_k, -3, -4)
overlap_derivatives_k = overlap_derivatives_k.reshape(-1, num_atoms * self.nao_max, num_atoms * self.nao_max, 3)
# Apply orbital mask to select only valid orbitals
mask = orbital_interaction_masks[crystal_idx].repeat(self.num_k, 1, 1) > 0
hamiltonian_k = torch.masked_select(hamiltonian_k, mask)
num_orbitals = int(math.sqrt(hamiltonian_k.numel() / self.num_k))
hamiltonian_k = hamiltonian_k.reshape(self.num_k, num_orbitals, num_orbitals)
reference_overlap_k = torch.masked_select(reference_overlap_k, mask)
reference_overlap_k = reference_overlap_k.reshape(self.num_k, num_orbitals, num_orbitals)
predicted_overlap_k = torch.masked_select(predicted_overlap_k, mask)
predicted_overlap_k = predicted_overlap_k.reshape(self.num_k, num_orbitals, num_orbitals)
if export_reciprocal_values:
derivative_mask = orbital_interaction_masks[crystal_idx].unsqueeze(-1).repeat(self.num_k, 1, 1, 3) > 0
overlap_derivatives_k = torch.masked_select(overlap_derivatives_k, derivative_mask)
overlap_derivatives_k = overlap_derivatives_k.reshape(self.num_k, num_orbitals, num_orbitals, 3)
# Solve generalized eigenvalue problem using Cholesky decomposition
cholesky_factor = torch.linalg.cholesky(reference_overlap_k)
cholesky_factor_conj_transpose = torch.transpose(cholesky_factor.conj(), dim0=-1, dim1=-2)
cholesky_inverse = torch.linalg.inv(cholesky_factor)
cholesky_conj_transpose_inverse = torch.linalg.inv(cholesky_factor_conj_transpose)
# Transform to standard eigenvalue problem
transformed_hamiltonian = torch.bmm(
torch.bmm(cholesky_inverse, hamiltonian_k),
cholesky_conj_transpose_inverse
)
# Solve eigenvalue problem
eigenvalues, eigenvectors = torch.linalg.eigh(transformed_hamiltonian)
# Transform eigenvectors back to original basis
eigenvectors = torch.einsum('ijk,ika->iaj', cholesky_conj_transpose_inverse, eigenvectors)
# Normalize wavefunctions if exporting reciprocal values
if export_reciprocal_values:
# Calculate normalization factors
normalization_factors = torch.einsum(
'nai,nij,naj->na',
torch.conj(eigenvectors),
reference_overlap_k,
eigenvectors
).real
# Apply normalization
normalization_factors = 1 / torch.sqrt(normalization_factors)
eigenvectors = eigenvectors * normalization_factors.unsqueeze(-1)
# Store reciprocal space matrices
reciprocal_hamiltonians.append(hamiltonian_k)
reciprocal_overlaps.append(predicted_overlap_k)
reciprocal_overlap_derivatives.append(overlap_derivatives_k)
# Apply band number control if specified
if self.band_num_control is not None:
eigenvalues = eigenvalues[:, :crystal_band_counts[crystal_idx]]
eigenvectors = eigenvectors[:, :crystal_band_counts[crystal_idx], :]
# Store results
band_energies.append(torch.transpose(eigenvalues, dim0=-1, dim1=-2))
wavefunctions.append(eigenvectors)
symmetrized_hamiltonians.append(transformed_hamiltonian.view(-1))
# Calculate band gap
half_filled_band_index = math.ceil(crystal_valence_counts[crystal_idx] / 2)
band_gap = (
torch.min(eigenvalues[:, half_filled_band_index]) -
torch.max(eigenvalues[:, half_filled_band_index - 1])
).unsqueeze(0)
band_gaps.append(band_gap)
# Combine results across all crystals
band_energies = torch.cat(band_energies, dim=0)
band_gaps = torch.cat(band_gaps, dim=0)
if export_reciprocal_values:
# Return reciprocal space matrices
wavefunctions = torch.stack(wavefunctions, dim=0)
hamiltonians_k_space = torch.stack(reciprocal_hamiltonians, dim=0)
overlaps_k_space = torch.stack(reciprocal_overlaps, dim=0)
overlap_derivatives_k_space = torch.stack(reciprocal_overlap_derivatives, dim=0)
return band_energies, wavefunctions, hamiltonians_k_space, overlaps_k_space, overlap_derivatives_k_space, band_gaps
else:
# Flatten wavefunctions and return
wavefunctions = [wf.reshape(-1) for wf in wavefunctions]
wavefunctions = torch.cat(wavefunctions, dim=0)
symmetrized_hamiltonians = torch.cat(symmetrized_hamiltonians, dim=0)
return band_energies, wavefunctions, band_gaps, symmetrized_hamiltonians
[docs]
def calculate_band_energies(self, onsite_hamiltonian, offsite_hamiltonian, crystal_data,
export_reciprocal_values=False):
"""
Calculate electronic band structure using Hamiltonian matrices and reference overlap matrices.
This function computes electronic band energies, wavefunctions, and band gaps for a set
of crystal structures by solving the generalized eigenvalue problem H���� = E��S����.
Args:
onsite_hamiltonian (torch.Tensor): On-site Hamiltonian matrix elements with shape
(total_atoms, nao_max2).
offsite_hamiltonian (torch.Tensor): Off-site Hamiltonian matrix elements with shape
(total_edges, nao_max2).
crystal_data (DataObject): Object containing crystal structure information including:
- edge_index: Indices of connected atom pairs
- cell: Unit cell vectors
- z: Atomic numbers
- node_counts: Number of atoms in each crystal
- batch: Batch assignment for each atom
- k_vecs: k-points for band structure calculation
- nbr_shift: Neighbor cell shifts for periodic boundary conditions
- Son/Soff: Reference overlap matrices
export_reciprocal_values (bool, optional): Whether to export additional reciprocal
space matrices (H(k), S(k), dS(k)). Defaults to False.
Returns:
tuple: Contains band energies, wavefunctions, band gaps, and optionally additional
reciprocal space matrices depending on the export_reciprocal_values parameter.
"""
source_indices, target_indices = crystal_data.edge_index
lattice_vectors = crystal_data.cell # shape: (batch_size, 3, 3)
batch_size = lattice_vectors.shape[0]
# Create orbital validity mask based on atomic numbers
atomic_orbital_mask = torch.zeros((99, self.nao_max)).type_as(crystal_data.z)
for atomic_number, orbital_indices in self.basis_def.items():
atomic_orbital_mask[atomic_number][orbital_indices] = 1
# Get orbital mask for each atom
orbital_mask = atomic_orbital_mask[crystal_data.z] # shape: [total_atoms, nao_max]
orbital_mask_by_crystal = torch.split(orbital_mask, crystal_data.node_counts.tolist(), dim=0)
# Create interaction masks for each crystal
orbital_interaction_masks = []
for idx in range(batch_size):
# Create outer product of orbital masks for interaction masking
orbital_interaction_masks.append(
orbital_mask_by_crystal[idx].reshape(-1, 1) *
orbital_mask_by_crystal[idx].reshape(1, -1)
)
# Set the number of valence electrons per atom type
valence_electron_counts = torch.zeros((99,)).type_as(crystal_data.z)
for atomic_number, count in self.num_valence.items():
valence_electron_counts[atomic_number] = count
# Get valence electron count for each atom and then for each crystal
atom_valence_counts = valence_electron_counts[crystal_data.z] # shape: [total_atoms]
crystal_valence_counts = scatter(atom_valence_counts, crystal_data.batch, dim=0) # shape: [batch_size]
# Initialize band window sizes if needed
if isinstance(self.band_num_control, dict):
bands_per_atom_type = torch.zeros((99,)).type_as(crystal_data.z)
for atomic_number, band_count in self.band_num_control.items():
bands_per_atom_type[atomic_number] = band_count
atom_band_counts = bands_per_atom_type[crystal_data.z] # shape: [total_atoms]
crystal_band_counts = scatter(atom_band_counts, crystal_data.batch, dim=0) # shape: [batch_size]
# Split data by crystal
atoms_per_crystal = crystal_data.node_counts
atom_index_offsets = torch.cumsum(atoms_per_crystal, dim=0) - atoms_per_crystal
# Split Hamiltonians and overlaps by crystal
onsite_hamiltonians_by_crystal = torch.split(onsite_hamiltonian, atoms_per_crystal.tolist(), dim=0)
onsite_overlaps_by_crystal = torch.split(crystal_data.Son, atoms_per_crystal.tolist(), dim=0)
# Count edges per crystal and split edge data
edge_counter = torch.ones_like(source_indices)
edges_per_crystal = scatter(edge_counter, crystal_data.batch[source_indices], dim=0)
edge_index_offsets = torch.cumsum(edges_per_crystal, dim=0) - edges_per_crystal
offsite_hamiltonians_by_crystal = torch.split(offsite_hamiltonian, edges_per_crystal.tolist(), dim=0)
offsite_overlaps_by_crystal = torch.split(crystal_data.Soff, edges_per_crystal.tolist(), dim=0)
# Prepare derivatives if exporting reciprocal values
if export_reciprocal_values:
onsite_overlap_derivatives = torch.split(crystal_data.dSon, atoms_per_crystal.tolist(), dim=0)
offsite_overlap_derivatives = torch.split(crystal_data.dSoff, edges_per_crystal.tolist(), dim=0)
# Initialize results containers
band_energies = []
wavefunctions = []
reciprocal_hamiltonians = []
symmetrized_hamiltonians = []
reciprocal_overlaps = []
reciprocal_overlap_derivatives = []
band_gaps = []
# Process each crystal
for crystal_idx in range(batch_size):
k_points = crystal_data.k_vecs[crystal_idx]
num_atoms = atoms_per_crystal[crystal_idx]
# Calculate phase factors for k-points
# Shape: (edges_in_crystal, num_k_points)
phase_factors = torch.exp(
2j * torch.pi * torch.sum(
crystal_data.nbr_shift[
edge_index_offsets[crystal_idx] +
torch.arange(edges_per_crystal[crystal_idx]).type_as(source_indices),
None, :
] * k_points[None, :, :],
dim=-1
)
)
# Initialize k-space Hamiltonian and overlap matrices
hamiltonian_k = torch.view_as_complex(
torch.zeros((
self.num_k, num_atoms, num_atoms, self.nao_max, self.nao_max, 2
)).type_as(onsite_hamiltonian)
)
overlap_k = torch.view_as_complex(
torch.zeros((
self.num_k, num_atoms, num_atoms, self.nao_max, self.nao_max, 2
)).type_as(onsite_hamiltonian)
)
if export_reciprocal_values:
overlap_derivatives_k = torch.view_as_complex(
torch.zeros((
self.num_k, num_atoms, num_atoms, self.nao_max, self.nao_max, 3, 2
)).type_as(onsite_hamiltonian)
)
# Add on-site terms to k-space matrices
atom_indices = torch.arange(num_atoms).type_as(source_indices)
# Add on-site Hamiltonian and overlap terms
hamiltonian_k[:, atom_indices, atom_indices, :, :] += onsite_hamiltonians_by_crystal[crystal_idx].reshape(
-1, self.nao_max, self.nao_max
)[None, atom_indices, :, :].type_as(hamiltonian_k)
overlap_k[:, atom_indices, atom_indices, :, :] += onsite_overlaps_by_crystal[crystal_idx].reshape(
-1, self.nao_max, self.nao_max
)[None, atom_indices, :, :].type_as(overlap_k)
if export_reciprocal_values:
overlap_derivatives_k[:, atom_indices, atom_indices, :, :, :] += onsite_overlap_derivatives[crystal_idx].reshape(
-1, self.nao_max, self.nao_max, 3
)[None, atom_indices, :, :, :].type_as(overlap_derivatives_k)
# Prepare edge data for vectorized operations
edge_indices = torch.arange(edges_per_crystal[crystal_idx], device=source_indices.device)
source_atom_indices = source_indices[edge_index_offsets[crystal_idx] + edge_indices] - atom_index_offsets[crystal_idx]
target_atom_indices = target_indices[edge_index_offsets[crystal_idx] + edge_indices] - atom_index_offsets[crystal_idx]
# Reshape matrices for more efficient operations
offsite_h_reshaped = offsite_hamiltonians_by_crystal[crystal_idx].reshape(
edges_per_crystal[crystal_idx], self.nao_max, self.nao_max
)
offsite_s_reshaped = offsite_overlaps_by_crystal[crystal_idx].reshape(
edges_per_crystal[crystal_idx], self.nao_max, self.nao_max
)
# Add off-site terms to k-space matrices for each k-point
for k_idx in range(self.num_k):
# Pre-calculate phase factors for this k-point
k_phase_factors = phase_factors[:edges_per_crystal[crystal_idx], k_idx].unsqueeze(-1).unsqueeze(-1)
# Calculate contribution to Hamiltonian
h_contributions = k_phase_factors * offsite_h_reshaped
s_contributions = k_phase_factors * offsite_s_reshaped
# Add contributions to k-space matrices using index_put
hamiltonian_k[k_idx] = torch.index_put(
hamiltonian_k[k_idx],
(source_atom_indices, target_atom_indices),
h_contributions,
accumulate=True
)
overlap_k[k_idx] = torch.index_put(
overlap_k[k_idx],
(source_atom_indices, target_atom_indices),
s_contributions,
accumulate=True
)
# Add derivative terms if needed
if export_reciprocal_values:
offsite_ds_reshaped = offsite_overlap_derivatives[crystal_idx].reshape(
edges_per_crystal[crystal_idx], self.nao_max, self.nao_max, 3
)
for k_idx in range(self.num_k):
# Pre-calculate phase factors for this k-point
k_phase_factors = phase_factors[:edges_per_crystal[crystal_idx], k_idx].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
# Calculate contribution to derivatives
ds_contributions = k_phase_factors * offsite_ds_reshaped
# Add contributions to k-space matrices using index_put
overlap_derivatives_k[k_idx] = torch.index_put(
overlap_derivatives_k[k_idx],
(source_atom_indices, target_atom_indices),
ds_contributions,
accumulate=True
)
# Reshape matrices to combine atom and orbital dimensions
hamiltonian_k = torch.swapaxes(hamiltonian_k, -2, -3) # Swap atom and orbital indices
hamiltonian_k = hamiltonian_k.reshape(-1, num_atoms * self.nao_max, num_atoms * self.nao_max)
overlap_k = torch.swapaxes(overlap_k, -2, -3)
overlap_k = overlap_k.reshape(-1, num_atoms * self.nao_max, num_atoms * self.nao_max)
if export_reciprocal_values:
overlap_derivatives_k = torch.swapaxes(overlap_derivatives_k, -3, -4)
overlap_derivatives_k = overlap_derivatives_k.reshape(-1, num_atoms * self.nao_max, num_atoms * self.nao_max, 3)
# Apply orbital mask to select only valid orbitals
mask = orbital_interaction_masks[crystal_idx].repeat(self.num_k, 1, 1) > 0
hamiltonian_k = torch.masked_select(hamiltonian_k, mask)
num_orbitals = int(math.sqrt(hamiltonian_k.numel() / self.num_k))
hamiltonian_k = hamiltonian_k.reshape(self.num_k, num_orbitals, num_orbitals)
overlap_k = torch.masked_select(overlap_k, mask)
overlap_k = overlap_k.reshape(self.num_k, num_orbitals, num_orbitals)
if export_reciprocal_values:
derivative_mask = orbital_interaction_masks[crystal_idx].unsqueeze(-1).repeat(self.num_k, 1, 1, 3) > 0
overlap_derivatives_k = torch.masked_select(overlap_derivatives_k, derivative_mask)
overlap_derivatives_k = overlap_derivatives_k.reshape(self.num_k, num_orbitals, num_orbitals, 3)
# Solve generalized eigenvalue problem using Cholesky decomposition
cholesky_factor = torch.linalg.cholesky(overlap_k)
cholesky_factor_conj_transpose = torch.transpose(cholesky_factor.conj(), dim0=-1, dim1=-2)
cholesky_inverse = torch.linalg.inv(cholesky_factor)
cholesky_conj_transpose_inverse = torch.linalg.inv(cholesky_factor_conj_transpose)
# Transform to standard eigenvalue problem
transformed_hamiltonian = torch.bmm(
torch.bmm(cholesky_inverse, hamiltonian_k),
cholesky_conj_transpose_inverse
)
# Solve eigenvalue problem
eigenvalues, eigenvectors = torch.linalg.eigh(transformed_hamiltonian)
# Transform eigenvectors back to original basis
eigenvectors = torch.einsum('ijk,ika->iaj', cholesky_conj_transpose_inverse, eigenvectors)
# Calculate band gap
half_filled_band_index = math.ceil(crystal_valence_counts[crystal_idx] / 2)
band_gap = (
torch.min(eigenvalues[:, half_filled_band_index]) -
torch.max(eigenvalues[:, half_filled_band_index - 1])
).unsqueeze(0)
band_gaps.append(band_gap)
# Apply band number control if specified
if self.band_num_control is not None:
if isinstance(self.band_num_control, dict):
# Use pre-calculated band counts
eigenvalues = eigenvalues[:, :crystal_band_counts[crystal_idx]]
eigenvectors = eigenvectors[:, :crystal_band_counts[crystal_idx], :]
else:
# Calculate band window dynamically
if isinstance(self.band_num_control, float):
band_window = max(1, int(self.band_num_control * half_filled_band_index))
else:
band_window = min(self.band_num_control, half_filled_band_index)
# Select band window centered around half-filled band
start_band = half_filled_band_index - band_window
end_band = half_filled_band_index + band_window
eigenvalues = eigenvalues[:, start_band:end_band]
eigenvectors = eigenvectors[:, start_band:end_band, :]
# Normalize wavefunctions if exporting reciprocal values
if export_reciprocal_values:
# Calculate normalization factors
normalization_factors = torch.einsum(
'nai,nij,naj->na',
torch.conj(eigenvectors),
overlap_k,
eigenvectors
).real
# Apply normalization
normalization_factors = 1 / torch.sqrt(normalization_factors)
eigenvectors = eigenvectors * normalization_factors.unsqueeze(-1)
# Store reciprocal space matrices
reciprocal_hamiltonians.append(hamiltonian_k)
reciprocal_overlaps.append(overlap_k)
reciprocal_overlap_derivatives.append(overlap_derivatives_k)
# Store results
band_energies.append(torch.transpose(eigenvalues, dim0=-1, dim1=-2))
wavefunctions.append(eigenvectors)
symmetrized_hamiltonians.append(transformed_hamiltonian.view(-1))
# Combine results across all crystals
band_energies = torch.cat(band_energies, dim=0)
band_gaps = torch.cat(band_gaps, dim=0)
if export_reciprocal_values:
# Return reciprocal space matrices
wavefunctions = torch.stack(wavefunctions, dim=0)
hamiltonians_k_space = torch.stack(reciprocal_hamiltonians, dim=0)
overlaps_k_space = torch.stack(reciprocal_overlaps, dim=0)
overlap_derivatives_k_space = torch.stack(reciprocal_overlap_derivatives, dim=0)
return band_energies, wavefunctions, hamiltonians_k_space, overlaps_k_space, overlap_derivatives_k_space, band_gaps
else:
# Flatten wavefunctions and return
wavefunctions = [wf.reshape(-1) for wf in wavefunctions]
wavefunctions = torch.cat(wavefunctions, dim=0)
symmetrized_hamiltonians = torch.cat(symmetrized_hamiltonians, dim=0)
return band_energies, wavefunctions, band_gaps, symmetrized_hamiltonians
[docs]
def calculate_band_energies_with_spin_orbit_coupling(self, real_onsite, imag_onsite, real_offsite, imag_offsite, crystal_data):
"""
Calculate electronic band structure with spin-orbit coupling (SOC).
This function computes electronic band energies and wavefunctions for systems with
spin-orbit coupling, which requires handling complex Hamiltonian matrices and doubling
the matrix dimensions to account for spin.
Args:
real_onsite (torch.Tensor): Real part of on-site SOC Hamiltonian with shape
(total_atoms, 2*nao_max, 2*nao_max).
imag_onsite (torch.Tensor): Imaginary part of on-site SOC Hamiltonian with shape
(total_atoms, 2*nao_max, 2*nao_max).
real_offsite (torch.Tensor): Real part of off-site SOC Hamiltonian with shape
(total_edges, 2*nao_max, 2*nao_max).
imag_offsite (torch.Tensor): Imaginary part of off-site SOC Hamiltonian with shape
(total_edges, 2*nao_max, 2*nao_max).
crystal_data (DataObject): Object containing crystal structure information including:
- edge_index: Indices of connected atom pairs
- cell: Unit cell vectors
- z: Atomic numbers
- node_counts: Number of atoms in each crystal
- batch: Batch assignment for each atom
- k_vecs: k-points for band structure calculation
- cell_shift: Cell shift vectors for periodic images
- nbr_shift: Neighbor cell shifts for periodic boundary conditions
- Son/Soff: Reference overlap matrices (without SOC)
Returns:
tuple: (
band_energies (torch.Tensor): Band energies with shape (total_bands, num_k_points),
wavefunctions (torch.Tensor): Flattened wavefunctions
)
"""
source_indices, target_indices = crystal_data.edge_index
lattice_vectors = crystal_data.cell # shape: (batch_size, 3, 3)
batch_size = lattice_vectors.shape[0]
# Reshape Hamiltonian components to matrix form
real_onsite = real_onsite.reshape(-1, 2*self.nao_max, 2*self.nao_max)
imag_onsite = imag_onsite.reshape(-1, 2*self.nao_max, 2*self.nao_max)
real_offsite = real_offsite.reshape(-1, 2*self.nao_max, 2*self.nao_max)
imag_offsite = imag_offsite.reshape(-1, 2*self.nao_max, 2*self.nao_max)
# Create orbital validity mask based on atomic numbers
atomic_orbital_mask = torch.zeros((99, self.nao_max)).type_as(crystal_data.z)
for atomic_number, orbital_indices in self.basis_def.items():
atomic_orbital_mask[atomic_number][orbital_indices] = 1
# Get orbital mask for each atom
orbital_mask = atomic_orbital_mask[crystal_data.z] # shape: [total_atoms, nao_max]
orbital_mask_by_crystal = torch.split(orbital_mask, crystal_data.node_counts.tolist(), dim=0)
# Create interaction masks for each crystal
orbital_interaction_masks = []
for idx in range(batch_size):
# Create outer product of orbital masks for interaction masking
orbital_interaction_masks.append(
orbital_mask_by_crystal[idx].reshape(-1, 1) *
orbital_mask_by_crystal[idx].reshape(1, -1)
)
# Set the number of valence electrons per atom type
valence_electron_counts = torch.zeros((99,)).type_as(crystal_data.z)
for atomic_number, count in self.num_valence.items():
valence_electron_counts[atomic_number] = count
# Get valence electron count for each atom and then for each crystal
atom_valence_counts = valence_electron_counts[crystal_data.z] # shape: [total_atoms]
crystal_valence_counts = scatter(atom_valence_counts, crystal_data.batch, dim=0) # shape: [batch_size]
# Initialize band window sizes if needed
if isinstance(self.band_num_control, dict):
bands_per_atom_type = torch.zeros((99,)).type_as(crystal_data.z)
for atomic_number, band_count in self.band_num_control.items():
bands_per_atom_type[atomic_number] = band_count
atom_band_counts = bands_per_atom_type[crystal_data.z] # shape: [total_atoms]
crystal_band_counts = scatter(atom_band_counts, crystal_data.batch, dim=0) # shape: [batch_size]
# Split data by crystal
atoms_per_crystal = crystal_data.node_counts
# Split Hamiltonians by crystal
real_onsite_by_crystal = torch.split(real_onsite, atoms_per_crystal.tolist(), dim=0)
imag_onsite_by_crystal = torch.split(imag_onsite, atoms_per_crystal.tolist(), dim=0)
onsite_overlaps_by_crystal = torch.split(
crystal_data.Son.reshape(-1, self.nao_max, self.nao_max),
atoms_per_crystal.tolist(),
dim=0
)
# Count edges per crystal and split edge data
edge_counter = torch.ones_like(source_indices)
edges_per_crystal = scatter(edge_counter, crystal_data.batch[source_indices], dim=0)
real_offsite_by_crystal = torch.split(real_offsite, edges_per_crystal.tolist(), dim=0)
imag_offsite_by_crystal = torch.split(imag_offsite, edges_per_crystal.tolist(), dim=0)
offsite_overlaps_by_crystal = torch.split(
crystal_data.Soff.reshape(-1, self.nao_max, self.nao_max),
edges_per_crystal.tolist(),
dim=0
)
# Split cell shifts and neighbor shifts by crystal
cell_shifts_by_crystal = torch.split(crystal_data.cell_shift, edges_per_crystal.tolist(), dim=0)
nbr_shifts_by_crystal = torch.split(crystal_data.nbr_shift, edges_per_crystal.tolist(), dim=0)
# Split edge indices by crystal and convert to local indices
edge_indices_by_crystal = torch.split(crystal_data.edge_index, edges_per_crystal.tolist(), dim=1)
atom_index_offsets = torch.cumsum(atoms_per_crystal, dim=0) - atoms_per_crystal
# Convert global edge indices to local indices within each crystal
local_edge_indices = []
for idx, edge_indices in enumerate(edge_indices_by_crystal):
local_edge_indices.append(edge_indices - atom_index_offsets[idx])
# Initialize results containers
band_energies = []
wavefunctions = []
# Process each crystal
for crystal_idx in range(batch_size):
k_points = crystal_data.k_vecs[crystal_idx]
num_atoms = atoms_per_crystal[crystal_idx].item()
# Create unique cell shift mapping for this crystal
cell_shift_tuples = [tuple(shift.cpu().tolist()) for shift in cell_shifts_by_crystal[crystal_idx]]
unique_cell_shifts = list(set(cell_shift_tuples))
cell_to_index = {shift: idx for idx, shift in enumerate(unique_cell_shifts)}
cell_indices = torch.tensor([cell_to_index[shift] for shift in cell_shift_tuples], device=source_indices.device).type_as(source_indices)
num_cells = len(unique_cell_shifts)
# Calculate phase factors for k-points
# Shape: (num_k_points, num_cells)
phase_factors = torch.view_as_complex(
torch.zeros((self.num_k, num_cells, 2)).type_as(crystal_data.Son)
)
# Set phase factors for each cell shift
phase_factors[:, cell_indices] = torch.exp(
2j * torch.pi * torch.sum(
nbr_shifts_by_crystal[crystal_idx][None, :, :] * k_points[:, None, :],
dim=-1
)
)
# Create atom index tensor
atom_indices = torch.arange(num_atoms, device=source_indices.device).type_as(source_indices)
# Initialize overlap matrix in k-space
s_by_cell = torch.view_as_complex(
torch.zeros((
num_cells, num_atoms, num_atoms, self.nao_max, self.nao_max, 2
)).type_as(crystal_data.Son)
)
# Set off-site overlap elements
source_atoms, target_atoms = local_edge_indices[crystal_idx]
s_by_cell[cell_indices, source_atoms, target_atoms, :, :] += offsite_overlaps_by_crystal[crystal_idx]
# Calculate k-space overlap matrix
overlap_k = torch.einsum('ijklm,ni->njklm', s_by_cell, phase_factors)
# Add on-site overlap elements
overlap_k[:, atom_indices, atom_indices, :, :] += onsite_overlaps_by_crystal[crystal_idx][None, atom_indices, :, :]
# Reshape and swap dimensions for consistent ordering
overlap_k = torch.swapaxes(overlap_k, 2, 3) # Swap atom and orbital indices
overlap_k = overlap_k.reshape(self.num_k, num_atoms * self.nao_max, num_atoms * self.nao_max)
# Apply orbital mask to select only valid orbitals
overlap_k = overlap_k[:, orbital_interaction_masks[crystal_idx] > 0]
num_orbitals = int(math.sqrt(overlap_k.numel() / self.num_k))
overlap_k = overlap_k.reshape(self.num_k, num_orbitals, num_orbitals)
# Create identity matrix for spin expansion
identity = torch.eye(2, device=crystal_data.Hon.device).type_as(crystal_data.Hon)
# Expand overlap matrix to account for spin
overlap_k_soc = torch.kron(identity, overlap_k)
# Process SOC Hamiltonian components
# Extract spin blocks from on-site Hamiltonian
h_up_up = real_onsite_by_crystal[crystal_idx][:, :self.nao_max, :self.nao_max] + \
1.0j * imag_onsite_by_crystal[crystal_idx][:, :self.nao_max, :self.nao_max]
h_up_down = real_onsite_by_crystal[crystal_idx][:, :self.nao_max, self.nao_max:] + \
1.0j * imag_onsite_by_crystal[crystal_idx][:, :self.nao_max, self.nao_max:]
h_down_up = real_onsite_by_crystal[crystal_idx][:, self.nao_max:, :self.nao_max] + \
1.0j * imag_onsite_by_crystal[crystal_idx][:, self.nao_max:, :self.nao_max]
h_down_down = real_onsite_by_crystal[crystal_idx][:, self.nao_max:, self.nao_max:] + \
1.0j * imag_onsite_by_crystal[crystal_idx][:, self.nao_max:, self.nao_max:]
onsite_soc_blocks = [h_up_up, h_up_down, h_down_up, h_down_down]
# Extract spin blocks from off-site Hamiltonian
h_up_up = real_offsite_by_crystal[crystal_idx][:, :self.nao_max, :self.nao_max] + \
1.0j * imag_offsite_by_crystal[crystal_idx][:, :self.nao_max, :self.nao_max]
h_up_down = real_offsite_by_crystal[crystal_idx][:, :self.nao_max, self.nao_max:] + \
1.0j * imag_offsite_by_crystal[crystal_idx][:, :self.nao_max, self.nao_max:]
h_down_up = real_offsite_by_crystal[crystal_idx][:, self.nao_max:, :self.nao_max] + \
1.0j * imag_offsite_by_crystal[crystal_idx][:, self.nao_max:, :self.nao_max]
h_down_down = real_offsite_by_crystal[crystal_idx][:, self.nao_max:, self.nao_max:] + \
1.0j * imag_offsite_by_crystal[crystal_idx][:, self.nao_max:, self.nao_max:]
offsite_soc_blocks = [h_up_up, h_up_down, h_down_up, h_down_down]
# Initialize k-space Hamiltonian blocks
hamiltonian_k_blocks = []
# Process each spin block
for onsite_block, offsite_block in zip(onsite_soc_blocks, offsite_soc_blocks):
# Initialize Hamiltonian matrix in k-space for this spin block
h_by_cell = torch.view_as_complex(
torch.zeros((
num_cells, num_atoms, num_atoms, self.nao_max, self.nao_max, 2
)).type_as(crystal_data.Son)
)
# Set off-site Hamiltonian elements
h_by_cell[cell_indices, source_atoms, target_atoms, :, :] += offsite_block
# Calculate k-space Hamiltonian matrix
hamiltonian_k_block = torch.einsum('ijklm,ni->njklm', h_by_cell, phase_factors)
# Add on-site Hamiltonian elements
hamiltonian_k_block[:, atom_indices, atom_indices, :, :] += onsite_block[None, atom_indices, :, :]
# Reshape and swap dimensions for consistent ordering
hamiltonian_k_block = torch.swapaxes(hamiltonian_k_block, 2, 3)
hamiltonian_k_block = hamiltonian_k_block.reshape(self.num_k, num_atoms * self.nao_max, num_atoms * self.nao_max)
# Apply orbital mask to select only valid orbitals
hamiltonian_k_block = hamiltonian_k_block[:, orbital_interaction_masks[crystal_idx] > 0]
hamiltonian_k_block = hamiltonian_k_block.reshape(self.num_k, num_orbitals, num_orbitals)
hamiltonian_k_blocks.append(hamiltonian_k_block)
# Combine spin blocks into full SOC Hamiltonian
hamiltonian_k_soc = torch.cat([
torch.cat([hamiltonian_k_blocks[0], hamiltonian_k_blocks[1]], dim=-1),
torch.cat([hamiltonian_k_blocks[2], hamiltonian_k_blocks[3]], dim=-1)
], dim=-2)
# Solve generalized eigenvalue problem using Cholesky decomposition
cholesky_factor = torch.linalg.cholesky(overlap_k_soc)
cholesky_factor_conj_transpose = torch.transpose(cholesky_factor.conj(), dim0=-1, dim1=-2)
cholesky_inverse = torch.linalg.inv(cholesky_factor)
cholesky_conj_transpose_inverse = torch.linalg.inv(cholesky_factor_conj_transpose)
# Transform to standard eigenvalue problem
transformed_hamiltonian = torch.bmm(
torch.bmm(cholesky_inverse, hamiltonian_k_soc),
cholesky_conj_transpose_inverse
)
# Solve eigenvalue problem
eigenvalues, eigenvectors = torch.linalg.eigh(transformed_hamiltonian)
# Transform eigenvectors back to original basis
eigenvectors = torch.bmm(cholesky_conj_transpose_inverse, eigenvectors)
# Apply band number control if specified
if self.band_num_control is not None:
if isinstance(self.band_num_control, dict):
# Use pre-calculated band counts
eigenvalues = eigenvalues[:, :crystal_band_counts[crystal_idx]]
eigenvectors = eigenvectors[:, :crystal_band_counts[crystal_idx], :]
else:
# Calculate band window centered around valence electron count
band_window_start = crystal_valence_counts[crystal_idx] - self.band_num_control
band_window_end = crystal_valence_counts[crystal_idx] + self.band_num_control
eigenvalues = eigenvalues[:, band_window_start:band_window_end]
eigenvectors = eigenvectors[:, band_window_start:band_window_end, :]
# Store results
band_energies.append(torch.transpose(eigenvalues, dim0=-1, dim1=-2))
wavefunctions.append(eigenvectors)
# Combine results across all crystals
band_energies = torch.cat(band_energies, dim=0)
wavefunctions = torch.cat(wavefunctions, dim=0).reshape(-1)
return band_energies, wavefunctions
[docs]
def apply_orbital_masks_to_hamiltonians(self, onsite_hamiltonian, offsite_hamiltonian, data, return_masks=False):
"""
Apply atomic orbital validity masks to on-site and off-site Hamiltonian matrices.
This function zeroes out elements of Hamiltonian matrices that correspond to invalid
or non-existent atomic orbitals based on the atomic numbers. This ensures physical
correctness by preventing interactions involving orbitals that shouldn't exist for
particular atom types.
Args:
onsite_hamiltonian (torch.Tensor): On-site Hamiltonian matrix elements with shape
(n_atoms, nao_max^2) or (n_atoms, nao_max, nao_max).
offsite_hamiltonian (torch.Tensor): Off-site Hamiltonian matrix elements with shape
(n_edges, nao_max^2) or (n_edges, nao_max, nao_max).
data (DataObject): Object containing:
- z: Atomic numbers for each atom
- edge_index: Indices of connected atom pairs (source, target)
return_masks (bool, optional): Whether to return the masks along with masked
Hamiltonians. Defaults to False.
Returns:
tuple: If return_masks is False:
(masked_onsite_hamiltonian, masked_offsite_hamiltonian)
If return_masks is True:
(masked_onsite_hamiltonian, masked_offsite_hamiltonian,
onsite_orbital_mask, offsite_orbital_mask)
"""
# Create atomic orbital validity mask based on atomic numbers
atomic_orbital_mask = torch.zeros((99, self.nao_max), device=data.z.device).type_as(data.z)
# Populate mask with 1s for valid orbitals of each atomic number
for atomic_number, orbital_indices in self.basis_def.items():
atomic_orbital_mask[atomic_number][orbital_indices] = 1
# Save original shapes for reshaping at the end
original_onsite_shape = onsite_hamiltonian.shape
original_offsite_shape = offsite_hamiltonian.shape
# ---- Create and apply mask for on-site Hamiltonian ----
# Get orbital mask for each atom
atom_orbital_masks = atomic_orbital_mask[data.z].view(-1, self.nao_max) # [n_atoms, nao_max]
# Create outer product of masks for each atom's interaction with itself
# This gives a matrix where element (i,j) is 1 if both orbitals i and j are valid
onsite_orbital_mask = atom_orbital_masks[:, :, None] * atom_orbital_masks[:, None, :] # [n_atoms, nao_max, nao_max]
# Reshape to match Hamiltonian shape
onsite_orbital_mask = onsite_orbital_mask.reshape(original_onsite_shape)
# Apply mask to on-site Hamiltonian
masked_onsite_hamiltonian = onsite_hamiltonian * onsite_orbital_mask
# ---- Create and apply mask for off-site Hamiltonian ----
# Get source and target atom indices for each edge
source_atoms, target_atoms = data.edge_index
# Get orbital masks for source and target atoms of each edge
source_orbital_masks = atomic_orbital_mask[data.z[source_atoms]].view(-1, self.nao_max) # [n_edges, nao_max]
target_orbital_masks = atomic_orbital_mask[data.z[target_atoms]].view(-1, self.nao_max) # [n_edges, nao_max]
# Create outer product of source and target masks for each edge
# This gives a matrix where element (i,j) is 1 if both orbitals i (source) and j (target) are valid
offsite_orbital_mask = source_orbital_masks[:, :, None] * target_orbital_masks[:, None, :] # [n_edges, nao_max, nao_max]
# Reshape to match Hamiltonian shape
offsite_orbital_mask = offsite_orbital_mask.reshape(original_offsite_shape)
# Apply mask to off-site Hamiltonian
masked_offsite_hamiltonian = offsite_hamiltonian * offsite_orbital_mask
# Return results based on return_masks flag
if return_masks:
return masked_onsite_hamiltonian, masked_offsite_hamiltonian, onsite_orbital_mask, offsite_orbital_mask
else:
return masked_onsite_hamiltonian, masked_offsite_hamiltonian
[docs]
def symmetrize_orbital_coefficients(self, coefficient_matrix):
"""
Enforce spherical symmetry on orbital coefficient matrices by averaging within angular momentum blocks.
This function applies orbital symmetrization to ensure that coefficients maintain proper
rotational invariance within each angular momentum subspace (p, d, f orbitals). Each block
of coefficients corresponding to orbitals with the same angular momentum is averaged to
enforce spherical symmetry constraints.
Orbital index ranges by angular momentum:
- s orbitals: 0:3 (single orbital with additional indices)
- p orbitals: 3:6 (3 components)
- p' orbitals: 6:9 (3 components)
- d orbitals: 9:14 (5 components)
- d' orbitals: 14:19 (5 components, only for nao_max �� 19)
- f orbitals: 19:26 (7 components, only for nao_max = 26)
Args:
coefficient_matrix (torch.Tensor): Coefficient matrix with shape (batch_size, nao_max2)
or (batch_size, nao_max, nao_max).
Returns:
torch.Tensor: Symmetrized coefficient matrix with shape (batch_size, nao_max2).
"""
batch_size = coefficient_matrix.shape[0]
# Reshape to 3D tensor if necessary
matrix_3d = coefficient_matrix.reshape(batch_size, self.nao_max, self.nao_max)
# Define orbital blocks based on angular momentum
orbital_blocks = []
if self.nao_max >= 14: # All sizes have at least these blocks
orbital_blocks = [
(3, 6), # p orbitals (3 components)
(6, 9), # p' orbitals (3 components)
(9, 14), # d orbitals (5 components)
]
if self.nao_max >= 19: # 19 and 26 have d' orbitals
orbital_blocks.append((14, 19)) # d' orbitals (5 components)
if self.nao_max == 26: # Only 26 has f orbitals
orbital_blocks.append((19, 26)) # f orbitals (7 components)
# Apply row-wise averaging (across orbital m components)
for start_idx, end_idx in orbital_blocks:
# Calculate mean across the angular momentum block
block_mean = torch.mean(matrix_3d[:, start_idx:end_idx], dim=1, keepdim=True)
# Expand mean to match the original block shape and assign
block_size = end_idx - start_idx
matrix_3d[:, start_idx:end_idx] = block_mean.expand(batch_size, block_size, self.nao_max)
# Apply column-wise averaging (across orbital m components)
for start_idx, end_idx in orbital_blocks:
# Calculate mean across the angular momentum block
block_mean = torch.mean(matrix_3d[:, :, start_idx:end_idx], dim=2, keepdim=True)
# Expand mean to match the original block shape and assign
block_size = end_idx - start_idx
matrix_3d[:, :, start_idx:end_idx] = block_mean.expand(batch_size, self.nao_max, block_size)
# Flatten the matrix back to original shape
return matrix_3d.view(batch_size, -1)
[docs]
def create_cell_index_mapping(self, unique_cell_vectors: List[List[int]]) -> dict:
"""
Create a mapping from unique cell vectors to their indices.
This function creates a dictionary that associates each unique cell vector
(representing a unit cell in a periodic lattice) with its index in the list
of unique vectors. This mapping is used for efficient lookup of cell indices
when constructing Hamiltonian matrices with periodic boundary conditions.
Args:
unique_cell_vectors (List[List[int]]): A list of cell vectors, where each
vector is represented as a list of integers (typically 3 integers for
3D periodic systems).
Returns:
dict: A dictionary mapping each cell vector (as a tuple) to its integer index.
"""
# Create mapping using dictionary comprehension for efficiency
return {tuple(cell_vector): index for index, cell_vector in enumerate(unique_cell_vectors)}
[docs]
def build_edge_lookup_structures(self, data, inverse_edge_indices=None):
"""
Build efficient edge lookup structures for crystal graph operations.
This function constructs two data structures that enable fast edge lookup:
1. A mapping from each atom to all edges where that atom is the source
2. A mapping from each atom and cell shift combination to corresponding edges
These structures accelerate operations that require finding all edges connected
to specific atoms, particularly when periodic boundary conditions are involved.
Args:
data (DataObject): Crystal graph data containing:
- edge_index: Tensor of shape [2, num_edges] with source and target indices
- unique_cell_shift: Tensor of unique cell shift vectors
- cell_shift_indices: Tensor mapping each edge to its cell shift index
- cell_index_map: Dictionary mapping cell shift tuples to indices
- z: Atomic numbers, used to determine number of atoms
inverse_edge_indices (torch.Tensor, optional): Tensor mapping each edge to its
inverse edge index (i��j maps to j��i). Required for building the target lookup.
Returns:
tuple:
- source_edge_indices (list of torch.Tensor): For each atom, contains indices
of edges where that atom is the source.
- target_edge_indices (list of list of torch.Tensor): For each atom and cell
shift combination, contains indices of edges where that atom is the target
in the specified cell shift.
"""
# Extract necessary data
source_nodes, target_nodes = data.edge_index
unique_cell_shifts = data.unique_cell_shift
cell_shift_indices = data.cell_shift_indices
cell_index_map = data.cell_index_map
# Get dimensions
num_atoms = len(data.z)
num_cell_shifts = len(unique_cell_shifts)
# Build mapping from each atom to edges where it's the source
# This finds all edges that originate from each atom
source_edge_indices = [torch.where(source_nodes == atom_idx)[0] for atom_idx in range(num_atoms)]
# Initialize the nested structure for target edge lookup
# This will map each (atom, cell_shift) pair to a list of edge indices
target_edge_indices = [[[] for _ in range(num_cell_shifts)] for _ in range(num_atoms)]
# Populate target edge lookup structure
for atom_idx in range(num_atoms):
# Get the inverse edges for edges where this atom is the source
# These are edges where this atom is the target
inverse_edges = inverse_edge_indices[source_edge_indices[atom_idx]]
# Get cell shift indices for these inverse edges
cell_shifts_of_inverse_edges = cell_shift_indices[inverse_edges]
# Group inverse edges by their cell shift
for edge_idx, cell_shift_idx in zip(inverse_edges, cell_shifts_of_inverse_edges):
target_edge_indices[atom_idx][cell_shift_idx.item()].append(edge_idx)
# Convert lists to tensors for efficiency
for cell_shift_idx in range(num_cell_shifts):
if target_edge_indices[atom_idx][cell_shift_idx]:
# If edges exist for this cell shift, stack them into a tensor
target_edge_indices[atom_idx][cell_shift_idx] = torch.stack(
target_edge_indices[atom_idx][cell_shift_idx]
).type_as(source_nodes)
else:
# Otherwise create an empty tensor of the correct type
target_edge_indices[atom_idx][cell_shift_idx] = torch.tensor(
[], dtype=torch.long
).type_as(source_nodes)
return source_edge_indices, target_edge_indices
[docs]
def create_orbital_validity_mask(self, atomic_numbers):
"""
Create a mask identifying valid atomic orbitals for each element type.
This function generates a binary mask tensor where each row corresponds to an
element in the periodic table (indexed by atomic number), and each column
represents an atomic orbital. A value of 1 indicates that the orbital is
valid for that element type, while 0 indicates an invalid orbital.
Args:
atomic_numbers (torch.Tensor): Tensor containing atomic numbers, used only
for device and dtype information.
Returns:
torch.Tensor: Binary mask tensor of shape (99, nao_max) where 99 covers
all possible elements in the periodic table and nao_max is the maximum
number of atomic orbitals.
"""
# Initialize all orbitals as invalid (zeros)
orbital_mask = torch.zeros((99, self.nao_max), device=atomic_numbers.device).type_as(atomic_numbers)
# Set valid orbitals to 1 for each element type based on the basis definition
for atomic_number, valid_orbital_indices in self.basis_def.items():
orbital_mask[atomic_number][valid_orbital_indices] = 1
return orbital_mask
[docs]
def build_interaction_masks(self, data):
"""
Build boolean masks for valid orbital interactions in Hamiltonian matrices.
This function generates masks that identify which elements in the Hamiltonian
matrices correspond to valid orbital-orbital interactions, based on the atomic
species involved. It creates separate masks for on-site (same atom) and
off-site (different atoms) interactions.
Args:
data (DataObject): Graph data containing:
- edge_index: Indices of connected atom pairs
- z: Atomic numbers for each atom
Returns:
torch.Tensor: Combined mask tensor with shape (n_atoms + n_edges, nao_max2)
where the first n_atoms rows are for on-site interactions and the
remaining n_edges rows are for off-site interactions.
"""
source_indices, target_indices = data.edge_index
atomic_numbers = data.z
# Get orbital validity mask for each element type
orbital_validity_mask = self.create_orbital_validity_mask(atomic_numbers)
# Create on-site interaction masks using einsum for efficient outer product
# Each mask has shape (n_atoms, nao_max, nao_max) where element (i,j,k) is True
# if orbital j of atom i can interact with orbital k of the same atom
onsite_masks = torch.einsum(
'ni, nj -> nij',
orbital_validity_mask[atomic_numbers],
orbital_validity_mask[atomic_numbers]
).bool()
# Create off-site interaction masks using einsum
# Each mask has shape (n_edges, nao_max, nao_max) where element (e,j,k) is True
# if orbital j of source atom can interact with orbital k of target atom
offsite_masks = torch.einsum(
'ni, nj -> nij',
orbital_validity_mask[atomic_numbers[source_indices]],
orbital_validity_mask[atomic_numbers[target_indices]]
).bool()
# Reshape masks to 2D format and concatenate
combined_masks = torch.cat(
(onsite_masks.reshape(-1, self.nao_max**2),
offsite_masks.reshape(-1, self.nao_max**2)),
dim=0
)
return combined_masks
[docs]
def build_column_wise_interaction_masks(self, data):
"""
Build interaction masks with an additional column dimension.
Similar to build_interaction_masks, but creates masks with an additional
dimension for column-wise operations (e.g., for real and imaginary parts).
Args:
data (DataObject): Graph data containing:
- edge_index: Indices of connected atom pairs
- z: Atomic numbers for each atom
Returns:
torch.Tensor: Combined mask tensor with shape
(n_atoms + n_edges, 2, nao_max2) where the additional
dimension can be used for real and imaginary parts.
"""
source_indices, target_indices = data.edge_index
atomic_numbers = data.z
# Get orbital validity mask for each element type
orbital_validity_mask = self.create_orbital_validity_mask(atomic_numbers)
# Create on-site interaction masks
onsite_masks = torch.einsum(
'ni, nj -> nij',
orbital_validity_mask[atomic_numbers],
orbital_validity_mask[atomic_numbers]
).bool()
# Stack the mask along a new dimension (for real and imaginary parts)
# Shape: (n_atoms, 2, nao_max, nao_max)
onsite_masks_stacked = torch.stack([onsite_masks, onsite_masks], dim=1)
# Create off-site interaction masks
offsite_masks = torch.einsum(
'ni, nj -> nij',
orbital_validity_mask[atomic_numbers[source_indices]],
orbital_validity_mask[atomic_numbers[target_indices]]
).bool()
# Stack the mask along a new dimension
# Shape: (n_edges, 2, nao_max, nao_max)
offsite_masks_stacked = torch.stack([offsite_masks, offsite_masks], dim=1)
# Reshape masks and concatenate
combined_masks = torch.cat(
(onsite_masks_stacked.reshape(-1, 2, self.nao_max**2),
offsite_masks_stacked.reshape(-1, 2, self.nao_max**2)),
dim=0
)
return combined_masks
[docs]
def build_spin_orbit_interaction_masks(self, data):
"""
Build interaction masks that account for spin-orbit coupling.
This function generates masks for systems with spin-orbit coupling, which
requires handling interactions between different spin components. The resulting
masks have double the orbital dimension to account for spin-up and spin-down
components.
Args:
data (DataObject): Graph data containing:
- edge_index: Indices of connected atom pairs
- z: Atomic numbers for each atom
Returns:
tuple: A tuple containing:
- real_imag_masks (torch.Tensor): Mask for real and imaginary
components with shape (n_atoms + n_edges, (2*nao_max)2)
- combined_masks (torch.Tensor): Full mask tensor with shape
(2*(n_atoms + n_edges), (2*nao_max)2)
"""
source_indices, target_indices = data.edge_index
atomic_numbers = data.z
# Get orbital validity mask for each element type
orbital_validity_mask = self.create_orbital_validity_mask(atomic_numbers)
# Create base interaction masks without converting to boolean yet
onsite_base_masks = torch.einsum(
'ni, nj -> nij',
orbital_validity_mask[atomic_numbers],
orbital_validity_mask[atomic_numbers]
)
offsite_base_masks = torch.einsum(
'ni, nj -> nij',
orbital_validity_mask[atomic_numbers[source_indices]],
orbital_validity_mask[atomic_numbers[target_indices]]
)
# Expand to include spin components using 2x2 block structure
# This creates a block diagonal matrix where each block represents
# interactions between different spin components
onsite_expanded_masks = blockwise_2x2_concat(
onsite_base_masks, onsite_base_masks,
onsite_base_masks, onsite_base_masks
).reshape(-1, (2*self.nao_max)**2).bool()
offsite_expanded_masks = blockwise_2x2_concat(
offsite_base_masks, offsite_base_masks,
offsite_base_masks, offsite_base_masks
).reshape(-1, (2*self.nao_max)**2).bool()
# Combine on-site and off-site masks
real_imag_masks = self.concatenate_hamiltonians_by_crystal(
data, onsite_expanded_masks, offsite_expanded_masks
)
# Create full mask tensor by duplicating (for real and imaginary parts)
combined_masks = torch.cat((real_imag_masks, real_imag_masks), dim=0)
return real_imag_masks, combined_masks
[docs]
def calculate_sparsity_ratio(self, data):
"""
Calculate the ratio between the total possible matrix elements and the effective matrix elements.
This function computes the sparsity of Hamiltonian matrices by dividing the total number
of possible matrix elements by the number of effective matrix elements based on the
atomic basis definitions.
Parameters
----------
data : object
Data object containing atomic information and Hamiltonian matrices.
Returns
-------
torch.Tensor
A scalar tensor on the same device as ``data.z``. The sparsity ratio is
defined as the total number of matrix elements divided by the number of
effective matrix elements. Returns ``inf`` if there are no effective elements.
Notes
-----
The calculation considers both on-site and off-site Hamiltonian elements
if they are present in the data object.
"""
# Fast path: allow offline precompute/caching on the Data object.
if hasattr(data, "sparsity_ratio"):
cached = getattr(data, "sparsity_ratio")
if torch.is_tensor(cached):
return cached.to(device=data.z.device, dtype=torch.float32)
return torch.tensor(float(cached), device=data.z.device, dtype=torch.float32)
atomic_numbers = data.z
device = atomic_numbers.device
# Build (and cache) lookup tables on the correct device to avoid per-atom/edge Python loops.
cache = getattr(self, "_sparsity_ratio_lookup_tables", None)
if cache is None or cache[0].device != device:
table_size = 256 # covers Z<=118 and avoids device sync from dynamic sizing
n_orbital = torch.full((table_size,), int(self.nao_max), dtype=torch.long, device=device)
is_defined = torch.zeros((table_size,), dtype=torch.bool, device=device)
for z, basis_indices in self.basis_def.items():
zi = int(z)
if 0 <= zi < table_size:
n_orbital[zi] = int(len(basis_indices))
is_defined[zi] = True
cache = (n_orbital, is_defined)
setattr(self, "_sparsity_ratio_lookup_tables", cache)
n_orbital, is_defined = cache
z = atomic_numbers.to(torch.long)
nao_max_squared = int(self.nao_max) ** 2
total = z.new_zeros((), dtype=torch.long)
effective = z.new_zeros((), dtype=torch.long)
# On-site blocks: sum_i (n_i^2); unknown elements default to nao_max^2 (matches legacy behavior).
if hasattr(data, "Hon"):
num_atoms = int(z.numel())
total = total + int(num_atoms * nao_max_squared)
n_i = n_orbital[z]
effective = effective + (n_i * n_i).sum()
# Off-site blocks: sum_e (n_src*n_dst) if both defined else nao_max^2 (matches legacy behavior).
if hasattr(data, "Hoff") and hasattr(data, "edge_index"):
edge_index = data.edge_index
num_edges = int(edge_index.shape[1])
total = total + int(num_edges * nao_max_squared)
src = edge_index[0].to(torch.long)
dst = edge_index[1].to(torch.long)
z_src = z[src]
z_dst = z[dst]
both_defined = is_defined[z_src] & is_defined[z_dst]
n_src = n_orbital[z_src]
n_dst = n_orbital[z_dst]
eff_edges = torch.where(
both_defined,
n_src * n_dst,
n_src.new_full(n_src.shape, nao_max_squared),
)
effective = effective + eff_edges.sum()
total_f64 = total.to(torch.float64)
effective_f64 = effective.to(torch.float64)
inf = torch.full_like(total_f64, float("inf"))
ratio_f64 = torch.where(effective_f64 > 0, total_f64 / effective_f64, inf)
return ratio_f64.to(torch.float32).detach()
[docs]
def validate_elements_in_basis_def(self, data, raise_error=True):
"""
Validate that all elements in the input data exist in the basis_def dictionary.
Notes
-----
This function is used to ensure that all elements in the molecular system
have corresponding basis set definitions before performing calculations.
"""
# Get atomic numbers from data object
if hasattr(data, 'z'):
atomic_numbers = data.z
elif hasattr(data, 'Z'):
atomic_numbers = data.Z
else:
raise AttributeError("Input data object has no 'z' or 'Z' attribute containing atomic numbers")
# Get unique atomic numbers and convert to standard Python list
unique_atomic_numbers = atomic_numbers.unique().cpu().tolist()
# Find atomic numbers that are missing from basis_def using set operations
missing_atomic_numbers = [num for num in unique_atomic_numbers if num not in self.basis_def]
# Check if all elements are valid
all_elements_valid = len(missing_atomic_numbers) == 0
# Handle case when invalid elements are found
if not all_elements_valid and raise_error:
# Try to convert atomic numbers to element symbols if possible
try:
# Attempt to use a periodic table package if available
from qcelemental import periodictable
element_symbols = [f"{periodictable.to_symbol(num)} (Z={num})" for num in missing_atomic_numbers]
except (ImportError, AttributeError):
# Fallback to just showing atomic numbers
element_symbols = [f"Z={num}" for num in missing_atomic_numbers]
raise ValueError(f"The following elements are missing from basis_def: {', '.join(element_symbols)}")
# Return appropriate value based on raise_error parameter
return all_elements_valid if raise_error else (all_elements_valid, missing_atomic_numbers)
[docs]
def forward(self, data, graph_representation: dict = None):
"""
Forward pass of the Hamiltonian prediction model.
This method constructs Hamiltonian matrices for crystal structures from graph representations,
with support for various physical effects including spin-orbit coupling (SOC), magnetism,
and long-range interactions. It can also calculate electronic band structures.
The method handles several different physical regimes:
1. Non-magnetic systems (standard DFT-like Hamiltonians)
2. Collinear spin-polarized systems (separate Hamiltonians for up/down spins)
3. Non-collinear magnetic systems (full 2x2 spin blocks with SOC)
4. Systems with spin-orbit coupling (complex Hamiltonians with real/imaginary parts)
Args:
data (DataObject): Contains crystal structure information including:
- edge_index: Connectivity information between atoms
- z: Atomic numbers
- pos: Atomic positions
- cell: Unit cell vectors
- Hon/Hoff: Reference on-site/off-site Hamiltonian matrices (if available)
- Son/Soff: Reference on-site/off-site overlap matrices
- node_counts: Number of atoms in each crystal
- batch: Batch assignment for each atom
- inv_edge_idx: Inverse edge indices
graph_representation (dict): Output from the graph neural network containing:
- node_attr: Node feature vectors from the GNN
- edge_attr: Edge feature vectors from the GNN
Returns:
dict: Dictionary containing predicted quantities, which may include:
- hamiltonian: Predicted Hamiltonian matrix
- overlap: Predicted overlap matrix (if ham_only=False)
- band_energy: Electronic band energies (if calculate_band_energy=True)
- wavefunction: Eigenvectors of the Hamiltonian (if calculate_band_energy=True)
- band_gap: Band gap values (if calculate_band_energy=True)
- Various additional matrices for debugging or analysis
Note:
The exact contents of the return dictionary depend on the model configuration
parameters (soc_switch, spin_constrained, collinear_spin, etc.)
"""
# Validate that all elements in the data are present in basis_def
self.validate_elements_in_basis_def(data)
# Data format compatibility handling
if 'H0_u' in data:
# Handle legacy format from Hongyu yu's data
Hon_u0 = data.H0_u[:len(data.z)].flatten(1)
Hon_d0 = data.H0_d[:len(data.z)].flatten(1)
Hoff_u0 = data.H0_u[len(data.z):].flatten(1)
Hoff_d0 = data.H0_d[len(data.z):].flatten(1)
data.Hon0 = torch.stack([Hon_u0, Hon_d0], dim=1)
data.Hoff0 = torch.stack([Hoff_u0, Hoff_d0], dim=1)
data.Hon = torch.stack([data.H_u[:len(data.z)], data.H_d[:len(data.z)]], dim=1).flatten(2)
data.Hoff = torch.stack([data.H_u[len(data.z):], data.H_d[len(data.z):]], dim=1).flatten(2)
# Prepare combined Hamiltonian and overlap matrices if not present
if 'hamiltonian' not in data:
data.hamiltonian = self.concatenate_hamiltonians_by_crystal(data, data.Hon, data.Hoff)
if 'overlap' not in data:
data.overlap = self.concatenate_hamiltonians_by_crystal(data, data.Son, data.Soff)
# Extract node and edge attributes from graph representation
node_attr = graph_representation['node_attr']
edge_attr = graph_representation['edge_attr'] # mji
source_indices, target_indices = data.edge_index
# Calculate batch-specific inverse edge indices
inverse_edge_indices = data.inv_edge_idx
edge_count = torch.ones_like(source_indices)
edge_count_by_batch = scatter(edge_count, data.batch[source_indices], dim=0)
edge_count_offset = torch.cumsum(edge_count_by_batch, dim=0) - edge_count_by_batch
inverse_edge_indices = inverse_edge_indices + edge_count_offset[data.batch[source_indices]]
# Ensure irreps dimensions have correct data type
self.hamiltonian_irreps_dimensions = self.hamiltonian_irreps_dimensions.type_as(source_indices)
# === OVERLAP MATRIX CALCULATION (if enabled) ===
if not self.ham_only:
# Calculate on-site overlap matrix
node_spherical_harmonics = self.onsite_overlap_network(node_attr)
node_spherical_components = torch.split(node_spherical_harmonics, self.hamiltonian_irreps_dimensions.tolist(), dim=-1)
onsite_overlap = self.merge_tensor_components(node_spherical_components) # shape (n_atoms, nao_max2)
# Reorder indices to match physical orbital ordering
onsite_overlap = self.reorder_matrix(onsite_overlap)
# Impose Hermitian symmetry
onsite_overlap = self.symmetrize_onsite_hamiltonian(onsite_overlap)
# Calculate off-site overlap matrix
edge_spherical_harmonics = self.offsite_overlap_network(edge_attr)
edge_spherical_components = torch.split(edge_spherical_harmonics, self.hamiltonian_irreps_dimensions.tolist(), dim=-1)
offsite_overlap = self.merge_tensor_components(edge_spherical_components)
# Reorder indices and impose symmetry
offsite_overlap = self.reorder_matrix(offsite_overlap)
offsite_overlap = self.symmetrize_offsite_hamiltonian(offsite_overlap, inverse_edge_indices)
# Apply orbital masking for certain Hamiltonian types
if self.ham_type in ['openmx', 'pasp', 'siesta', 'abacus']:
onsite_overlap, offsite_overlap = self.apply_orbital_masks_to_hamiltonians(onsite_overlap, offsite_overlap, data)
# === HAMILTONIAN CALCULATION WITH SPIN EFFECTS ===
if self.soc_switch or self.spin_constrained:
# Handle spin-orbit coupling
if self.soc_switch:
# Different SOC basis implementations
if self.soc_basis == 'so3':
# SO(3) representation of spin-orbit coupling
if self.add_H_nonsoc:
# Use pre-defined non-SOC Hamiltonians
onsite_hamiltonian, offsite_hamiltonian = data.Hon_nonsoc, data.Hoff_nonsoc
# Process the H0 matrices for SOC
onsite_h0, offsite_h0 = data['Hon0'], data['Hoff0']
onsite_h0_resized = onsite_h0.reshape(-1, 2 * self.nao_max, 2 * self.nao_max)
offsite_h0_resized = offsite_h0.reshape(-1, 2 * self.nao_max, 2 * self.nao_max)
# Create zero blocks for the diagonal blocks
zero_onsite = torch.zeros_like(data['Son']).reshape(-1, self.nao_max, self.nao_max)
zero_offsite = torch.zeros_like(data['Soff']).reshape(-1, self.nao_max, self.nao_max)
# Zero out the diagonal blocks (spin-conserving terms)
onsite_h0_resized[:, :self.nao_max, :self.nao_max] = zero_onsite
onsite_h0_resized[:, self.nao_max:, self.nao_max:] = zero_onsite
offsite_h0_resized[:, :self.nao_max, :self.nao_max] = zero_offsite
offsite_h0_resized[:, self.nao_max:, self.nao_max:] = zero_offsite
# Store processed matrices
data['Hon0'] = onsite_h0_resized.reshape(-1, (2 * self.nao_max) ** 2)
data['Hoff0'] = offsite_h0_resized.reshape(-1, (2 * self.nao_max) ** 2)
else:
# Calculate Hamiltonians from spherical harmonics
node_spherical_harmonics = self.onsite_hamiltonian_network(node_attr)
node_spherical_components = torch.split(node_spherical_harmonics, self.hamiltonian_irreps_dimensions.tolist(), dim=-1)
onsite_hamiltonian = self.merge_tensor_components(node_spherical_components)
# Process on-site Hamiltonian
onsite_hamiltonian = self.reorder_matrix(onsite_hamiltonian)
onsite_hamiltonian = self.symmetrize_onsite_hamiltonian(onsite_hamiltonian)
# Calculate off-site Hamiltonian
edge_spherical_harmonics = self.offsite_hamiltonian_network(edge_attr)
edge_spherical_components = torch.split(edge_spherical_harmonics, self.hamiltonian_irreps_dimensions.tolist(), dim=-1)
offsite_hamiltonian = self.merge_tensor_components(edge_spherical_components)
# Process off-site Hamiltonian
offsite_hamiltonian = self.reorder_matrix(offsite_hamiltonian)
offsite_hamiltonian = self.symmetrize_offsite_hamiltonian(offsite_hamiltonian, inverse_edge_indices)
# Apply orbital masking
onsite_hamiltonian, offsite_hamiltonian = self.apply_orbital_masks_to_hamiltonians(onsite_hamiltonian, offsite_hamiltonian, data)
# Calculate SOC coupling strength
ksi_onsite = self.onsite_ksi_network(node_attr)
ksi_onsite = self.symmetrize_orbital_coefficients(ksi_onsite)
ksi_offsite = self.offsite_ksi_network(edge_attr)
ksi_offsite = self.symmetrize_orbital_coefficients(ksi_offsite)
# Construct SOC Hamiltonian real part (2x2 block structure)
soc_onsite_real = torch.zeros((onsite_hamiltonian.shape[0], 2*self.nao_max, 2*self.nao_max)).type_as(onsite_hamiltonian)
# Diagonal blocks (spin-conserving terms)
soc_onsite_real[:, :self.nao_max, :self.nao_max] = onsite_hamiltonian.reshape(-1, self.nao_max, self.nao_max)
soc_onsite_real[:, self.nao_max:, self.nao_max:] = onsite_hamiltonian.reshape(-1, self.nao_max, self.nao_max)
# Off-diagonal blocks (spin-flip terms)
soc_onsite_real[:, :self.nao_max, self.nao_max:] = self.symmetrize_onsite_hamiltonian(
(ksi_onsite*data.Lon[:, :, 1]), hermitian=False
).reshape(-1, self.nao_max, self.nao_max)
soc_onsite_real[:, self.nao_max:, :self.nao_max] = self.symmetrize_onsite_hamiltonian(
(ksi_onsite*data.Lon[:, :, 1]), hermitian=False
).reshape(-1, self.nao_max, self.nao_max)
soc_onsite_real = soc_onsite_real.reshape(-1, (2*self.nao_max)**2)
# Construct SOC Hamiltonian imaginary part
soc_onsite_imag = torch.zeros((onsite_hamiltonian.shape[0], 2*self.nao_max, 2*self.nao_max)).type_as(onsite_hamiltonian)
# Diagonal blocks
soc_onsite_imag[:, :self.nao_max, :self.nao_max] = self.symmetrize_onsite_hamiltonian(
(ksi_onsite*data.Lon[:, :, 2]), hermitian=False
).reshape(-1, self.nao_max, self.nao_max)
soc_onsite_imag[:, self.nao_max:, self.nao_max:] = -self.symmetrize_onsite_hamiltonian(
(ksi_onsite*data.Lon[:, :, 2]), hermitian=False
).reshape(-1, self.nao_max, self.nao_max)
# Off-diagonal blocks
soc_onsite_imag[:, :self.nao_max, self.nao_max:] = self.symmetrize_onsite_hamiltonian(
(ksi_onsite*data.Lon[:, :, 0]), hermitian=False
).reshape(-1, self.nao_max, self.nao_max)
soc_onsite_imag[:, self.nao_max:, :self.nao_max] = -self.symmetrize_onsite_hamiltonian(
(ksi_onsite*data.Lon[:, :, 0]), hermitian=False
).reshape(-1, self.nao_max, self.nao_max)
soc_onsite_imag = soc_onsite_imag.reshape(-1, (2*self.nao_max)**2)
# Similar construction for off-site SOC Hamiltonian
soc_offsite_real = torch.zeros((offsite_hamiltonian.shape[0], 2*self.nao_max, 2*self.nao_max)).type_as(offsite_hamiltonian)
# Diagonal blocks
soc_offsite_real[:, :self.nao_max, :self.nao_max] = offsite_hamiltonian.reshape(-1, self.nao_max, self.nao_max)
soc_offsite_real[:, self.nao_max:, self.nao_max:] = offsite_hamiltonian.reshape(-1, self.nao_max, self.nao_max)
# Off-diagonal blocks
soc_offsite_real[:, :self.nao_max, self.nao_max:] = self.symmetrize_offsite_hamiltonian(
(ksi_offsite*data.Loff[:, :, 1]), inverse_edge_indices, hermitian=False
).reshape(-1, self.nao_max, self.nao_max)
soc_offsite_real[:, self.nao_max:, :self.nao_max] = self.symmetrize_offsite_hamiltonian(
(ksi_offsite*data.Loff[:, :, 1]), inverse_edge_indices, hermitian=False
).reshape(-1, self.nao_max, self.nao_max)
soc_offsite_real = soc_offsite_real.reshape(-1, (2*self.nao_max)**2)
# Off-site SOC Hamiltonian imaginary part
soc_offsite_imag = torch.zeros((offsite_hamiltonian.shape[0], 2*self.nao_max, 2*self.nao_max)).type_as(offsite_hamiltonian)
# Diagonal blocks
soc_offsite_imag[:, :self.nao_max, :self.nao_max] = self.symmetrize_offsite_hamiltonian(
(ksi_offsite*data.Loff[:, :, 2]), inverse_edge_indices, hermitian=False
).reshape(-1, self.nao_max, self.nao_max)
soc_offsite_imag[:, self.nao_max:, self.nao_max:] = -self.symmetrize_offsite_hamiltonian(
(ksi_offsite*data.Loff[:, :, 2]), inverse_edge_indices, hermitian=False
).reshape(-1, self.nao_max, self.nao_max)
# Off-diagonal blocks
soc_offsite_imag[:, :self.nao_max, self.nao_max:] = self.symmetrize_offsite_hamiltonian(
(ksi_offsite*data.Loff[:, :, 0]), inverse_edge_indices, hermitian=False
).reshape(-1, self.nao_max, self.nao_max)
soc_offsite_imag[:, self.nao_max:, :self.nao_max] = -self.symmetrize_offsite_hamiltonian(
(ksi_offsite*data.Loff[:, :, 0]), inverse_edge_indices, hermitian=False
).reshape(-1, self.nao_max, self.nao_max)
soc_offsite_imag = soc_offsite_imag.reshape(-1, (2*self.nao_max)**2)
elif self.soc_basis == 'su2':
# SU(2) representation of spin-orbit coupling
node_spherical_harmonics = self.onsite_hamiltonian_network(node_attr)
onsite_hamiltonian = self.hamiltonian_decomposition.get_H(node_spherical_harmonics)
onsite_hamiltonian = self.reorder_matrix(onsite_hamiltonian)
onsite_hamiltonian = onsite_hamiltonian.reshape(-1, 2, 2, self.nao_max, self.nao_max)
onsite_hamiltonian = torch.swapaxes(onsite_hamiltonian, 2, 3) # shape (n_atoms, 2, nao_max, 2, nao_max)
onsite_hamiltonian = self.symmetrize_onsite_hamiltonian_soc(onsite_hamiltonian).reshape(-1, 2, self.nao_max, 2, self.nao_max) # Ensure Hermiticity
# Calculate off-site Hamiltonian
edge_spherical_harmonics = self.offsite_hamiltonian_network(edge_attr)
offsite_hamiltonian = self.hamiltonian_decomposition.get_H(edge_spherical_harmonics)
offsite_hamiltonian = self.reorder_matrix(offsite_hamiltonian)
offsite_hamiltonian = offsite_hamiltonian.reshape(-1, 2, 2, self.nao_max, self.nao_max)
offsite_hamiltonian = torch.swapaxes(offsite_hamiltonian, 2, 3) # shape (n_edges, 2, nao_max, 2, nao_max)
offsite_hamiltonian = self.symmetrize_offsite_hamiltonian_soc(offsite_hamiltonian, inverse_edge_indices).reshape(-1, 2, self.nao_max, 2, self.nao_max) # Ensure Hermiticity
# Apply orbital masking to each spin block
for i in range(2):
for j in range(2):
onsite_hamiltonian[:, i, :, j, :], offsite_hamiltonian[:, i, :, j, :] = self.apply_orbital_masks_to_hamiltonians(
onsite_hamiltonian[:, i, :, j, :], offsite_hamiltonian[:, i, :, j, :], data
)
# Reshape to flattened form
onsite_hamiltonian = onsite_hamiltonian.reshape(-1, (2*self.nao_max)**2)
offsite_hamiltonian = offsite_hamiltonian.reshape(-1, (2*self.nao_max)**2)
# Extract real and imaginary parts
soc_onsite_real = onsite_hamiltonian.real
soc_offsite_real = offsite_hamiltonian.real
soc_onsite_imag = onsite_hamiltonian.imag
soc_offsite_imag = offsite_hamiltonian.imag
else:
raise NotImplementedError("Unsupported SOC basis")
else: # No SOC, but spin-constrained
# Calculate standard Hamiltonian
node_spherical_harmonics = self.onsite_hamiltonian_network(node_attr)
node_spherical_components = torch.split(node_spherical_harmonics, self.hamiltonian_irreps_dimensions.tolist(), dim=-1)
onsite_hamiltonian = self.merge_tensor_components(node_spherical_components)
onsite_hamiltonian = self.reorder_matrix(onsite_hamiltonian)
onsite_hamiltonian = self.symmetrize_onsite_hamiltonian(onsite_hamiltonian)
# Calculate off-site Hamiltonian
edge_spherical_harmonics = self.offsite_hamiltonian_network(edge_attr)
edge_spherical_components = torch.split(edge_spherical_harmonics, self.hamiltonian_irreps_dimensions.tolist(), dim=-1)
offsite_hamiltonian = self.merge_tensor_components(edge_spherical_components)
offsite_hamiltonian = self.reorder_matrix(offsite_hamiltonian)
offsite_hamiltonian = self.symmetrize_offsite_hamiltonian(offsite_hamiltonian, inverse_edge_indices)
# Apply orbital masking
onsite_hamiltonian, offsite_hamiltonian = self.apply_orbital_masks_to_hamiltonians(onsite_hamiltonian, offsite_hamiltonian, data)
# For non-collinear spin without SOC, create block diagonal matrices
if not self.collinear_spin:
soc_onsite_real = torch.zeros_like(data.Hon).reshape(onsite_hamiltonian.shape[0], 2*self.nao_max, 2*self.nao_max)
soc_onsite_real[:, :self.nao_max, :self.nao_max] = onsite_hamiltonian.reshape(-1, self.nao_max, self.nao_max)
soc_onsite_real[:, self.nao_max:, self.nao_max:] = onsite_hamiltonian.reshape(-1, self.nao_max, self.nao_max)
soc_onsite_real = soc_onsite_real.reshape(onsite_hamiltonian.shape[0], (2*self.nao_max)**2)
soc_offsite_real = torch.zeros_like(data.Hoff).reshape(offsite_hamiltonian.shape[0], 2*self.nao_max, 2*self.nao_max)
soc_offsite_real[:, :self.nao_max, :self.nao_max] = offsite_hamiltonian.reshape(-1, self.nao_max, self.nao_max)
soc_offsite_real[:, self.nao_max:, self.nao_max:] = offsite_hamiltonian.reshape(-1, self.nao_max, self.nao_max)
soc_offsite_real = soc_offsite_real.reshape(offsite_hamiltonian.shape[0], (2*self.nao_max)**2)
# Zero imaginary parts
soc_onsite_imag = torch.zeros_like(data.iHon)
soc_offsite_imag = torch.zeros_like(data.iHoff)
# === MAGNETIC INTERACTION CALCULATION ===
if self.spin_constrained:
# Identify magnetic atoms
magnetic_atoms = (data.spin_length > self.min_magnetic_moment)
# Extract cell shift information
data.unique_cell_shift, data.cell_shift_indices, data.cell_index_map = self.extract_unique_cell_vectors(data)
cell_shift_indices = data.cell_shift_indices.tolist()
cell_index_map = data.cell_index_map
# Learn weight matrices if enabled
if self.use_learned_weight:
# Calculate on-site weight matrix
node_spherical_harmonics = self.onsite_weight_network(node_attr)
node_spherical_components = torch.split(node_spherical_harmonics, self.hamiltonian_irreps_dimensions.tolist(), dim=-1)
weight_onsite = self.merge_tensor_components(node_spherical_components)
weight_onsite = self.reorder_matrix(weight_onsite)
weight_onsite = self.symmetrize_onsite_hamiltonian(weight_onsite)
# Calculate off-site weight matrix
edge_spherical_harmonics = self.offsite_weight_network(edge_attr)
edge_spherical_components = torch.split(edge_spherical_harmonics, self.hamiltonian_irreps_dimensions.tolist(), dim=-1)
weight_offsite = self.merge_tensor_components(edge_spherical_components)
weight_offsite = self.reorder_matrix(weight_offsite)
weight_offsite = self.symmetrize_offsite_hamiltonian(weight_offsite, inverse_edge_indices)
# Apply orbital masking to weights
weight_onsite, weight_offsite = self.apply_orbital_masks_to_hamiltonians(weight_onsite, weight_offsite, data)
weight_onsite = weight_onsite.reshape(-1, self.nao_max, self.nao_max)
weight_offsite = weight_offsite.reshape(-1, self.nao_max, self.nao_max)
# Store weights for later use
data.weight_on = weight_onsite
data.weight_off = weight_offsite
# Calculate magnetic interaction terms
if self.soc_switch:
# Calculate Heisenberg J coupling for SOC
J_onsite = self.onsite_J_network(node_attr)
J_onsite = self.construct_j_coupling_matrix(J_onsite) # shape: (n_atoms, nao_max, nao_max, 3, 3)
J_offsite = self.offsite_J_network(edge_attr)
J_offsite = self.construct_j_coupling_matrix(J_offsite) # shape: (n_edges, nao_max, nao_max, 3, 3)
# Calculate quartic terms if enabled
if self.add_quartic:
K_onsite = self.onsite_K_network(node_attr)
K_onsite = self.construct_k_coupling_matrix(K_onsite) # shape: (n_atoms, nao_max, nao_max)
K_offsite = self.offsite_K_network(edge_attr)
K_offsite = self.construct_k_coupling_matrix(K_offsite) # shape: (n_edges, nao_max, nao_max)
# Define Pauli matrices for spin operators
sigma = torch.view_as_complex(torch.zeros((3, 2, 2, 2)).type_as(J_onsite))
sigma[0] = torch.Tensor([[0.0, 1.0], [1.0, 0.0]]).type_as(sigma) # ��x
sigma[1] = torch.complex(
real=torch.zeros((2, 2)),
imag=torch.Tensor([[0.0, -1.0], [1.0, 0.0]])
).type_as(sigma) # ��y
sigma[2] = torch.Tensor([[1.0, 0.0], [0.0, -1.0]]).type_as(sigma) # ��z
# Get spin vectors
spin_vec = data.spin_vec
# Initialize Heisenberg interaction terms
H_heisen_J_onsite = torch.zeros(len(J_onsite), 2, self.nao_max, 2, self.nao_max).type_as(sigma)
H_heisen_J_offsite = torch.zeros(len(source_indices), 2, self.nao_max, 2, self.nao_max).type_as(sigma)
# Optimize performance with edge lookup structures
edge_matcher_src, edge_matcher_tar = self.build_edge_lookup_structures(data, inverse_edge_indices)
# Calculate on-site Heisenberg terms for magnetic atoms
# S��J��S terms using opt_einsum for efficient contraction
H_heisen_J_onsite[magnetic_atoms] += oe.contract(
'mijkl,mij,kop,ml->moipj',
J_onsite[magnetic_atoms].type_as(sigma),
weight_onsite[magnetic_atoms].type_as(sigma),
sigma,
spin_vec[magnetic_atoms].type_as(sigma)
)
H_heisen_J_onsite[magnetic_atoms] += oe.contract(
'mijkl,mij,lop,mk->moipj',
J_onsite[magnetic_atoms].type_as(sigma),
weight_onsite[magnetic_atoms].type_as(sigma),
sigma,
spin_vec[magnetic_atoms].type_as(sigma)
)
# Calculate off-site interactions via neighboring atoms
zero_shift_idx = cell_index_map[(0, 0, 0)]
for atom_idx in range(len(J_onsite)):
# Process source atom contributions
if magnetic_atoms[atom_idx]:
# Get all edges where this atom appears, including zero-shift cells
zero_shift_edges = edge_matcher_tar[atom_idx][zero_shift_idx]
source_edges = torch.cat([edge_matcher_src[atom_idx], zero_shift_edges])
offsite_weights = weight_offsite[source_edges]
# Calculate Heisenberg interactions for these edges
H_heisen_J_offsite[source_edges] += oe.contract(
'ijkl,mij,kop,l->moipj',
J_onsite[atom_idx].type_as(sigma),
offsite_weights.type_as(sigma),
sigma,
spin_vec[atom_idx].type_as(sigma)
)
H_heisen_J_offsite[source_edges] += oe.contract(
'ijkl,mij,lop,k->moipj',
J_onsite[atom_idx].type_as(sigma),
offsite_weights.type_as(sigma),
sigma,
spin_vec[atom_idx].type_as(sigma)
)
# Process edge-specific interactions
for edge_idx in range(len(source_indices)):
source_atom = source_indices[edge_idx].item()
target_atom = target_indices[edge_idx].item()
# Target atom contribution to source
if magnetic_atoms[target_atom]:
onsite_weight = weight_onsite[source_atom]
source_offsite_weights = weight_offsite[edge_matcher_src[source_atom]]
# Update on-site term at source atom
H_heisen_J_onsite[source_atom] += oe.contract(
'ijkl,ij,kop,l->oipj',
J_offsite[edge_idx].type_as(sigma),
onsite_weight.type_as(sigma),
sigma,
spin_vec[target_atom].type_as(sigma)
)
# Update off-site terms from source atom
H_heisen_J_offsite[edge_matcher_src[source_atom]] += oe.contract(
'ijkl,mij,kop,l->moipj',
J_offsite[edge_idx].type_as(sigma),
source_offsite_weights.type_as(sigma),
sigma,
spin_vec[target_atom].type_as(sigma)
)
# Source atom contribution to target
if magnetic_atoms[source_atom]:
target_offsite_weights = weight_offsite[edge_matcher_tar[target_atom][cell_shift_indices[edge_idx]]]
# Update off-site terms to target atom
H_heisen_J_offsite[edge_matcher_tar[target_atom][cell_shift_indices[edge_idx]]] += oe.contract(
'ijkl,mij,lop,k->moipj',
J_offsite[edge_idx].type_as(sigma),
target_offsite_weights.type_as(sigma),
sigma,
spin_vec[source_atom].type_as(sigma)
)
# For zero-cell-shift, also update on-site term at target
if cell_shift_indices[edge_idx] == data.cell_index_map[(0, 0, 0)]:
target_onsite_weight = weight_onsite[target_atom]
H_heisen_J_onsite[target_atom] += oe.contract(
'ijkl,ij,lop,k->oipj',
J_offsite[edge_idx].type_as(sigma),
target_onsite_weight.type_as(sigma),
sigma,
spin_vec[source_atom].type_as(sigma)
)
else:
# Non-SOC magnetic interactions
J_onsite = self.onsite_J_network(node_attr)
J_onsite = self.construct_j_coupling_matrix(J_onsite) # shape: (n_atoms, nao_max, nao_max)
J_offsite = self.offsite_J_network(edge_attr)
J_offsite = self.construct_j_coupling_matrix(J_offsite) # shape: (n_edges, nao_max, nao_max)
# Calculate quartic terms if enabled
if self.add_quartic:
K_onsite = self.onsite_K_network(node_attr)
K_onsite = self.construct_k_coupling_matrix(K_onsite)
K_offsite = self.offsite_K_network(edge_attr)
K_offsite = self.construct_k_coupling_matrix(K_offsite)
# Handle collinear vs non-collinear spin
if self.collinear_spin:
# Collinear spin uses only z-component
sigma_z = torch.Tensor([[1.0, 0.0], [0.0, -1.0]]).type_as(J_onsite)
spin_vec = data.spin_vec
# Initialize Heisenberg interaction matrices
H_heisen_J_onsite = torch.zeros(len(J_onsite), 2, self.nao_max, 2, self.nao_max).type_as(J_onsite)
H_heisen_J_offsite = torch.zeros(len(source_indices), 2, self.nao_max, 2, self.nao_max).type_as(J_offsite)
# Build edge lookup for efficient calculation
edge_matcher_src, edge_matcher_tar = self.build_edge_lookup_structures(data, inverse_edge_indices)
# Calculate on-site contributions for magnetic atoms
H_heisen_J_onsite[magnetic_atoms] += oe.contract(
'mij,mij,op,m->moipj',
J_onsite[magnetic_atoms],
weight_onsite[magnetic_atoms],
sigma_z,
spin_vec[magnetic_atoms, 2] # z-component only
)
# Calculate off-site contributions
zero_shift_idx = cell_index_map[(0, 0, 0)]
for atom_idx in range(len(J_onsite)):
if magnetic_atoms[atom_idx]:
zero_shift_edges = edge_matcher_tar[atom_idx][zero_shift_idx]
source_edges = torch.cat([edge_matcher_src[atom_idx], zero_shift_edges])
offsite_weights = weight_offsite[source_edges]
H_heisen_J_offsite[source_edges] += oe.contract(
'ij,mij,op->moipj',
J_onsite[atom_idx],
offsite_weights,
sigma_z
) * spin_vec[atom_idx, 2]
# Process edge-specific interactions
for edge_idx in range(len(source_indices)):
source_atom = source_indices[edge_idx].item()
target_atom = target_indices[edge_idx].item()
# Target atom contribution to source
if magnetic_atoms[target_atom]:
onsite_weight = weight_onsite[source_atom]
source_offsite_weights = weight_offsite[edge_matcher_src[source_atom]]
H_heisen_J_onsite[source_atom] += oe.contract(
'ij,ij,op->oipj',
J_offsite[edge_idx],
onsite_weight,
sigma_z
) * spin_vec[target_atom, 2]
H_heisen_J_offsite[edge_matcher_src[source_atom]] += oe.contract(
'ij,mij,op->moipj',
J_offsite[edge_idx],
source_offsite_weights,
sigma_z
) * spin_vec[target_atom, 2]
# Source atom contribution to target
if magnetic_atoms[source_atom]:
target_offsite_weights = weight_offsite[edge_matcher_tar[target_atom][cell_shift_indices[edge_idx]]]
H_heisen_J_offsite[edge_matcher_tar[target_atom][cell_shift_indices[edge_idx]]] += oe.contract(
'ij,mij,op->moipj',
J_offsite[edge_idx],
target_offsite_weights,
sigma_z
) * spin_vec[source_atom, 2]
if cell_shift_indices[edge_idx] == data.cell_index_map[(0, 0, 0)]:
target_onsite_weight = weight_onsite[target_atom]
H_heisen_J_onsite[target_atom] += oe.contract(
'ij,ij,op->oipj',
J_offsite[edge_idx],
target_onsite_weight,
sigma_z
) * spin_vec[source_atom, 2]
else:
# Non-collinear spin handling (similar to SOC case but simpler)
# Define Pauli matrices
sigma = torch.view_as_complex(torch.zeros((3, 2, 2, 2)).type_as(J_onsite))
sigma[0] = torch.Tensor([[0.0, 1.0], [1.0, 0.0]]).type_as(sigma)
sigma[1] = torch.complex(
real=torch.zeros((2, 2)),
imag=torch.Tensor([[0.0, -1.0], [1.0, 0.0]])
).type_as(sigma)
sigma[2] = torch.Tensor([[1.0, 0.0], [0.0, -1.0]]).type_as(sigma)
spin_vec = data.spin_vec
# Initialize Heisenberg matrices
H_heisen_J_onsite = torch.zeros(len(J_onsite), 2, self.nao_max, 2, self.nao_max).type_as(sigma)
H_heisen_J_offsite = torch.zeros(len(source_indices), 2, self.nao_max, 2, self.nao_max).type_as(sigma)
# Build edge lookup
edge_matcher_src, edge_matcher_tar = self.build_edge_lookup_structures(data, inverse_edge_indices)
# Calculate on-site contributions for magnetic atoms
H_heisen_J_onsite[magnetic_atoms] += oe.contract(
'mij,mij,kop,mk->moipj',
J_onsite[magnetic_atoms].type_as(sigma),
weight_onsite[magnetic_atoms].type_as(sigma),
sigma,
spin_vec[magnetic_atoms].type_as(sigma)
)
# Calculate off-site contributions
zero_shift_idx = cell_index_map[(0, 0, 0)]
for atom_idx in range(len(J_onsite)):
if magnetic_atoms[atom_idx]:
zero_shift_edges = edge_matcher_tar[atom_idx][zero_shift_idx]
source_edges = torch.cat([edge_matcher_src[atom_idx], zero_shift_edges])
offsite_weights = weight_offsite[source_edges]
H_heisen_J_offsite[source_edges] += oe.contract(
'ij,mij,kop,k->moipj',
J_onsite[atom_idx].type_as(sigma),
offsite_weights.type_as(sigma),
sigma,
spin_vec[atom_idx].type_as(sigma)
)
# Process edge-specific interactions
for edge_idx in range(len(source_indices)):
source_atom = source_indices[edge_idx].item()
target_atom = target_indices[edge_idx].item()
# Target atom contribution to source
if magnetic_atoms[target_atom]:
onsite_weight = weight_onsite[source_atom]
source_offsite_weights = weight_offsite[edge_matcher_src[source_atom]]
H_heisen_J_onsite[source_atom] += oe.contract(
'ij,ij,kop,k->oipj',
J_offsite[edge_idx].type_as(sigma),
onsite_weight.type_as(sigma),
sigma,
spin_vec[target_atom].type_as(sigma)
)
H_heisen_J_offsite[edge_matcher_src[source_atom]] += oe.contract(
'ij,mij,kop,k->moipj',
J_offsite[edge_idx].type_as(sigma),
source_offsite_weights.type_as(sigma),
sigma,
spin_vec[target_atom].type_as(sigma)
)
# Source atom contribution to target
if magnetic_atoms[source_atom]:
target_offsite_weights = weight_offsite[edge_matcher_tar[target_atom][cell_shift_indices[edge_idx]]]
H_heisen_J_offsite[edge_matcher_tar[target_atom][cell_shift_indices[edge_idx]]] += oe.contract(
'ij,mij,kop,k->moipj',
J_offsite[edge_idx].type_as(sigma),
target_offsite_weights.type_as(sigma),
sigma,
spin_vec[source_atom].type_as(sigma)
)
if cell_shift_indices[edge_idx] == data.cell_index_map[(0, 0, 0)]:
target_onsite_weight = weight_onsite[target_atom]
H_heisen_J_onsite[target_atom] += oe.contract(
'ij,ij,kop,k->oipj',
J_offsite[edge_idx].type_as(sigma),
target_onsite_weight.type_as(sigma),
sigma,
spin_vec[source_atom].type_as(sigma)
)
# Combine Heisenberg terms with SOC Hamiltonian
if not self.collinear_spin:
# Add real and imaginary parts separately
soc_onsite_real = soc_onsite_real + H_heisen_J_onsite.reshape(-1, (2*self.nao_max)**2).real
soc_offsite_real = soc_offsite_real + H_heisen_J_offsite.reshape(-1, (2*self.nao_max)**2).real
soc_onsite_imag = soc_onsite_imag + H_heisen_J_onsite.reshape(-1, (2*self.nao_max)**2).imag
soc_offsite_imag = soc_offsite_imag + H_heisen_J_offsite.reshape(-1, (2*self.nao_max)**2).imag
# Apply symmetrization if enabled
if self.symmetrize:
soc_onsite_real = self.symmetrize_onsite_hamiltonian_soc(soc_onsite_real, hermitian=True)
soc_offsite_real = self.symmetrize_offsite_hamiltonian_soc(soc_offsite_real, inverse_edge_indices, hermitian=True)
soc_onsite_imag = self.symmetrize_onsite_hamiltonian_soc(soc_onsite_imag, hermitian=False)
soc_offsite_imag = self.symmetrize_offsite_hamiltonian_soc(soc_offsite_imag, inverse_edge_indices, hermitian=False)
else:
# For collinear spin, create separate up and down Hamiltonians
collinear_onsite = torch.stack([
onsite_hamiltonian.reshape(-1, self.nao_max, self.nao_max) + H_heisen_J_onsite[:, 0, :, 0, :],
onsite_hamiltonian.reshape(-1, self.nao_max, self.nao_max) + H_heisen_J_onsite[:, 1, :, 1, :]
], dim=1).reshape(-1, 2, (self.nao_max)**2)
collinear_offsite = torch.stack([
offsite_hamiltonian.reshape(-1, self.nao_max, self.nao_max) + H_heisen_J_offsite[:, 0, :, 0, :],
offsite_hamiltonian.reshape(-1, self.nao_max, self.nao_max) + H_heisen_J_offsite[:, 1, :, 1, :]
], dim=1).reshape(-1, 2, (self.nao_max)**2)
# Add reference H0 terms if enabled
if self.add_H0:
if not self.collinear_spin:
soc_onsite_real = soc_onsite_real + data.Hon0
soc_offsite_real = soc_offsite_real + data.Hoff0
soc_onsite_imag = soc_onsite_imag + data.iHon0
soc_offsite_imag = soc_offsite_imag + data.iHoff0
else:
collinear_onsite = collinear_onsite + data.Hon0
collinear_offsite = collinear_offsite + data.Hoff0
# === COMBINE HAMILTONIANS AND CALCULATE BAND STRUCTURES ===
if not self.collinear_spin:
# Combine real and imaginary parts of SOC Hamiltonian
soc_hamiltonian_real = self.concatenate_hamiltonians_by_crystal(data, soc_onsite_real, soc_offsite_real)
soc_hamiltonian_imag = self.concatenate_hamiltonians_by_crystal(data, soc_onsite_imag, soc_offsite_imag)
# Store reference Hamiltonians
data.hamiltonian_real = self.concatenate_hamiltonians_by_crystal(data, data.Hon, data.Hoff)
data.hamiltonian_imag = self.concatenate_hamiltonians_by_crystal(data, data.iHon, data.iHoff)
# Combine real and imaginary parts
soc_hamiltonian = torch.cat((soc_hamiltonian_real, soc_hamiltonian_imag), dim=0)
data.hamiltonian = torch.cat((data.hamiltonian_real, data.hamiltonian_imag), dim=0)
# Calculate band structure if requested
if self.calculate_band_energy:
# Generate k-points for each crystal
k_vectors = []
for idx in range(data.batch[-1] + 1):
lattice_vectors = data.cell
# Generate k-point path based on configuration
if self.k_path is not None:
kpoints = kpoints_generator(dim_k=3, lat=lattice_vectors[idx].detach().cpu().numpy())
k_vec, k_dist, k_node, lat_per_inv = kpoints.k_path(self.k_path, self.num_k)
else:
# Random k-points if no path specified
lat_per_inv = np.linalg.inv(lattice_vectors[idx].detach().cpu().numpy()).T
k_vec = 2.0 * np.random.rand(self.num_k, 3) - 1.0 # (-1, 1)
# Transform k-points to reciprocal space
k_vec = k_vec.dot(lat_per_inv[np.newaxis, :, :]) # shape (nk, 1, 3)
k_vec = k_vec.reshape(-1, 3) # shape (nk, 3)
k_vec = torch.Tensor(k_vec).type_as(onsite_hamiltonian)
k_vectors.append(k_vec)
# Store k-vectors for band structure calculation
data.k_vecs = torch.stack(k_vectors, dim=0)
# Calculate band energies and wavefunctions with SOC
band_energy, wavefunction = self.calculate_band_energies_with_spin_orbit_coupling(
soc_onsite_real, soc_onsite_imag, soc_offsite_real, soc_offsite_imag, data
)
# Calculate reference band structure
with torch.no_grad():
data.band_energy, data.wavefunction = self.calculate_band_energies_with_spin_orbit_coupling(
data.Hon, data.iHon, data.Hoff, data.iHoff, data
)
else:
band_energy = None
wavefunction = None
else:
# For collinear spin or non-SOC case
collinear_hamiltonian = self.concatenate_hamiltonians_by_crystal(data, collinear_onsite, collinear_offsite)
data.hamiltonian = self.concatenate_hamiltonians_by_crystal(data, data.Hon, data.Hoff)
# Calculate band structure if requested
if self.calculate_band_energy:
k_vectors = []
for idx in range(data.batch[-1] + 1):
lattice_vectors = data.cell
# Generate k-point path based on configuration
if isinstance(self.k_path, list):
kpoints = kpoints_generator(dim_k=3, lat=lattice_vectors[idx].detach().cpu().numpy())
k_vec, k_dist, k_node, lat_per_inv = kpoints.k_path(self.k_path, self.num_k)
elif isinstance(self.k_path, str) and self.k_path.lower() == 'auto':
# Automatic k-path generation based on crystal symmetry
lattice = lattice_vectors[idx].detach().cpu().numpy() * au2ang
positions = torch.split(data.pos, data.node_counts.tolist(), dim=0)[idx].detach().cpu().numpy() * au2ang
species = torch.split(data.z, data.node_counts.tolist(), dim=0)[idx]
# Create pymatgen Structure
structure = Structure(
lattice=lattice,
species=[Element.from_Z(k.item()).symbol for k in species],
coords=positions,
coords_are_cartesian=True
)
# Generate k-path using symmetry
kpath_seek = KPathSeek(structure=structure)
k_labels = []
for label_group in kpath_seek.kpath['path']:
k_labels += label_group
# Remove adjacent duplicates
unique_labels = [k_labels[0]]
[unique_labels.append(x) for x in k_labels[1:] if x != unique_labels[-1]]
k_path = [kpath_seek.kpath['kpoints'][k] for k in unique_labels]
try:
kpoints = kpoints_generator(dim_k=3, lat=lattice_vectors[idx].detach().cpu().numpy())
k_vec, k_dist, k_node, lat_per_inv = kpoints.k_path(k_path, self.num_k)
except:
# Fallback to random k-points if path generation fails
lat_per_inv = np.linalg.inv(lattice_vectors[idx].detach().cpu().numpy()).T
k_vec = 2.0 * np.random.rand(self.num_k, 3) - 1.0 # (-1, 1)
else:
# Random k-points if no path specified
lat_per_inv = np.linalg.inv(lattice_vectors[idx].detach().cpu().numpy()).T
k_vec = 2.0 * np.random.rand(self.num_k, 3) - 1.0 # (-1, 1)
# Transform k-points to reciprocal space
k_vec = k_vec.dot(lat_per_inv[np.newaxis, :, :])
k_vec = k_vec.reshape(-1, 3)
k_vec = torch.Tensor(k_vec).type_as(onsite_hamiltonian)
k_vectors.append(k_vec)
# Store k-vectors for band structure calculation
data.k_vecs = torch.stack(k_vectors, dim=0)
if self.export_reciprocal_values:
# Calculate band structure for both spin channels with reciprocal space matrices
band_energy_up, wavefunction_up, HK_up, SK_up, dSK_up, gap_up = self.calculate_band_energies(
collinear_onsite[:, 0, :], collinear_offsite[:, 0, :], data, True
)
band_energy_down, wavefunction_down, HK_down, SK_down, dSK_down, gap_down = self.calculate_band_energies(
collinear_onsite[:, 1, :], collinear_offsite[:, 1, :], data, True
)
H_sym = None
band_energy = torch.cat([band_energy_up, band_energy_down])
wavefunction = torch.cat([wavefunction_up, wavefunction_down])
HK = torch.cat([HK_up, HK_down])
gap = torch.cat([gap_up, gap_down])
else:
# Calculate band structure for both spin channels
band_energy_up, wavefunction_up, gap_up, H_sym = self.calculate_band_energies(
collinear_onsite[:, 0, :], collinear_offsite[:, 0, :], data
)
band_energy_down, wavefunction_down, gap_down, H_sym = self.calculate_band_energies(
collinear_onsite[:, 1, :], collinear_offsite[:, 1, :], data
)
band_energy = torch.cat([band_energy_up, band_energy_down])
wavefunction = torch.cat([wavefunction_up, wavefunction_down])
gap = torch.cat([gap_up, gap_down])
# Calculate reference band structure
with torch.no_grad():
data.band_energy_up, data.wavefunction, data.band_gap_up, data.H_sym = self.calculate_band_energies(
data.Hon[:, 0, :], data.Hoff[:, 0, :], data
)
data.band_energy_down, data.wavefunction, data.band_gap_down, data.H_sym = self.calculate_band_energies(
data.Hon[:, 1, :], data.Hoff[:, 1, :], data
)
data.band_energy = torch.cat([data.band_energy_up, data.band_energy_down])
data.band_gap = torch.cat([data.band_gap_up, data.band_gap_down])
else:
band_energy = None
wavefunction = None
gap = None
H_sym = None
# === NON-MAGNETIC, NON-SOC CASE ===
else:
# Calculate standard Hamiltonian
node_spherical_harmonics = self.onsite_hamiltonian_network(node_attr)
node_spherical_components = torch.split(node_spherical_harmonics, self.hamiltonian_irreps_dimensions.tolist(), dim=-1)
onsite_hamiltonian = self.merge_tensor_components(node_spherical_components)
onsite_hamiltonian = self.reorder_matrix(onsite_hamiltonian)
onsite_hamiltonian = self.symmetrize_onsite_hamiltonian(onsite_hamiltonian)
# Add reference H0 if available
if self.add_H0:
onsite_hamiltonian = onsite_hamiltonian + data.Hon0
# Calculate off-site Hamiltonian
edge_spherical_harmonics = self.offsite_hamiltonian_network(edge_attr)
edge_spherical_components = torch.split(edge_spherical_harmonics, self.hamiltonian_irreps_dimensions.tolist(), dim=-1)
offsite_hamiltonian = self.merge_tensor_components(edge_spherical_components)
offsite_hamiltonian = self.reorder_matrix(offsite_hamiltonian)
offsite_hamiltonian = self.symmetrize_offsite_hamiltonian(offsite_hamiltonian, inverse_edge_indices)
# Add reference H0 if available
if self.add_H0:
offsite_hamiltonian = offsite_hamiltonian + data.Hoff0
# Apply orbital masking for specific Hamiltonian types
if self.ham_type in ['openmx', 'pasp', 'siesta', 'abacus']:
onsite_hamiltonian, offsite_hamiltonian = self.apply_orbital_masks_to_hamiltonians(onsite_hamiltonian, offsite_hamiltonian, data)
# Calculate band structure if requested
if self.calculate_band_energy:
# Generate k-points for each crystal
k_vectors = []
for idx in range(data.batch[-1] + 1):
lattice_vectors = data.cell
# Generate k-point path based on configuration
if isinstance(self.k_path, list):
kpoints = kpoints_generator(dim_k=3, lat=lattice_vectors[idx].detach().cpu().numpy())
k_vec, k_dist, k_node, lat_per_inv = kpoints.k_path(self.k_path, self.num_k)
elif isinstance(self.k_path, str) and self.k_path.lower() == 'auto':
# Automatic k-path generation
lattice = lattice_vectors[idx].detach().cpu().numpy() * au2ang
positions = torch.split(data.pos, data.node_counts.tolist(), dim=0)[idx].detach().cpu().numpy() * au2ang
species = torch.split(data.z, data.node_counts.tolist(), dim=0)[idx]
structure = Structure(
lattice=lattice,
species=[Element.from_Z(k.item()).symbol for k in species],
coords=positions,
coords_are_cartesian=True
)
kpath_seek = KPathSeek(structure=structure)
k_labels = []
for label_group in kpath_seek.kpath['path']:
k_labels += label_group
unique_labels = [k_labels[0]]
[unique_labels.append(x) for x in k_labels[1:] if x != unique_labels[-1]]
k_path = [kpath_seek.kpath['kpoints'][k] for k in unique_labels]
try:
kpoints = kpoints_generator(dim_k=3, lat=lattice_vectors[idx].detach().cpu().numpy())
k_vec, k_dist, k_node, lat_per_inv = kpoints.k_path(k_path, self.num_k)
except:
# Fallback to random k-points
lat_per_inv = np.linalg.inv(lattice_vectors[idx].detach().cpu().numpy()).T
k_vec = 2.0 * np.random.rand(self.num_k, 3) - 1.0
else:
# Random k-points
lat_per_inv = np.linalg.inv(lattice_vectors[idx].detach().cpu().numpy()).T
k_vec = 2.0 * np.random.rand(self.num_k, 3) - 1.0
# Transform k-points to reciprocal space
k_vec = k_vec.dot(lat_per_inv[np.newaxis, :, :])
k_vec = k_vec.reshape(-1, 3)
k_vec = torch.Tensor(k_vec).type_as(onsite_hamiltonian)
k_vectors.append(k_vec)
# Store k-vectors for band structure calculation
data.k_vecs = torch.stack(k_vectors, dim=0)
if self.export_reciprocal_values:
if self.ham_only:
# Calculate band structure with reciprocal space matrices
band_energy, wavefunction, HK, SK, dSK, gap = self.calculate_band_energies(
onsite_hamiltonian, offsite_hamiltonian, data, True
)
H_sym = None
else:
# Calculate band structure with overlap matrices
band_energy, wavefunction, HK, SK, dSK, gap = self.calculate_band_energies_with_overlap(
onsite_hamiltonian, offsite_hamiltonian, onsite_overlap, offsite_overlap, data, True
)
H_sym = None
else:
# Standard band structure calculation
band_energy, wavefunction, gap, H_sym = self.calculate_band_energies(
onsite_hamiltonian, offsite_hamiltonian, data
)
# Calculate reference band structure
with torch.no_grad():
data.band_energy, data.wavefunction, data.band_gap, data.H_sym = self.calculate_band_energies(
data.Hon, data.Hoff, data
)
else:
band_energy = None
wavefunction = None
gap = None
H_sym = None
# === PREPARE RESULTS ===
# Process results based on Hamiltonian type
if self.ham_type in ['openmx', 'pasp', 'siesta', 'abacus']:
if self.soc_switch or self.spin_constrained:
if not self.collinear_spin:
# Handle zero point energy shift if enabled
if self.zero_point_shift:
# Calculate energy shift to match reference
overlap_matrix = data.overlap.reshape(-1, self.nao_max, self.nao_max)
soc_hamiltonian_real_reshaped = soc_hamiltonian_real.reshape(-1, 2, self.nao_max, 2, self.nao_max)
real_up_up_block = soc_hamiltonian_real_reshaped[:, 0, :, 0, :]
real_down_down_block = soc_hamiltonian_real_reshaped[:, 1, :, 1, :]
data_hamiltonian_real_reshaped = data.hamiltonian_real.reshape(-1, 2, self.nao_max, 2, self.nao_max)
data_real_up_up_block = data_hamiltonian_real_reshaped[:, 0, :, 0, :]
data_real_down_down_block = data_hamiltonian_real_reshaped[:, 1, :, 1, :]
sum_overlap = torch.sum(overlap_matrix[overlap_matrix > 1e-6])
diagonal_block_difference = (real_up_up_block + real_down_down_block) - (data_real_up_up_block + data_real_down_down_block)
energy_shift_real = torch.sum(extract_elements_above_threshold(
overlap_matrix, diagonal_block_difference , 1e-6
)) / (2.0*sum_overlap)
# Apply energy shift
soc_hamiltonian_real_reshaped[:, 0, :, 0, :] = real_up_up_block - energy_shift_real * overlap_matrix
soc_hamiltonian_real_reshaped[:, 1, :, 1, :] = real_down_down_block - energy_shift_real * overlap_matrix
soc_hamiltonian_real = soc_hamiltonian_real_reshaped.reshape(-1, (2*self.nao_max)**2)
soc_hamiltonian = torch.cat((soc_hamiltonian_real, soc_hamiltonian_imag), dim=0)
# Adjust band energies if calculated
if band_energy is not None:
band_energy = band_energy - torch.mean(band_energy - data.band_energy)
# Prepare result dictionary
result = {
'hamiltonian': soc_hamiltonian,
'hamiltonian_real': soc_hamiltonian_real,
'hamiltonian_imag': soc_hamiltonian_imag,
'band_energy': band_energy,
'wavefunction': wavefunction
}
# Add mask tensor if requested
if self.get_nonzero_mask_tensor:
mask_real_imag, mask_all = self.build_spin_orbit_interaction_masks(data)
result['mask_real_imag'] = mask_real_imag
else: # Collinear spin
# Handle zero point energy shift if enabled
if self.zero_point_shift:
# Calculate energy shift to match reference
overlap_matrix = data.overlap
collinear_overlap = torch.stack([overlap_matrix, overlap_matrix], dim=1)
sum_collinear_overlap = 2 * torch.sum(overlap_matrix[overlap_matrix > 1e-6])
energy_shift = torch.sum(extract_elements_above_threshold(
collinear_overlap, collinear_hamiltonian - data.hamiltonian, 1e-6
)) / sum_collinear_overlap
# Apply energy shift
collinear_hamiltonian = collinear_hamiltonian - energy_shift * collinear_overlap
# Adjust band energies if calculated
if band_energy is not None:
band_energy = band_energy - torch.mean(band_energy - data.band_energy)
# Prepare result dictionary
result = {
'hamiltonian': collinear_hamiltonian,
'band_energy': band_energy,
'wavefunction': wavefunction
}
# Add mask tensor if requested
if self.get_nonzero_mask_tensor:
mask_all = self.build_column_wise_interaction_masks(data)
result['mask'] = mask_all
else:
# Standard non-magnetic case
combined_hamiltonian = self.concatenate_hamiltonians_by_crystal(data, onsite_hamiltonian, offsite_hamiltonian)
# Handle zero point energy shift if enabled
if self.zero_point_shift:
# Calculate energy shift to match reference
overlap_matrix = data.overlap
sum_overlap = torch.sum(overlap_matrix[overlap_matrix > 1e-6])
energy_shift = torch.sum(extract_elements_above_threshold(
overlap_matrix, combined_hamiltonian - data.hamiltonian, 1e-6
)) / sum_overlap
# Apply energy shift
combined_hamiltonian = combined_hamiltonian - energy_shift * data.overlap
# Adjust band energies if calculated
if band_energy is not None:
band_energy = band_energy - torch.mean(band_energy - data.band_energy)
# Prepare result dictionary
result = {
'hamiltonian': combined_hamiltonian,
'band_energy': band_energy,
'wavefunction': wavefunction,
'band_gap': gap,
'H_sym': H_sym
}
# Add reciprocal space matrices if requested
if self.export_reciprocal_values:
result.update({'HK': HK, 'SK': SK, 'dSK': dSK})
# Add mask tensor if requested
if self.get_nonzero_mask_tensor:
mask_all = self.build_interaction_masks(data)
result['mask'] = mask_all
else:
raise NotImplementedError("Unsupported Hamiltonian type")
# Add overlap matrix if calculating both Hamiltonian and overlap
if not self.ham_only:
if self.ham_type in ['openmx', 'pasp', 'siesta', 'abacus']:
overlap = self.concatenate_hamiltonians_by_crystal(data, onsite_overlap, offsite_overlap)
else:
raise NotImplementedError("Unsupported Hamiltonian type for overlap calculation")
result.update({'overlap': overlap})
# Calculate and add sparsity ratio if requested
if self.calculate_sparsity:
result['sparsity_ratio'] = self.calculate_sparsity_ratio(data)
return result