Source code for hamgnn.utils.losses

import torch
import torch.nn as nn
from typing import Union

[docs] class cosine_similarity_loss(nn.Module): def __init__(self): super(cosine_similarity_loss, self).__init__()
[docs] def forward(self, pred, target): vec_product = torch.sum(pred*target, dim=-1) pred_norm = torch.norm(pred, p=2, dim=-1) target_norm = torch.norm(target, p=2, dim=-1) loss = torch.tensor(1.0).type_as( pred) - vec_product/(pred_norm*target_norm) loss = torch.mean(loss) return loss
[docs] class sum_zero_loss(nn.Module): def __init__(self): super(sum_zero_loss, self).__init__()
[docs] def forward(self, pred, target): loss = torch.sum(pred, dim=0).pow(2).sum(dim=-1).sqrt() return loss
[docs] class Euclidean_loss(nn.Module): def __init__(self): super(Euclidean_loss, self).__init__()
[docs] def forward(self, pred, target): dist = (pred - target).pow(2).sum(dim=-1).sqrt() loss = torch.mean(dist) return loss
[docs] class RMSELoss(nn.Module): def __init__(self): super(RMSELoss, self).__init__() self.mse = nn.MSELoss()
[docs] def forward(self, pred, target): return torch.sqrt(self.mse(pred, target))
[docs] def parse_metric_func(losses_list: Union[list, tuple] = None): for loss_dict in losses_list: if loss_dict['metric'].lower() == 'mse': loss_dict['metric'] = nn.MSELoss() elif loss_dict['metric'].lower() == 'mae': loss_dict['metric'] = nn.L1Loss() elif loss_dict['metric'].lower() == 'cosine_similarity': loss_dict['metric'] = cosine_similarity_loss() elif loss_dict['metric'].lower() == 'sum_zero': loss_dict['metric'] = sum_zero_loss() elif loss_dict['metric'].lower() == 'euclidean_loss': loss_dict['metric'] = Euclidean_loss() elif loss_dict['metric'].lower() == 'rmse': loss_dict['metric'] = RMSELoss() else: print(f'This metric function is not supported!') return losses_list