Source code for hamgnn.config.config_parsing

# Copyright (c) 2021-2026 HamGNN Team
# SPDX-License-Identifier: GPL-3.0-only

"""Default configuration templates and YAML merge loader for HamGNN.

Defines ``config_default`` and nested section dicts, and provides :func:`load_config`
to merge user YAML with defaults into :class:`easydict.EasyDict`, including parsing
of loss/metric function specs.
"""

import yaml
import argparse
import copy
from typing import Dict, Any, Optional, Union
from easydict import EasyDict

from ..utils.losses import parse_metric_func

"""
Default configuration parameters.
"""
config_default = dict()

"""The parameters for setup"""
config_default_setup = dict()
config_default_setup['GNN_Net'] = 'HamGNNpre'
config_default_setup['ignore_warnings'] = True
config_default_setup['checkpoint_path'] = './'
config_default_setup['load_from_checkpoint'] = False
config_default_setup['resume'] = False
config_default_setup['num_gpus'] = 1
config_default_setup['hostname'] = 'host'
config_default_setup['job_id'] = 'time_2025'
config_default_setup['precision'] = 32
config_default_setup['property'] = 'hamiltonian'
config_default_setup['stage'] = 'fit'
config_default_setup['use_gradient_checkpointing'] = False
config_default['setup'] = config_default_setup

"""The parameters for profiler"""
config_default_profiler = dict()
config_default_profiler['train_dir'] = './'
config_default_profiler['progress_bar_refresh_rat'] = 1
config_default['profiler_params'] = config_default_profiler

"""The parameters for representation_nets"""
config_default_representation_nets = dict()
config_default_representation_nets['HamGNN_pre'] = dict()
config_default_representation_nets['HamGNN_pre']['cutoff'] = 26.0
config_default_representation_nets['HamGNN_pre']['cutoff_func'] = 'cos'
config_default_representation_nets['HamGNN_pre']['radius_type'] = 'openmx'
config_default_representation_nets['HamGNN_pre']['edge_sh_normalization'] = 'component'
config_default_representation_nets['HamGNN_pre']['edge_sh_normalize'] = True
config_default_representation_nets['HamGNN_pre']['irreps_edge_sh'] = '0e + 1o + 2e + 3o + 4e + 5o'
config_default_representation_nets['HamGNN_pre']['irreps_node_features'] = '64x0e+64x0o+32x1o+16x1e+12x2o+25x2e+18x3o+9x3e+4x4o+9x4e+4x5o+4x5e+2x6e'
config_default_representation_nets['HamGNN_pre']['num_layers'] = 3
config_default_representation_nets['HamGNN_pre']['num_radial'] = 64
config_default_representation_nets['HamGNN_pre']['num_types'] = 96
config_default_representation_nets['HamGNN_pre']['rbf_func'] = 'bessel'
config_default_representation_nets['HamGNN_pre']['set_features'] = True
config_default_representation_nets['HamGNN_pre']['radial_MLP'] = [64, 64]
config_default_representation_nets['HamGNN_pre']['use_corr_prod'] = False
config_default_representation_nets['HamGNN_pre']['correlation'] = 2
config_default_representation_nets['HamGNN_pre']['num_hidden_features'] = 16
config_default_representation_nets['HamGNN_pre']['use_kan'] = False
config_default_representation_nets['HamGNN_pre']['radius_scale'] = 1.01
config_default_representation_nets['HamGNN_pre']['build_internal_graph'] = False
config_default_representation_nets['HamGNN_pre']['use_gradient_checkpointing'] = False
config_default['representation_nets'] = config_default_representation_nets

"""The parameters for output_nets"""
config_default_output_nets = dict()
config_default_output_nets['output_module'] = 'HamGNN_out'
config_default_output_nets['HamGNN_out'] = dict()
config_default_output_nets['HamGNN_out']['ham_only'] = True
config_default_output_nets['HamGNN_out']['ham_type'] = 'openmx'
config_default_output_nets['HamGNN_out']['nao_max'] = 26
config_default_output_nets['HamGNN_out']['add_H0'] = True
config_default_output_nets['HamGNN_out']['add_H_nonsoc'] = False
config_default_output_nets['HamGNN_out']['symmetrize'] = True
config_default_output_nets['HamGNN_out']['calculate_band_energy'] = False
config_default_output_nets['HamGNN_out']['num_k'] = 5
config_default_output_nets['HamGNN_out']['band_num_control'] = 8
config_default_output_nets['HamGNN_out']['k_path'] = None
config_default_output_nets['HamGNN_out']['soc_switch'] = False
config_default_output_nets['HamGNN_out']['nonlinearity_type'] = 'gate'
config_default_output_nets['HamGNN_out']['spin_constrained'] = False
config_default_output_nets['HamGNN_out']['collinear_spin'] = False
config_default_output_nets['HamGNN_out']['minMagneticMoment'] = 0.5
config_default_output_nets['HamGNN_out']['zero_point_shift'] = True
config_default_output_nets['HamGNN_out']['get_nonzero_mask_tensor'] = False
config_default['output_nets'] = config_default_output_nets

"""The parameters for optimizer."""
config_default_optimizer = dict()
config_default_optimizer['lr'] = 0.01
config_default_optimizer['lr_decay'] = 0.5
config_default_optimizer['lr_patience'] = 5
config_default_optimizer['gradient_clip_val'] = 0.0
config_default_optimizer['stop_patience'] = 30
config_default_optimizer['min_epochs'] = 100
config_default_optimizer['max_epochs'] = 3000
config_default['optim_params'] = config_default_optimizer

"""The parameters for losses_metrics."""
config_default_metric = dict()
config_default_metric['losses'] = [{'metric': 'mae', 'prediction': 'hamiltonian', 'target': 'hamiltonian', 'loss_weight': 27.211}]
config_default_metric['metrics'] = [{'metric': 'mae', 'prediction': 'hamiltonian', 'target': 'hamiltonian'}]
config_default['losses_metrics'] = config_default_metric

"""The parameters for dataset."""
config_default_dataset = dict()
config_default_dataset['batch_size'] = 1
config_default_dataset['split_file'] = None
config_default_dataset['test_ratio'] = 0.2
config_default_dataset['train_ratio'] = 0.6
config_default_dataset['val_ratio'] = 0.2
config_default_dataset['graph_data_path'] = './'
config_default['dataset_params'] = config_default_dataset


[docs] def recursive_update(base_dict: Dict[str, Any], update_dict: Dict[str, Any]) -> Dict[str, Any]: """ Recursively update a dictionary with values from another dictionary. Args: base_dict: The dictionary to be updated (typically default configuration) update_dict: The dictionary containing updates (typically from user YAML file) Returns: Dict[str, Any]: The updated dictionary Notes: - If a key exists in both dictionaries and both values are dictionaries, recursively merge these nested dictionaries - If a key exists in both dictionaries but values aren't both dictionaries, the value from update_dict overrides the value in base_dict - If a key exists only in update_dict, it's added to base_dict """ for key, value in update_dict.items(): if isinstance(value, dict) and key in base_dict and isinstance(base_dict[key], dict): # If both values are dictionaries, recursively update base_dict[key] = recursive_update(base_dict[key], value) else: # Otherwise, directly update the value base_dict[key] = value return base_dict
[docs] def load_config(config_file_path: Optional[str] = None) -> EasyDict: """ Load configuration from a YAML file and merge it with default configuration. This function reads a YAML configuration file and recursively merges its contents with the default configuration. The config file path can be provided either as a function parameter or as a command-line argument. Args: config_file_path: Path to the YAML configuration file. If None, attempts to get the path from command-line arguments. If not provided via command line either, uses 'config_default.yaml'. Returns: EasyDict: An EasyDict object containing the merged configuration. Raises: FileNotFoundError: If the configuration file doesn't exist. yaml.YAMLError: If the YAML file has parsing errors. UnicodeDecodeError: If there are encoding issues when reading the file. Exception: For any other unexpected errors. """ # Check if config file path is provided via command line if not given as parameter if not config_file_path: parser = argparse.ArgumentParser(description='Load configuration from a YAML file.') parser.add_argument('--config', '-c', type=str, default='config_default.yaml', help='Path to the YAML configuration file.') args, _ = parser.parse_known_args() config_file_path = args.config # Create a deep copy of config_default to avoid modifying the original config_copy = copy.deepcopy(config_default) # Try to read and merge configuration from the file try: with open(config_file_path, encoding='utf-8') as config_file: user_config = yaml.safe_load(config_file) # If file is empty or invalid, use an empty dict if user_config is None: user_config = {} # Recursively update each configuration section for section_key in user_config: if (section_key in config_copy and isinstance(config_copy[section_key], dict) and isinstance(user_config[section_key], dict)): config_copy[section_key] = recursive_update(config_copy[section_key], user_config[section_key]) else: config_copy[section_key] = user_config[section_key] except FileNotFoundError: raise FileNotFoundError(f"Configuration file not found: {config_file_path}") except yaml.YAMLError as e: raise yaml.YAMLError(f"Error parsing YAML file {config_file_path}: {e}") except UnicodeDecodeError as e: raise UnicodeDecodeError(f"Encoding error when reading {config_file_path}: {e}") except Exception as e: raise Exception(f"Unexpected error when reading {config_file_path}: {e}") # Convert to EasyDict config = EasyDict(config_copy) # Process specific fields if they exist if hasattr(config, 'losses_metrics'): if hasattr(config.losses_metrics, 'losses'): config.losses_metrics.losses = parse_metric_func(config.losses_metrics.losses) if hasattr(config.losses_metrics, 'metrics'): config.losses_metrics.metrics = parse_metric_func(config.losses_metrics.metrics) return config