eugene.train.fit¶
- eugene.train.fit(model, train_dataloader, val_dataloader=None, epochs=10, gpus=None, logger='tensorboard', log_dir=None, name=None, version=None, early_stopping_metric='val_loss_epoch', early_stopping_patience=5, early_stopping_verbose=False, model_checkpoint_k=1, model_checkpoint_monitor='val_loss_epoch', seed=None, return_trainer=False, **kwargs)¶
Fit a model using PyTorch Lightning.
This is a generic fit function that can be used to train any PyTorch LightninngModule. All that’s required is a LightningModule, a training dataloader, and optionally a validation dataloader.
- Parameters:
model (LightningModule) – The model to train.
train_dataloader (DataLoader) – The training dataloader to use.
val_dataloader (DataLoader) – The validation dataloader to use.
epochs (int) – The number of epochs to train for.
gpus (int) – The number of gpus to use. EUGENe will automatically use all available gpus if available.
logger (str or Logger) – The logger to use. If a string, must be one of “csv”, “tensorboard”, or “wandb”.
log_dir (PathLike) – The directory to save the logs to.
name (str) – The name of the experiment. Appended to the end of the log directory
version (str) – The version of the experiment. Appended to the end of the log directory/name
early_stopping_metric (str) – The metric to use for early stopping.
early_stopping_patience (int) – The number of epochs to wait before stopping.
early_stopping_verbose (bool) – Whether to print early stopping messages.
seed (int) – The seed to use for reproducibility.
kwargs (dict) – Additional varword arguments to pass to the PL Trainer.
- Returns:
trainer – The PyTorch Lightning Trainer object.
- Return type:
Trainer