eugene.train.fit_sequence_module¶
- eugene.train.fit_sequence_module(model, sdata, seq_var=None, target_vars=None, in_memory=False, train_var='train_val', epochs=10, gpus=None, batch_size=None, num_workers=None, prefetch_factor=None, transforms=None, drop_last=False, logger='tensorboard', log_dir=None, name=None, version=None, early_stopping_metric='val_loss_epoch', early_stopping_patience=5, early_stopping_verbose=False, model_checkpoint_k=1, model_checkpoint_monitor='val_loss_epoch', seed=None, return_trainer=False, **kwargs)¶
Fit a SequenceModule using PyTorch Lightning. This function is a wrapper around the fit function, but builds the dataloaders from a SeqData object.
- Parameters:
model (
SequenceModule) – The model to train.sdata (SeqData) – The SeqData object to train on.
target_vars (str or list of str) – The target vars in sdata to use aas labels for training
in_memory (bool) – Whether to load the data into memory before training. Default is False.
train_var (str) – The var in sdata to use to split into train and validation set
epochs (int) – The number of epochs to train for.
gpus (int) – The number of gpus to use. EUGENe will automatically use all available gpus if available.
batch_size (int) – The batch size to use.
num_workers (int) – The number of workers to use for the dataloader.
prefetch_factor (int) – The prefetch factor to use for the dataloader.
transforms (dict) – The transforms to apply to the data. This should be a dictionary of the form {“var”: transform function to apply}. See the documentation for SeqData for more information.
drop_last (bool) – Whether to drop the last batch if it is smaller than the batch size.
logger (str or Logger) – The logger to use. If a string, must be one of “csv”, “tensorboard”, or “wandb”.
log_dir (PathLike) – The directory to save the logs to.
name (str) – The name of the experiment.
version (str) – The version of the experiment.
early_stopping_metric (str) – The metric to use for early stopping.
early_stopping_patience (int) – The number of epochs to wait before stopping.
early_stopping_verbose (bool) – Whether to print early stopping messages.
model_checkpoint_k (int) – The number of models to save.
model_checkpoint_monitor (str) – The metric to use for model checkpointing.
seed (int) – The seed to use for reproducibility.
return_trainer (bool) – Whether to return the trainer object.
kwargs (dict) – Additional varword arguments to pass to the PL Trainer.
- Returns:
trainer – The PyTorch Lightning Trainer object.
- Return type:
Trainer