eugene.models.SequenceModule¶
- class eugene.models.SequenceModule(arch, task='regression', loss_fxn='mse', optimizer='adam', optimizer_lr=0.001, optimizer_kwargs={}, scheduler=None, scheduler_monitor='val_loss_epoch', scheduler_kwargs={}, metric=None, metric_kwargs={}, seed=None, save_hyperparams=True, arch_name=None, model_name=None)¶
Base LightningModule class for EUGENe that handles models that predict single tensor outputs.
SequenceModules expect to train an architecture that ingests a single tensor (usually one-hot encoded DNA sequences) as input and outputs a single tensor. Examples of models that can be trained with SequenceModule include:
DeepBind
DeepSEA
DanQ
Basset
DeepSTARR
And many more!
The SequenceModule class handles the loss function, optimizer, scheduler, and metric for the model unless otherwise specified. We currently support custom loss functions, but only a set of pre-defined optimizers, schedulers, and metrics. We are working on adding support for custom optimizers, schedulers, and metrics, but for now you can find the list of supported optimizers, schedulers, and metrics in the documentation.
- Parameters:
arch (torch.nn.Module) – The architecture to train. e.g. DeepBind, DeepSEA, DanQ, Basset, DeepSTARR
task (Literal["regression", "binary_classification", "multiclass_classification", "multilabel_classification"]) – task of the model. SequenceModule currently supports “regression”, “binary_classification”, “multiclass_classification” and “multilabel_classification”
loss_fxn (Union[str, Callable]) – loss function to use. If not specified, it will be inferred from the task. e.g. if the task is “regression”, the loss function will be set to “mse”. Custom loss functions can be passed in as a Callable.
optimizer (Literal["adam", "sgd"]) – optimizer to use. We currently support “adam” and “sgd”
optimizer_lr (float) – starting learning rate for the optimizer
optimizer_kwargs (dict) – additional arguments to pass to the optimizer
scheduler (Optional[str]) – scheduler to use. We currently support “reduce_lr_on_plateau”
scheduler_monitor (str) – metric to monitor for the scheduler
scheduler_kwargs (dict) – additional arguments to pass to the scheduler
metric (Optional[str]) – metric (other than loss) to track during training. If not specified, it will be inferred from the task. e.g. if the task is “regression”, the metric will be set to “r2score”. We currently support “r2score”, “pearson”, “spearman”, “explainedvariance”, “auroc”, “accuracy”, “f1score”, “precision”, and “recall”
metric_kwargs (dict) – additional arguments to pass to the metric. Note that for many cases, specific metrics and task combinations require specific keyword arguments. For example, “multilable_classifcation” should use “auroc” as the metric, and the “task” keyword argument should be set to “multilabel”. See the documentation for more details.
seed (Optional[int]) – seed to use for reproducibility. If not specified, no seed will be set.
save_hyperparams (bool) – whether to save the hyperparameters of the model. If True, the hyperparameters will be saved in the model checkpoint directory
arch_name (Optional[str]) – name of the architecture. If not specified, it will be inferred from the architecture class name
model_name (Optional[str]) – name of the specific instantiation of the model. If not specified, it will be set to “model”. This is useful for keeping track of multiple models that might have the same architecture
- __init__(arch, task='regression', loss_fxn='mse', optimizer='adam', optimizer_lr=0.001, optimizer_kwargs={}, scheduler=None, scheduler_monitor='val_loss_epoch', scheduler_kwargs={}, metric=None, metric_kwargs={}, seed=None, save_hyperparams=True, arch_name=None, model_name=None)¶
Methods
__init__
(arch[, task, loss_fxn, optimizer, ...])add_module
(name, module)Adds a child module to the current module.
all_gather
(data[, group, sync_grads])Gather tensors or collections of tensors from multiple processes.
apply
(fn)Applies
fn
recursively to every submodule (as returned by.children()
) as well as self.backward
(loss, *args, **kwargs)Called to perform backward on the loss returned in
training_step()
.bfloat16
()Casts all floating point parameters and buffers to
bfloat16
datatype.buffers
([recurse])Returns an iterator over module buffers.
children
()Returns an iterator over immediate children modules.
clip_gradients
(optimizer[, ...])Handles gradient clipping internally.
configure_callbacks
()Configure model-specific callbacks.
configure_gradient_clipping
(optimizer[, ...])Perform gradient clipping for the optimizer parameters.
configure_metrics
(metric, metric_kwargs)Configure metrics
configure_optimizers
()Configure optimizers
configure_sharded_model
()Hook to create modules in a distributed aware context.
cpu
()See
torch.nn.Module.cpu()
.cuda
([device])Moves all model parameters and buffers to the GPU.
double
()See
torch.nn.Module.double()
.eval
()Sets the module in evaluation mode.
extra_repr
()Set the extra representation of the module
float
()See
torch.nn.Module.float()
.forward
(x)Forward pass of the arch.
freeze
()Freeze all params for inference.
get_buffer
(target)Returns the buffer given by
target
if it exists, otherwise throws an error.get_extra_state
()Returns any extra state to include in the module's state_dict.
get_parameter
(target)Returns the parameter given by
target
if it exists, otherwise throws an error.get_submodule
(target)Returns the submodule given by
target
if it exists, otherwise throws an error.half
()See
torch.nn.Module.half()
.ipu
([device])Moves all model parameters and buffers to the IPU.
load_from_checkpoint
(checkpoint_path[, ...])Primary way of loading a model from a checkpoint.
load_state_dict
(state_dict[, strict])Copies parameters and buffers from
state_dict
into this module and its descendants.log
(name, value[, prog_bar, logger, ...])Log a key, value pair.
log_dict
(dictionary[, prog_bar, logger, ...])Log a dictionary of values at once.
lr_scheduler_step
(scheduler, metric)Override this method to adjust the default way the
Trainer
calls each scheduler.lr_schedulers
()Returns the learning rate scheduler(s) that are being used during training.
manual_backward
(loss, *args, **kwargs)Call this directly from your
training_step()
when doing optimizations manually.modules
()Returns an iterator over all modules in the network.
named_buffers
([prefix, recurse])Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
named_children
()Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
named_modules
([memo, prefix, remove_duplicate])Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
named_parameters
([prefix, recurse])Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
on_after_backward
()Called after
loss.backward()
and before optimizers are stepped.on_after_batch_transfer
(batch, dataloader_idx)Override to alter or apply batch augmentations to your batch after it is transferred to the device.
on_before_backward
(loss)Called before
loss.backward()
.on_before_batch_transfer
(batch, dataloader_idx)Override to alter or apply batch augmentations to your batch before it is transferred to the device.
on_before_optimizer_step
(optimizer)Called before
optimizer.step()
.on_before_zero_grad
(optimizer)Called after
training_step()
and beforeoptimizer.zero_grad()
.on_fit_end
()Called at the very end of fit.
on_fit_start
()Called at the very beginning of fit.
on_load_checkpoint
(checkpoint)Called by Lightning to restore your model.
on_predict_batch_end
(outputs, batch, batch_idx)Called in the predict loop after the batch.
on_predict_batch_start
(batch, batch_idx[, ...])Called in the predict loop before anything happens for that batch.
on_predict_end
()Called at the end of predicting.
on_predict_epoch_end
()Called at the end of predicting.
on_predict_epoch_start
()Called at the beginning of predicting.
on_predict_model_eval
()Sets the model to eval during the predict loop.
on_predict_start
()Called at the beginning of predicting.
on_save_checkpoint
(checkpoint)Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.
on_test_batch_end
(outputs, batch, batch_idx)Called in the test loop after the batch.
on_test_batch_start
(batch, batch_idx[, ...])Called in the test loop before anything happens for that batch.
on_test_end
()Called at the end of testing.
on_test_epoch_end
()Called in the test loop at the very end of the epoch.
on_test_epoch_start
()Called in the test loop at the very beginning of the epoch.
on_test_model_eval
()Sets the model to eval during the test loop.
on_test_model_train
()Sets the model to train during the test loop.
on_test_start
()Called at the beginning of testing.
on_train_batch_end
(outputs, batch, batch_idx)Called in the training loop after the batch.
on_train_batch_start
(batch, batch_idx)Called in the training loop before anything happens for that batch.
on_train_end
()Called at the end of training before logger experiment is closed.
on_train_epoch_end
()Called in the training loop at the very end of the epoch.
on_train_epoch_start
()Called in the training loop at the very beginning of the epoch.
on_train_start
()Called at the beginning of training after sanity check.
on_validation_batch_end
(outputs, batch, ...)Called in the validation loop after the batch.
on_validation_batch_start
(batch, batch_idx)Called in the validation loop before anything happens for that batch.
on_validation_end
()Called at the end of validation.
on_validation_epoch_end
()Called in the validation loop at the very end of the epoch.
on_validation_epoch_start
()Called in the validation loop at the very beginning of the epoch.
on_validation_model_eval
()Sets the model to eval during the val loop.
on_validation_model_train
()Sets the model to train during the val loop.
on_validation_start
()Called at the beginning of validation.
optimizer_step
(epoch, batch_idx, optimizer)Override this method to adjust the default way the
Trainer
calls the optimizer.optimizer_zero_grad
(epoch, batch_idx, optimizer)Override this method to change the default behaviour of
optimizer.zero_grad()
.optimizers
([use_pl_optimizer])Returns the optimizer(s) that are being used during training.
parameters
([recurse])Returns an iterator over module parameters.
predict
(x[, batch_size, verbose])Predict the output of the model in batches.
predict_dataloader
()An iterable or collection of iterables specifying prediction samples.
predict_step
(batch, batch_idx[, dataloader_idx])Predict step
prepare_data
()Use this to download and prepare data.
print
(*args, **kwargs)Prints only from process 0.
register_backward_hook
(hook)Registers a backward hook on the module.
register_buffer
(name, tensor[, persistent])Adds a buffer to the module.
register_forward_hook
(hook)Registers a forward hook on the module.
register_forward_pre_hook
(hook)Registers a forward pre-hook on the module.
register_full_backward_hook
(hook)Registers a backward hook on the module.
register_load_state_dict_post_hook
(hook)Registers a post hook to be run after module's
load_state_dict
is called.register_module
(name, module)Alias for
add_module()
.register_parameter
(name, param)Adds a parameter to the module.
requires_grad_
([requires_grad])Change if autograd should record operations on parameters in this module.
save_hyperparameters
(*args[, ignore, frame, ...])Save arguments to
hparams
attribute.set_extra_state
(state)This function is called from
load_state_dict()
to handle any extra state found within the state_dict.setup
(stage)Called at the beginning of fit (train + validate), validate, test, or predict.
share_memory
()See
torch.Tensor.share_memory_()
state_dict
(*args[, destination, prefix, ...])Returns a dictionary containing references to the whole state of the module.
summary
()Print a summary of the model
teardown
(stage)Called at the end of fit (train + validate), validate, test, or predict.
test_dataloader
()An iterable or collection of iterables specifying test samples.
test_step
(batch, batch_idx)Test step
to
(*args, **kwargs)See
torch.nn.Module.to()
.to_empty
(*, device)Moves the parameters and buffers to the specified device without copying storage.
to_onnx
(file_path[, input_sample])Saves the model in ONNX format.
to_torchscript
([file_path, method, ...])By default compiles the whole model to a
ScriptModule
.toggle_optimizer
(optimizer)Makes sure only the gradients of the current optimizer's parameters are calculated in the training step to prevent dangling gradients in multiple-optimizer setup.
train
([mode])Sets the module in training mode.
train_dataloader
()An iterable or collection of iterables specifying training samples.
training_step
(batch, batch_idx)Training step
transfer_batch_to_device
(batch, device, ...)Override this hook if your
DataLoader
returns tensors wrapped in a custom data structure.type
(dst_type)See
torch.nn.Module.type()
.unfreeze
()Unfreeze all parameters for training.
untoggle_optimizer
(optimizer)Resets the state of required gradients that were toggled with
toggle_optimizer()
.val_dataloader
()An iterable or collection of iterables specifying validation samples.
validation_step
(batch, batch_idx)Validation step
xpu
([device])Moves all model parameters and buffers to the XPU.
zero_grad
([set_to_none])Sets gradients of all model parameters to zero.
Attributes
CHECKPOINT_HYPER_PARAMS_KEY
CHECKPOINT_HYPER_PARAMS_NAME
CHECKPOINT_HYPER_PARAMS_TYPE
T_destination
arch
Model
automatic_optimization
If set to
False
you are responsible for calling.backward()
,.step()
,.zero_grad()
.current_epoch
The current epoch in the
Trainer
, or 0 if not attached.device
dtype
dump_patches
example_input_array
The example input array is a specification of what the module can consume in the
forward()
method.fabric
global_rank
The index of the current process across all nodes and devices.
global_step
Total training batches seen across all epochs.
hparams
The collection of hyperparameters saved with
save_hyperparameters()
.hparams_initial
The collection of hyperparameters saved with
save_hyperparameters()
.local_rank
The index of the current process within a single node.
logger
Reference to the logger object in the Trainer.
loggers
Reference to the list of loggers in the Trainer.
loss_fxn
Loss function
on_gpu
Returns
True
if this model is currently located on a GPU.optimizer
Optimizer
optimizer_lr
Optimizer starting learning rate
scheduler
Scheduler
seed
Seed
task
Task
train_metric
Train metric
trainer
training