Source code for hamgnn.models.base_model

# Copyright (c) 2021-2026 HamGNN Team
# SPDX-License-Identifier: GPL-3.0-only

"""Base neural-network module and graph construction utilities for HamGNN.

Provides :class:`BaseModel` (abstract ``forward``; batch graph generation from
structures), atomic-radius tables aligned with DFT codes (OpenMX, ABACUS),
ASE-based neighbor lists with periodic boundary conditions, and helpers to align
recomputed edges with pre-existing edge data via column matching.
"""

from __future__ import annotations

import torch
import torch.nn as nn
from torch_scatter import scatter
from easydict import EasyDict

import warnings
from ase import geometry, neighborlist
import numpy as np
from pymatgen.core.periodic_table import Element
from typing import List, Union

ATOMIC_RADII = {
    'openmx': {
        'H': 6.0, 'He': 8.0, 'Li': 8.0, 'Be': 7.0, 'B': 7.0, 'C': 6.0,
        'N': 6.0, 'O': 6.0, 'F': 6.0, 'Ne': 9.0, 'Na': 9.0, 'Mg': 9.0,
        'Al': 7.0, 'Si': 7.0, 'P': 7.0, 'S': 7.0, 'Cl': 7.0, 'Ar': 9.0,
        'K': 10.0, 'Ca': 9.0, 'Sc': 9.0, 'Ti': 7.0, 'V': 6.0, 'Cr': 6.0,
        'Mn': 6.0, 'Fe': 5.5, 'Co': 6.0, 'Ni': 6.0, 'Cu': 6.0, 'Zn': 6.0,
        'Ga': 7.0, 'Ge': 7.0, 'As': 7.0, 'Se': 7.0, 'Br': 7.0, 'Kr': 10.0,
        'Rb': 11.0, 'Sr': 10.0, 'Y': 10.0, 'Zr': 7.0, 'Nb': 7.0, 'Mo': 7.0,
        'Tc': 7.0, 'Ru': 7.0, 'Rh': 7.0, 'Pd': 7.0, 'Ag': 7.0, 'Cd': 7.0,
        'In': 7.0, 'Sn': 7.0, 'Sb': 7.0, 'Te': 7.0, 'I': 7.0, 'Xe': 11.0,
        'Cs': 12.0, 'Ba': 10.0, 'La': 8.0, 'Ce': 8.0, 'Pr': 8.0, 'Nd': 8.0,
        'Pm': 8.0, 'Sm': 8.0, 'Dy': 8.0, 'Ho': 8.0, 'Lu': 8.0, 'Hf': 9.0,
        'Ta': 7.0, 'W': 7.0, 'Re': 7.0, 'Os': 7.0, 'Ir': 7.0, 'Pt': 7.0,
        'Au': 7.0, 'Hg': 8.0, 'Tl': 8.0, 'Pb': 8.0, 'Bi': 8.0
    },
    'siesta':{},
    'abacus': { # unit: au
    'Ag':7,  'Cu':8,  'Mo':7,  'Sc':8,
    'Al':7,  'Fe':8,  'Na':8,  'Se':8,
    'Ar':7,  'F' :7,  'Nb':8,  'S' :7,
    'As':7,  'Ga':8,  'Ne':6,  'Si':7,
    'Au':7,  'Ge':8,  'N' :7,  'Sn':7,
    'Ba':10, 'He':6,  'Ni':8,  'Sr':9,
    'Be':7,  'Hf':7,  'O' :7,  'Ta':8,
    'B' :8,  'H' :6,  'Os':7,  'Tc':7,
    'Bi':7,  'Hg':9,  'Pb':7,  'Te':7,
    'Br':7,  'I' :7,  'Pd':7,  'Ti':8,
    'Ca':9,  'In':7,  'P' :7,  'Tl':7,
    'Cd':7,  'Ir':7,  'Pt':7,  'V' :8,
    'C' :7,  'K' :9,  'Rb':10, 'W' :8,
    'Cl':7,  'Kr':7,  'Re':7,  'Xe':8,
    'Co':8,  'Li':7,  'Rh':7,  'Y' :8,
    'Cr':8,  'Mg':8,  'Ru':7,  'Zn':8,
    'Cs':10, 'Mn':8,  'Sb':7,  'Zr':8
}
}

DEFAULT_RADIUS = 10.0

[docs] def get_radii_from_atomic_numbers(atomic_numbers: Union[torch.Tensor, List[int]], radius_scale: float = 1.5, radius_type: str = 'openmx') -> List[float]: """ Retrieves the scaled atomic radii for a given list or tensor of atomic numbers. Parameters: - atomic_numbers (Union[torch.Tensor, List[int]]): A list or tensor containing atomic numbers. - radius_scale (float): A scaling factor to multiply the atomic radii. Default is 1.5. - radius_type (str): The software, in which the atomic radius is utilized, originates from a specific source. Default is openmx. Returns: - List[float]: A list of scaled atomic radii corresponding to the input atomic numbers. """ if isinstance(atomic_numbers, torch.Tensor): atomic_numbers = atomic_numbers.tolist() # Convert atomic numbers to element symbols and then to scaled radii. # Use 0.0 as a default value for elements not found in the dictionary. return [radius_scale * ATOMIC_RADII[radius_type].get(Element.from_Z(z).symbol, DEFAULT_RADIUS) for z in atomic_numbers]
[docs] def neighbor_list_and_relative_vec( pos, r_max, self_interaction=False, strict_self_interaction=True, cell=None, pbc=False, ): """Create neighbor list and neighbor vectors based on radial cutoff. Edges are given by the following convention: - ``edge_index[0]`` is the *source* (convolution center). - ``edge_index[1]`` is the *target* (neighbor). Args: pos (shape [N, 3]): Positional coordinates; Tensor or numpy array. r_max (float): Radial cutoff distance for neighbor finding. cell (numpy shape [3, 3]): Cell for periodic boundary conditions. pbc (bool or 3-tuple of bool): Periodicity in each of the three dimensions. self_interaction (bool): Include same periodic image self-edges. strict_self_interaction (bool): Include any self interaction edges. Returns: edge_index (torch.Tensor [2, num_edges]): List of edges. shifts (torch.Tensor [num_edges, 3]): Relative cell shift vectors. cell_tensor (torch.Tensor [3, 3]): Cell tensor. """ if isinstance(pbc, bool): pbc = (pbc,) * 3 # Handle positional data if isinstance(pos, torch.Tensor): temp_pos = pos.detach().cpu().numpy() out_device = pos.device out_dtype = pos.dtype else: temp_pos = np.asarray(pos) out_device = torch.device("cpu") out_dtype = torch.get_default_dtype() if out_device.type != "cpu": warnings.warn( "Currently, neighborlists require a round trip to the CPU. Please pass CPU tensors if possible." ) # Handle cell data if isinstance(cell, torch.Tensor): temp_cell = cell.detach().cpu().numpy() cell_tensor = cell.to(device=out_device, dtype=out_dtype) elif cell is not None: temp_cell = np.asarray(cell) cell_tensor = torch.as_tensor(temp_cell, device=out_device, dtype=out_dtype) else: temp_cell = np.zeros((3, 3), dtype=temp_pos.dtype) cell_tensor = torch.as_tensor(temp_cell, device=out_device, dtype=out_dtype) temp_cell = geometry.complete_cell(temp_cell) # Generate neighbor list first_index, second_index, shifts = neighborlist.primitive_neighbor_list( "ijS", pbc, temp_cell, temp_pos, cutoff=r_max, self_interaction=strict_self_interaction, use_scaled_positions=False, ) # Filter self-edges if not self_interaction: bad_edge = first_index == second_index bad_edge &= np.all(shifts == 0, axis=1) keep_edge = ~bad_edge if not np.any(keep_edge): raise ValueError("No edges remain after eliminating self-edges.") first_index = first_index[keep_edge] second_index = second_index[keep_edge] shifts = shifts[keep_edge] # Build output edge_index = torch.vstack( (torch.LongTensor(first_index), torch.LongTensor(second_index)) ).to(device=out_device) shifts = torch.as_tensor( shifts, dtype=torch.long, device=out_device, ) return edge_index, shifts, cell_tensor
[docs] def find_matching_columns_of_A_in_B(A, B): """ Finds matching columns between two matrices A and B. Parameters: - A (torch.Tensor): First matrix. - B (torch.Tensor): Second matrix. Returns: - torch.Tensor: Indices of matching columns in B. """ assert A.shape[0] == B.shape[0], "The number of rows in A and B must be the same." assert A.shape[1] <= B.shape[1], "Please increase radius_scale factor!" # If A and B are very small, directly use the original method. if A.shape[1] * B.shape[1] * A.shape[0] < 10**6: A_rows = A.T.unsqueeze(1) B_rows = B.T.unsqueeze(0) matches = torch.all(A_rows == B_rows, dim=-1) matching_indices = matches.nonzero(as_tuple=True)[1] return matching_indices # Select some rows for quick comparison (the first row, a middle row and the last row). sample_indices = [0, A.shape[0]//2, -1] A_sample = A[sample_indices, :] B_sample = B[sample_indices, :] matching_indices = [] for i in range(A.shape[1]): # Check if the sample row matches col_a_sample = A_sample[:, i:i+1] sample_matches = torch.all(col_a_sample == B_sample, dim=0) potential_matches = torch.nonzero(sample_matches, as_tuple=True)[0] if potential_matches.numel() > 0: # For potential matches, check if the entire column matches. col_a = A[:, i:i+1] potential_B = B[:, potential_matches] equals = torch.all(col_a == potential_B, dim=0) true_matches = potential_matches[equals] if true_matches.numel() > 0: matching_indices.append(true_matches) if matching_indices: return torch.cat(matching_indices) else: return torch.tensor([], dtype=torch.long, device=A.device)
[docs] class BaseModel(nn.Module): def __init__(self, radius_type: str = 'openmx', radius_scale: float = 1.5) -> None: super().__init__() self.radius_type = radius_type self.radius_scale = radius_scale
[docs] def forward(self, data): raise NotImplementedError
[docs] def generate_graph( self, data, ): graph = EasyDict() node_counts = scatter(torch.ones_like(data.batch), data.batch, dim=0).detach() latt_batch = data.cell.detach().reshape(-1, 3, 3) pos_batch = data.pos.detach() pos_batch = torch.split(pos_batch, node_counts.tolist(), dim=0) z_batch = torch.split(data.z.detach(), node_counts.tolist(), dim=0) nbr_shift = [] edge_index = [] cell_shift = [] for idx_xtal, pos in enumerate(pos_batch): edge_index_temp, shifts_tmp, _ = neighbor_list_and_relative_vec( pos, r_max=get_radii_from_atomic_numbers(z_batch[idx_xtal], radius_scale=self.radius_scale, radius_type=self.radius_type), self_interaction=False, strict_self_interaction=True, cell=latt_batch[idx_xtal], pbc=True, ) nbr_shift_temp = torch.einsum('ni, ij -> nj', shifts_tmp.type_as(pos), latt_batch[idx_xtal]) if idx_xtal > 0: edge_index_temp += node_counts[idx_xtal - 1] edge_index.append(edge_index_temp) cell_shift.append(shifts_tmp) nbr_shift.append(nbr_shift_temp) edge_index = torch.cat(edge_index, dim=-1).type_as(data.edge_index) cell_shift = torch.cat(cell_shift, dim=0).type_as(data.cell_shift) nbr_shift = torch.cat(nbr_shift, dim=0).type_as(data.nbr_shift) matching_edges = find_matching_columns_of_A_in_B(torch.cat([data.edge_index, data.cell_shift.t()], dim=0), torch.cat([edge_index, cell_shift.t()], dim=0)) graph['z'] = data.z graph['pos'] = data.pos graph['edge_index'] = edge_index graph['cell_shift'] = cell_shift graph['nbr_shift'] = nbr_shift graph['batch'] = data.batch graph['matching_edges'] = matching_edges return graph
@property def num_params(self) -> int: return sum(p.numel() for p in self.parameters())