eugene.interpret.attribute_sdata

eugene.interpret.attribute_sdata(model, sdata, seq_var='ohe_seq', method='InputXGradient', reference_type=None, target=0, batch_size=None, device=None, num_workers=None, prefetch_factor=None, transforms=None, prefix='', suffix='', copy=False)

Compute attributions for model and SeqData combination.

This function wraps the attribute function from the seqexplainer package to compute attributions for a model and SeqData combination. The attributions are stored in the sdata object as a new variable. The attributions are computed in batches to avoid memory issues.

Parameters:
  • model (nn.Module) – Model to compute attributions for.

  • sdata (xr.Dataset) – SeqData to compute attributions for.

  • seq_var (str, optional) – Name of the sequence variable in sdata, by default “ohe_seq”.

  • method (str, optional) – Attribution method to use, by default “InputXGradient”.

  • reference_type (Optional[str], optional) – Reference type to use, by default None.

  • target (int, optional) – Target class to compute attributions for, by default 0.

  • batch_size (Optional[int], optional) – Batch size to use, by default None.

  • device (Optional[str], optional) – Device to use, by default None.

  • num_workers (Optional[int], optional) – Number of workers to use, by default None.

  • prefetch_factor (Optional[int], optional) – Prefetch factor to use, by default None.

  • transforms (Optional[Dict[str, Any]], optional) – Additional transforms to apply to the data, by default None.

  • prefix (str, optional) – Prefix to add to the attribution variable name, by default “”.

  • suffix (str, optional) – Suffix to add to the attribution variable name, by default “”.

  • copy (bool, optional) – Whether to copy the data before adding the attribution variable, by default False.

Returns:

The sdata object with the attribution variable added if copy is False,

Return type:

Optional[xr.Dataset]