eugene.plot.metric_curve¶
- eugene.plot.metric_curve(log_path, metric, hue='metric', title=None, xlab='minibatch step', ylab=None, ax=None, return_axes=False, **kwargs)¶
Plots the loss curves from a PyTorch Lightning (PL) training run.
Uses the tensorboard event file to extract the loss curves. The loss curves are extracted from the event file and converted to a pandas dataframe. The dataframe is then plotted using seaborn.
- Parameters:
log_path (str, optional) – Path to tensorboard log directory.
metric (str, optional) – Metric to plot. Should be the string name of the metric used in PL
title (str, optional) – Title of plot.
xlab (str, optional) – Label for x-axis.
ylab (str, optional) – Label for y-axis.
ax (matplotlib.pyplot.Axes, optional) – The axes object to plot on.
return_axes (bool, optional) – If True, returns the axes object.
**kwargs – Additional keyword arguments to pass to seaborn.
- Returns:
If return_axes is True, returns the axes object.
- Return type:
None or matplotlib.pyplot.Axes