eugene.interpret.positional_gia_sdata¶
- eugene.interpret.positional_gia_sdata(model, sdata, feature, feature_name='feature', seq_var='ohe_seq', id_var='id', store_var=None, device='cpu', encoding='onehot')¶
Implant a feature into all sequences in an xarray dataset and return the model predictions.
- Parameters:
model (torch.nn.Module) – The model to use for predictions.
sdata (xr.Dataset) – The dataset containing the sequence data.
feature (np.ndarray) – The feature to implant.
feature_name (str, optional) – The name of the feature, by default “feature”.
seq_var (str, optional) – The key for the sequence data in the dataset, by default “ohe_seq”.
id_var (str, optional) – The key for the sequence IDs in the dataset, by default “id”.
store_var (str, optional) – The key to store the predictions in the dataset, by default None.
device (str, optional) – The device to use for predictions, by default “cpu”.
encoding (str, optional) – The encoding of the sequence data, either “onehot” or “str”, by default “onehot”.
- Returns:
The model predictions.
- Return type:
np.ndarray