eugene.interpret.attribute_sdata¶
- eugene.interpret.attribute_sdata(model, sdata, seq_var='ohe_seq', method='InputXGradient', reference_type=None, target=0, batch_size=None, device=None, num_workers=None, prefetch_factor=None, transforms=None, prefix='', suffix='', copy=False)¶
Compute attributions for model and SeqData combination.
This function wraps the
attribute
function from theseqexplainer
package to compute attributions for a model and SeqData combination. The attributions are stored in thesdata
object as a new variable. The attributions are computed in batches to avoid memory issues.- Parameters:
model (nn.Module) – Model to compute attributions for.
sdata (xr.Dataset) – SeqData to compute attributions for.
seq_var (str, optional) – Name of the sequence variable in
sdata
, by default “ohe_seq”.method (str, optional) – Attribution method to use, by default “InputXGradient”.
reference_type (Optional[str], optional) – Reference type to use, by default None.
target (int, optional) – Target class to compute attributions for, by default 0.
batch_size (Optional[int], optional) – Batch size to use, by default None.
device (Optional[str], optional) – Device to use, by default None.
num_workers (Optional[int], optional) – Number of workers to use, by default None.
prefetch_factor (Optional[int], optional) – Prefetch factor to use, by default None.
transforms (Optional[Dict[str, Any]], optional) – Additional transforms to apply to the data, by default None.
prefix (str, optional) – Prefix to add to the attribution variable name, by default “”.
suffix (str, optional) – Suffix to add to the attribution variable name, by default “”.
copy (bool, optional) – Whether to copy the data before adding the attribution variable, by default False.
- Returns:
The
sdata
object with the attribution variable added ifcopy
is False,- Return type:
Optional[xr.Dataset]