eugene.train.fit

eugene.train.fit(model, train_dataloader, val_dataloader=None, epochs=10, gpus=None, logger='tensorboard', log_dir=None, name=None, version=None, early_stopping_metric='val_loss_epoch', early_stopping_patience=5, early_stopping_verbose=False, model_checkpoint_k=1, model_checkpoint_monitor='val_loss_epoch', seed=None, return_trainer=False, **kwargs)

Fit a model using PyTorch Lightning.

This is a generic fit function that can be used to train any PyTorch LightninngModule. All that’s required is a LightningModule, a training dataloader, and optionally a validation dataloader.

Parameters:
  • model (LightningModule) – The model to train.

  • train_dataloader (DataLoader) – The training dataloader to use.

  • val_dataloader (DataLoader) – The validation dataloader to use.

  • epochs (int) – The number of epochs to train for.

  • gpus (int) – The number of gpus to use. EUGENe will automatically use all available gpus if available.

  • logger (str or Logger) – The logger to use. If a string, must be one of “csv”, “tensorboard”, or “wandb”.

  • log_dir (PathLike) – The directory to save the logs to.

  • name (str) – The name of the experiment. Appended to the end of the log directory

  • version (str) – The version of the experiment. Appended to the end of the log directory/name

  • early_stopping_metric (str) – The metric to use for early stopping.

  • early_stopping_patience (int) – The number of epochs to wait before stopping.

  • early_stopping_verbose (bool) – Whether to print early stopping messages.

  • seed (int) – The seed to use for reproducibility.

  • kwargs (dict) – Additional varword arguments to pass to the PL Trainer.

Returns:

trainer – The PyTorch Lightning Trainer object.

Return type:

Trainer