Source code for compiam.visualisation.training
import matplotlib.pyplot as plt
[docs]def plot_losses(train_loss, val_loss, output_path):
"""Plotting loss curves
:param train_loss: training loss curve
:param val_loss: validation loss curve (same length as training curve)
:param output_path: optional path (finished with .png) where the plot is saved
"""
plt.plot(train_loss, label="train")
plt.plot(val_loss, label="val")
plt.legend()
if output_path:
plt.savefig(output_path)
plt.clf()
else:
plt.show()