eugene.plot.metric_curve

eugene.plot.metric_curve(log_path, metric, hue='metric', title=None, xlab='minibatch step', ylab=None, ax=None, return_axes=False, **kwargs)

Plots the loss curves from a PyTorch Lightning (PL) training run.

Uses the tensorboard event file to extract the loss curves. The loss curves are extracted from the event file and converted to a pandas dataframe. The dataframe is then plotted using seaborn.

Parameters:
  • log_path (str, optional) – Path to tensorboard log directory.

  • metric (str, optional) – Metric to plot. Should be the string name of the metric used in PL

  • title (str, optional) – Title of plot.

  • xlab (str, optional) – Label for x-axis.

  • ylab (str, optional) – Label for y-axis.

  • ax (matplotlib.pyplot.Axes, optional) – The axes object to plot on.

  • return_axes (bool, optional) – If True, returns the axes object.

  • **kwargs – Additional keyword arguments to pass to seaborn.

Returns:

If return_axes is True, returns the axes object.

Return type:

None or matplotlib.pyplot.Axes