import torch
import torch.nn as nn
import e3nn.o3 as o3
from e3nn.util.jit import compile_mode
from typing import Optional
from math import prod
from ..nn.interaction_blocks import ResidualBlock
from .Clebsch_Gordan_coefficients import ClebschGordanCoefficients
[docs]
@compile_mode("script")
class HamLayer(nn.Module):
def __init__(self, irreps_in, feature_irreps_hidden, irreps_out, nonlinearity_type: str = "gate", resnet: bool = True):
super().__init__()
# Define the residual block
self.residual_block = ResidualBlock(irreps_in=irreps_in,
feature_irreps_hidden=feature_irreps_hidden,
nonlinearity_type=nonlinearity_type,
resnet=resnet)
# Define the linear transformation
self.linear_transform = o3.Linear(irreps_in=irreps_in, irreps_out=irreps_out)
[docs]
def forward(self, x):
# Apply the residual block
x = self.residual_block(x)
# Apply the linear transformation
x = self.linear_transform(x)
return x
[docs]
class TensorExpansion(nn.Module):
def __init__(self, ham_type, nao_max):
"""
:param ham_type: Type of Hamiltonian ('openmx', 'siesta', 'abacus', 'pasp')
:param nao_max: Maximum number of atomic orbitals
"""
super().__init__()
self.ham_type = ham_type
self.nao_max = nao_max
self.index_change = None
self.minus_index = None
self.row = None
self.col = None
self._set_basis_info()
# Calculate maximum l for Clebsch-Gordan coefficients
max_l = self.row.lmax + self.col.lmax
self.cg_calculator = ClebschGordanCoefficients(max_l=max_l)
irreps_combined = self._combine_irreps()
self.irreps_out, self.permute_indices, self.inverse_permute_indices = o3.Irreps(irreps_combined).sort()
self.irreps_out = self.irreps_out.simplify()
def _combine_irreps(self):
"""
Combine input irreps to determine output irreps.
Returns:
List of combined irreps.
"""
combined_irreps = []
for _, li in self.row:
for _, lj in self.col:
for L in range(abs(li.l - lj.l), li.l + lj.l + 1):
combined_irreps.append(o3.Irrep(L, (-1) ** (li.l + lj.l)))
return o3.Irreps(combined_irreps)
def _get_index_change_inv(self, index_change):
"""
Get the inverse of an index change tensor.
:param index_change: Tensor indicating the index change.
:return: Tensor representing the inverse index change.
"""
index_change_inv = torch.zeros_like(index_change)
for i in range(len(index_change)):
index_change_inv[index_change[i]] = i
return index_change_inv
def _set_basis_info(self):
"""
Sets the basis information based on the Hamiltonian type and number of atomic orbitals.
"""
if self.ham_type == 'openmx':
self._set_openmx_basis()
elif self.ham_type == 'siesta':
self._set_siesta_basis()
elif self.ham_type == 'abacus':
self._set_abacus_basis()
elif self.ham_type == 'pasp':
self.row = self.col = o3.Irreps("1x1o")
else:
raise NotImplementedError(f"Hamiltonian type '{self.ham_type}' is not supported.")
def _set_openmx_basis(self):
"""
Sets basis information for 'openmx' Hamiltonian.
"""
if self.nao_max == 14:
self.index_change = torch.LongTensor([0, 1, 2, 5, 3, 4, 8, 6, 7, 11, 13, 9, 12, 10])
self.row = self.col = o3.Irreps("1x0e+1x0e+1x0e+1x1o+1x1o+1x2e")
elif self.nao_max == 13:
self.index_change = torch.LongTensor([0, 1, 4, 2, 3, 7, 5, 6, 10, 12, 8, 11, 9])
self.row = self.col = o3.Irreps("1x0e+1x0e+1x1o+1x1o+1x2e")
elif self.nao_max == 19:
self.index_change = torch.LongTensor([0, 1, 2, 5, 3, 4, 8, 6, 7, 11, 13, 9, 12, 10, 16, 18, 14, 17, 15])
self.row = self.col = o3.Irreps("1x0e+1x0e+1x0e+1x1o+1x1o+1x2e+1x2e")
elif self.nao_max == 26:
self.index_change = torch.LongTensor([0, 1, 2, 5, 3, 4, 8, 6, 7, 11, 13, 9, 12, 10, 16, 18, 14, 17, 15, 22, 23, 21, 24, 20, 25, 19])
self.row = self.col = o3.Irreps("1x0e+1x0e+1x0e+1x1o+1x1o+1x2e+1x2e+1x3o")
else:
raise NotImplementedError(f"NAO max '{self.nao_max}' not supported for 'openmx'.")
def _set_siesta_basis(self):
"""
Sets basis information for 'siesta' Hamiltonian.
"""
if self.nao_max == 13:
self.index_change = None
self.row = self.col = o3.Irreps("1x0e+1x0e+1x1o+1x1o+1x2e")
self.minus_index = torch.LongTensor([2, 4, 5, 7, 9, 11])
elif self.nao_max == 19:
self.index_change = None
self.row = self.col = o3.Irreps("1x0e+1x0e+1x0e+1x1o+1x1o+1x2e+1x2e")
self.minus_index = torch.LongTensor([3, 5, 6, 8, 10, 12, 15, 17])
else:
raise NotImplementedError(f"NAO max '{self.nao_max}' not supported for 'siesta'.")
def _set_abacus_basis(self):
"""
Sets basis information for 'abacus' Hamiltonian.
"""
if self.nao_max == 13:
self.index_change = torch.LongTensor([0, 1, 3, 4, 2, 6, 7, 5, 10, 11, 9, 12, 8])
self.row = self.col = o3.Irreps("1x0e+1x0e+1x1o+1x1o+1x2e")
self.minus_index = torch.LongTensor([3, 4, 6, 7, 9, 10])
elif self.nao_max == 27:
self.index_change = torch.LongTensor([0, 1, 2, 3, 5, 6, 4, 8, 9, 7, 12, 13, 11, 14, 10, 17, 18, 16, 19, 15, 23, 24, 22, 25, 21, 26, 20])
self.row = self.col = o3.Irreps("1x0e+1x0e+1x0e+1x0e+1x1o+1x1o+1x2e+1x2e+1x3o")
self.minus_index = torch.LongTensor([5, 6, 8, 9, 11, 12, 16, 17, 21, 22, 25, 26])
elif self.nao_max == 40:
self.index_change = torch.LongTensor([0, 1, 2, 3, 5, 6, 4, 8, 9, 7, 11, 12, 10, 14, 15, 13, 18, 19, 17, 20, 16, 23, 24, 22, 25, 21, 29, 30, 28, 31, 27, 32, 26, 36, 37, 35, 38, 34, 39, 33])
self.row = self.col = o3.Irreps("1x0e+1x0e+1x0e+1x0e+1x1o+1x1o+1x1o+1x1o+1x2e+1x2e+1x3o+1x3o")
else:
raise NotImplementedError(f"NAO max '{self.nao_max}' not supported for 'abacus'.")
def _change_index(self, hamiltonian):
"""
Adjust the order of the output matrix elements to the atomic orbital order of openmx
"""
if self.index_change is not None or hasattr(self, 'minus_index'):
hamiltonian = hamiltonian.reshape(-1, self.nao_max, self.nao_max)
if self.index_change is not None:
hamiltonian = hamiltonian[:, self.index_change[:,None], self.index_change[None,:]]
if hasattr(self, 'minus_index'):
hamiltonian[:,self.minus_index,:] = -hamiltonian[:,self.minus_index,:]
hamiltonian[:,:,self.minus_index] = -hamiltonian[:,:,self.minus_index]
return hamiltonian
def _change_index_inv(self, hamiltonian):
"""
Adjust the order of the output matrix elements to the atomic orbital order of openmx
"""
if self.index_change is not None or hasattr(self, 'minus_index'):
hamiltonian = hamiltonian.reshape(-1, self.nao_max, self.nao_max)
if hasattr(self, 'minus_index'):
hamiltonian[:,self.minus_index,:] = -hamiltonian[:,self.minus_index,:]
hamiltonian[:,:,self.minus_index] = -hamiltonian[:,:,self.minus_index]
if self.index_change is not None:
index_change_inv = self._get_index_change_inv(self.index_change)
hamiltonian = hamiltonian[:, index_change_inv[:,None], index_change_inv[None,:]]
return hamiltonian
[docs]
def forward(self, x):
"""
Forward pass to compute the expanded tensor.
Args:
x (torch.Tensor): Input tensor of shape (*, row.dim, col.dim).
Returns:
torch.Tensor: Expanded tensor.
"""
x = x.reshape(-1, self.row.dim, self.col.dim)
x = self._change_index_inv(x)
output_blocks = []
row_start = 0
for _, li in self.row:
num_rows = 2 * li.l + 1
col_start = 0
for _, lj in self.col:
num_cols = 2 * lj.l + 1
for L in range(abs(li.l - lj.l), li.l + lj.l + 1):
# Compute Clebsch-Gordan coefficients
cg_coeffs = self.cg_calculator(L, li.l, lj.l)
block = x.narrow(-2, row_start, num_rows).narrow(-1, col_start, num_cols)
output_blocks.append(torch.einsum('nij, kij -> nk', block, cg_coeffs))
col_start += num_cols
row_start += num_rows
# Concatenate outputs and apply inverse permutation
expanded_output = torch.cat([output_blocks[idx] for idx in self.inverse_permute_indices], dim=-1)
return expanded_output
[docs]
class OverlapExpand(nn.Module):
def __init__(self, ham_type, nao_max) -> None:
"""
Initialize the OverlapExpand module.
:param ham_type: Type of Hamiltonian ('openmx', 'siesta', 'abacus', 'pasp').
:param nao_max: Maximum number of atomic orbitals.
"""
super().__init__()
self.tensor_expansion = TensorExpansion(ham_type=ham_type, nao_max=nao_max)
self.irreps_overlap = self.tensor_expansion.irreps_out
[docs]
def forward(self, data):
"""
Forward pass to expand overlap data.
Args:
data: Object containing 'Son' and 'Soff' tensors to be expanded.
Returns:
Updated data object with expanded 'Son' and 'Soff'.
"""
data['Son_expand'] = self.tensor_expansion(data.Son)
data['Soff_expand'] = self.tensor_expansion(data.Soff)
return data
[docs]
class TensorMerge(nn.Module):
def __init__(self, irrep_in, irrep_out_1, irrep_out_2, internal_weights: Optional[bool] = False):
super().__init__()
self.irrep_in = irrep_in
self.irrep_out_1 = irrep_out_1
self.irrep_out_2 = irrep_out_2
self.instructions = self.get_expansion_path(irrep_in, irrep_out_1, irrep_out_2)
self.num_path_weight = sum(prod(ins[-1]) for ins in self.instructions if ins[3])
self.num_bias = sum([prod(ins[-1][1:]) for ins in self.instructions if ins[0] == 0])
self.num_weights = self.num_path_weight + self.num_bias
self.internal_weights = internal_weights
if self.internal_weights:
self.weights = nn.Parameter(torch.rand(self.num_path_weight + self.num_bias))
else:
self.linear_weight_bias = o3.Linear(self.irrep_in, o3.Irreps([(self.num_weights, (0, 1))]))
[docs]
def forward(self, x_in):
if self.internal_weights:
weights, bias_weights = None
else:
weights, bias_weights = torch.split(self.linear_weight_bias(x_in),
split_size_or_sections=[self.num_path_weight, self.num_bias], dim=-1)
batch_num = x_in.shape[0]
if len(self.irrep_in) == 1:
x_in_s = [x_in.reshape(batch_num, self.irrep_in[0].mul, self.irrep_in[0].ir.dim)]
else:
x_in_s = [
x_in[:, i].reshape(batch_num, mul_ir.mul, mul_ir.ir.dim)
for i, mul_ir in zip(self.irrep_in.slices(), self.irrep_in)]
outputs = {}
flat_weight_index = 0
bias_weight_index = 0
for ins in self.instructions:
mul_ir_in = self.irrep_in[ins[0]]
mul_ir_out1 = self.irrep_out_1[ins[1]]
mul_ir_out2 = self.irrep_out_2[ins[2]]
x1 = x_in_s[ins[0]]
x1 = x1.reshape(batch_num, mul_ir_in.mul, mul_ir_in.ir.dim)
w3j_matrix = o3.wigner_3j(
mul_ir_out1.ir.l, mul_ir_out2.ir.l, mul_ir_in.ir.l).type_as(x_in)
if ins[3] is True or weights is not None:
if weights is None:
weight = self.weights[flat_weight_index:flat_weight_index + prod(ins[-1])].reshape(ins[-1])
result = torch.einsum(
f"wuv, ijk, bwk-> buivj", weight, w3j_matrix, x1) / mul_ir_in.mul
else:
weight = weights[:, flat_weight_index:flat_weight_index + prod(ins[-1])].reshape([-1] + ins[-1])
result = torch.einsum(f"bwuv, bwk-> buvk", weight, x1)
if ins[0] == 0 and bias_weights is not None:
bias_weight = bias_weights[:,bias_weight_index:bias_weight_index + prod(ins[-1][1:])].\
reshape([-1] + ins[-1][1:])
bias_weight_index += prod(ins[-1][1:])
result = result + bias_weight.unsqueeze(-1)
result = torch.einsum(f"ijk, buvk->buivj", w3j_matrix, result) / mul_ir_in.mul
flat_weight_index += prod(ins[-1])
else:
result = torch.einsum(
f"uvw, ijk, bwk-> buivj", torch.ones(ins[-1]).type(x1.type()).to(self.device), w3j_matrix,
x1.reshape(batch_num, mul_ir_in.mul, mul_ir_in.ir.dim)
)
result = result.reshape(batch_num, mul_ir_out1.dim, mul_ir_out2.dim)
key = (ins[1], ins[2])
if key in outputs.keys():
outputs[key] = outputs[key] + result
else:
outputs[key] = result
rows = []
for i in range(len(self.irrep_out_1)):
blocks = []
for j in range(len(self.irrep_out_2)):
if (i, j) not in outputs.keys():
blocks += [torch.zeros((x_in.shape[0], self.irrep_out_1[i].dim, self.irrep_out_2[j].dim),
device=x_in.device).type(x_in.type())]
else:
blocks += [outputs[(i, j)]]
rows.append(torch.cat(blocks, dim=-1))
output = torch.cat(rows, dim=-2).reshape(batch_num, -1)
return output
[docs]
def get_expansion_path(self, irrep_in, irrep_out_1, irrep_out_2):
instructions = []
for i, (num_in, ir_in) in enumerate(irrep_in):
for j, (num_out1, ir_out1) in enumerate(irrep_out_1):
for k, (num_out2, ir_out2) in enumerate(irrep_out_2):
if ir_in in ir_out1 * ir_out2:
instructions.append([i, j, k, True, 1.0, [num_in, num_out1, num_out2]])
return instructions
@property
def device(self):
return next(self.parameters()).device
def __repr__(self):
return f'{self.irrep_in} -> {self.irrep_out_1}x{self.irrep_out_1} and bias {self.num_bias}' \
f'with parameters {self.num_path_weight}'