eugene.interpret.positional_gia_sdata

eugene.interpret.positional_gia_sdata(model, sdata, feature, feature_name='feature', seq_var='ohe_seq', id_var='id', store_var=None, device='cpu', encoding='onehot')

Implant a feature into all sequences in an xarray dataset and return the model predictions.

Parameters:
  • model (torch.nn.Module) – The model to use for predictions.

  • sdata (xr.Dataset) – The dataset containing the sequence data.

  • feature (np.ndarray) – The feature to implant.

  • feature_name (str, optional) – The name of the feature, by default “feature”.

  • seq_var (str, optional) – The key for the sequence data in the dataset, by default “ohe_seq”.

  • id_var (str, optional) – The key for the sequence IDs in the dataset, by default “id”.

  • store_var (str, optional) – The key to store the predictions in the dataset, by default None.

  • device (str, optional) – The device to use for predictions, by default “cpu”.

  • encoding (str, optional) – The encoding of the sequence data, either “onehot” or “str”, by default “onehot”.

Returns:

The model predictions.

Return type:

np.ndarray