Source code for hamgnn.utils.hparam


[docs] def get_hparam_dict(config: dict = None): if config.setup.GNN_Net.lower() == 'dimnet': hparam_dict = config.representation_nets.dimnet_params elif config.setup.GNN_Net.lower() == 'edge_gnn': hparam_dict = config.representation_nets.Edge_GNN elif config.setup.GNN_Net.lower() == 'schnet': hparam_dict = config.representation_nets.SchNet elif config.setup.GNN_Net.lower() == 'cgcnn': hparam_dict = config.representation_nets.cgcnn elif config.setup.GNN_Net.lower() == 'cgcnn_edge': hparam_dict = config.representation_nets.cgcnn_edge elif config.setup.GNN_Net.lower() == 'painn': hparam_dict = config.representation_nets.painn elif config.setup.GNN_Net.lower() == 'cgcnn_triplet': hparam_dict = config.representation_nets.cgcnn_triplet elif config.setup.GNN_Net.lower() == 'dimenet_triplet': hparam_dict = config.representation_nets.dimenet_triplet elif config.setup.GNN_Net.lower() == 'dimeham': hparam_dict = config.representation_nets.dimeham elif config.setup.GNN_Net.lower() == 'dimeorb': hparam_dict = config.representation_nets.dimeorb elif config.setup.GNN_Net.lower() == 'schnorb': hparam_dict = config.representation_nets.schnorb elif config.setup.GNN_Net.lower() == 'nequip': hparam_dict = config.representation_nets.nequip elif config.setup.GNN_Net.lower() == 'hamgnn_pre': hparam_dict = config.representation_nets.HamGNN_pre elif config.setup.GNN_Net.lower()[:6] == 'hamgnn': hparam_dict = config.representation_nets.HamGNN_pre else: print(f"The network: {config.setup.GNN_Net} is not yet supported!") quit() for key in hparam_dict: if type(hparam_dict[key]) not in [str, float, int, bool, None]: hparam_dict[key] = type(hparam_dict[key]).__name__.split(".")[-1] out = {'GNN_Name': config.setup.GNN_Net} out.update(dict(hparam_dict)) return out