Source code for hamgnn.utils.triplets

import torch
from torch_sparse import SparseTensor

[docs] def triplets(edge_index, num_nodes, cell_shift): row, col = edge_index # j->i value = torch.arange(row.size(0), device=row.device) adj_t = SparseTensor( row=col, col=row, value=value, sparse_sizes=(num_nodes, num_nodes) ) adj_t_row = adj_t[row] num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long) # Node indices (k->j->i) for triplets. idx_i = col.repeat_interleave(num_triplets) idx_j = row.repeat_interleave(num_triplets) idx_k = adj_t_row.storage.col() # Edge indices (k-j, j->i) for triplets. idx_kj = adj_t_row.storage.value() idx_ji = adj_t_row.storage.row() """ idx_i -> pos[idx_i] idx_j -> pos[idx_j] - nbr_shift[idx_ji] idx_k -> pos[idx_k] - nbr_shift[idx_ji] - nbr_shift[idx_kj] """ # Remove i == k triplets with the same cell_shift. relative_cell_shift = cell_shift[idx_kj] + cell_shift[idx_ji] mask = (idx_i != idx_k) | torch.any(relative_cell_shift != 0, dim=-1) idx_i, idx_j, idx_k, idx_kj, idx_ji = idx_i[mask], idx_j[mask], idx_k[mask], idx_kj[mask], idx_ji[mask] return col, row, idx_i, idx_j, idx_k, idx_kj, idx_ji