import torch
import torch.nn as nn
from typing import Union
[docs]
class cosine_similarity_loss(nn.Module):
def __init__(self):
super(cosine_similarity_loss, self).__init__()
[docs]
def forward(self, pred, target):
vec_product = torch.sum(pred*target, dim=-1)
pred_norm = torch.norm(pred, p=2, dim=-1)
target_norm = torch.norm(target, p=2, dim=-1)
loss = torch.tensor(1.0).type_as(
pred) - vec_product/(pred_norm*target_norm)
loss = torch.mean(loss)
return loss
[docs]
class sum_zero_loss(nn.Module):
def __init__(self):
super(sum_zero_loss, self).__init__()
[docs]
def forward(self, pred, target):
loss = torch.sum(pred, dim=0).pow(2).sum(dim=-1).sqrt()
return loss
[docs]
class Euclidean_loss(nn.Module):
def __init__(self):
super(Euclidean_loss, self).__init__()
[docs]
def forward(self, pred, target):
dist = (pred - target).pow(2).sum(dim=-1).sqrt()
loss = torch.mean(dist)
return loss
[docs]
class RMSELoss(nn.Module):
def __init__(self):
super(RMSELoss, self).__init__()
self.mse = nn.MSELoss()
[docs]
def forward(self, pred, target):
return torch.sqrt(self.mse(pred, target))
[docs]
def parse_metric_func(losses_list: Union[list, tuple] = None):
for loss_dict in losses_list:
if loss_dict['metric'].lower() == 'mse':
loss_dict['metric'] = nn.MSELoss()
elif loss_dict['metric'].lower() == 'mae':
loss_dict['metric'] = nn.L1Loss()
elif loss_dict['metric'].lower() == 'cosine_similarity':
loss_dict['metric'] = cosine_similarity_loss()
elif loss_dict['metric'].lower() == 'sum_zero':
loss_dict['metric'] = sum_zero_loss()
elif loss_dict['metric'].lower() == 'euclidean_loss':
loss_dict['metric'] = Euclidean_loss()
elif loss_dict['metric'].lower() == 'rmse':
loss_dict['metric'] = RMSELoss()
else:
print(f'This metric function is not supported!')
return losses_list