eugene.evaluate.train_val_predictions_sequence_module

eugene.evaluate.train_val_predictions_sequence_module(model, sdata=None, seq_var='ohe_seq', target_vars=None, train_var='train_val', gpus=None, batch_size=None, num_workers=None, transforms=None, prefetch_factor=None, store_only=False, in_memory=False, out_dir=None, name=None, version='', prefix='', suffix='', copy=False)

Predictions for a SequenceModule model and SeqData

This is a wrapper around the predictions function that takes a SeqData object and builds a dataloader from it. It also adds the predictions to the SeqData object.

Parameters:
  • model (LightningModule) – Model to predict with.

  • sdata (xr.Dataset, optional) – SeqData object to predict with. If None, uses the sdata in settings.

  • seq_var (str, optional) – Key in sdata to use as the sequence. If None, uses “ohe_seq”.

  • target_vars (str or list of str, optional) – Key(s) in sdata to use as the target. If None, uses None.

  • train_var (str, optional) – Key in sdata to use as the train/val variable. If None, uses “train_val”.

  • gpus (int, optional) – Number of GPUs to use. If None, uses settings.gpus.

  • batch_size (int, optional) – Batch size to use. If None, uses settings.batch_size.

  • num_workers (int, optional) – Number of workers to use. If None, uses settings.dl_num_workers.

  • transforms (dict, optional) –

Returns:

sdata – SeqData object with predictions added if copy=True. If copy=False, returns None.

Return type:

xr.Dataset