Source code for hamgnn.utils.visualization


import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde

[docs] def scatter_plot(pred: np.ndarray = None, target: np.ndarray = None): fig, ax = plt.subplots() ax.scatter(x=pred, y=target, s=25, c='g', alpha=0.5, marker=".") ax.set_title('Prediction VS Target') ax.set_aspect('equal') min_val, max_val = np.min([target, pred]), np.max([target, pred]) ax.plot([min_val, max_val], [min_val, max_val], ls="--", linewidth=1, c='r') plt.xlabel('Prediction', fontsize=15) plt.ylabel('Target', fontsize=15) plt.xticks(fontsize=15) plt.yticks(fontsize=15) return fig