eugene.train.fit_sequence_module

eugene.train.fit_sequence_module(model, sdata, seq_var=None, target_vars=None, in_memory=False, train_var='train_val', epochs=10, gpus=None, batch_size=None, num_workers=None, prefetch_factor=None, transforms=None, drop_last=False, logger='tensorboard', log_dir=None, name=None, version=None, early_stopping_metric='val_loss_epoch', early_stopping_patience=5, early_stopping_verbose=False, model_checkpoint_k=1, model_checkpoint_monitor='val_loss_epoch', seed=None, return_trainer=False, **kwargs)

Fit a SequenceModule using PyTorch Lightning. This function is a wrapper around the fit function, but builds the dataloaders from a SeqData object.

Parameters:
  • model (SequenceModule) – The model to train.

  • sdata (SeqData) – The SeqData object to train on.

  • target_vars (str or list of str) – The target vars in sdata to use aas labels for training

  • in_memory (bool) – Whether to load the data into memory before training. Default is False.

  • train_var (str) – The var in sdata to use to split into train and validation set

  • epochs (int) – The number of epochs to train for.

  • gpus (int) – The number of gpus to use. EUGENe will automatically use all available gpus if available.

  • batch_size (int) – The batch size to use.

  • num_workers (int) – The number of workers to use for the dataloader.

  • prefetch_factor (int) – The prefetch factor to use for the dataloader.

  • transforms (dict) – The transforms to apply to the data. This should be a dictionary of the form {“var”: transform function to apply}. See the documentation for SeqData for more information.

  • drop_last (bool) – Whether to drop the last batch if it is smaller than the batch size.

  • logger (str or Logger) – The logger to use. If a string, must be one of “csv”, “tensorboard”, or “wandb”.

  • log_dir (PathLike) – The directory to save the logs to.

  • name (str) – The name of the experiment.

  • version (str) – The version of the experiment.

  • early_stopping_metric (str) – The metric to use for early stopping.

  • early_stopping_patience (int) – The number of epochs to wait before stopping.

  • early_stopping_verbose (bool) – Whether to print early stopping messages.

  • model_checkpoint_k (int) – The number of models to save.

  • model_checkpoint_monitor (str) – The metric to use for model checkpointing.

  • seed (int) – The seed to use for reproducibility.

  • return_trainer (bool) – Whether to return the trainer object.

  • kwargs (dict) – Additional varword arguments to pass to the PL Trainer.

Returns:

trainer – The PyTorch Lightning Trainer object.

Return type:

Trainer