eugene.plot.loss_curve

eugene.plot.loss_curve(log_path, title=None, xlab='minibatch_step', ylab='loss', ax=None, return_axes=False, **kwargs)

Plots the loss curves from a PyTorch Lightning (PL) training run. Wraps metrics_curve function.

Uses the tensorboard event file to extract the metric curves. The metric curves are extracted from the event file and converted to a pandas dataframe. The dataframe is then plotted using seaborn.

Parameters:
  • log_path (str, optional) – Path to tensorboard log directory.

  • title (str, optional) – Title of plot.

  • xlab (str, optional) – Label for x-axis.

  • ylab (str, optional) – Label for y-axis.

  • ax (matplotlib.pyplot.Axes, optional) – The axes object to plot on.

  • return_axes (bool, optional) – If True, returns the axes object.

  • **kwargs – Additional keyword arguments to pass to seaborn.

Returns:

If return_axes is True, returns the axes object.

Return type:

None or matplotlib.pyplot.Axes