eugene

API

preprocess

from eugene import preprocess

This module is designed to let users interact and modify SeqData objects to prepare for model training and other steps of the workflow. There are three main classes of preprocessing functions.

Sequence preprocessing

preprocess.make_unique_ids_sdata(sdata[, ...])

Make a set of unique ids for each sequence in a SeqData object and store as new xarray variable.

preprocess.pad_seqs_sdata(sdata, length[, ...])

Pad sequences in a SeqData object.

preprocess.ohe_seqs_sdata(sdata[, alphabet, ...])

One-hot encode sequences in a SeqData object.

Train-test splitting

preprocess.train_test_chrom_split(sdata, ...)

Add a variable labeling sequences as part of the train or test split based on chromosome.

preprocess.train_test_homology_split(sdata, ...)

Add a variable labeling sequences as part of the train or test split, splitting by homology.

preprocess.train_test_random_split(sdata, dim)

Add a variable labeling sequences as part of the train or test split, splitting randomly.

Target preprocessing

preprocess.clamp_targets_sdata(sdata, ...[, ...])

Clamp targets to a given percentile in a SeqData object.

preprocess.scale_targets_sdata(sdata, ...[, ...])

Scale targets in a SeqData object.

dataload

from eugene import dataload

This module is designed to help users prepare their SeqDatas for model training and other steps of the workflow (e.g. augmentation)

SeqData utilities

dataload.concat_sdatas(sdatas[, keys])

Concatenate multiple SeqDatas into one.

dataload.add_obs(sdata, obs[, on, left_on, ...])

Add observational metadata to a SeqData.

Augmentation

dataload.RandomRC([rc_prob])

Randomly applies a reverse-complement transformation to each sequence in a training batch

models

from eugene import models

This module is designed to allow users to easily build and initialize several neural network architectures that are designed for biological sequences.

Blocks

Blocks are composed to create architectures in EUGENe. You can find all the arguments that would be passed into the dense_kwargs and recurrent_kwargs arguments of all built-in model in the DenseBlock and RecurrentBlock classes, respectively. See the towers section for more information on the conv_kwargs argument.

models.DenseBlock(input_dim, output_dim[, ...])

A block for dense layers

models.Conv1DBlock(input_len, ...[, ...])

Flexible block for convolutional models

models.RecurrentBlock(input_dim, hidden_dim)

A block for recurrent layers

Towers

The Conv1DTower class is currently used for all built-in CNNs. This will be deprecated in the future in favor of the more general Tower class. For now, you can find all the arguments that would be passed into the cnn_kwargs argument of all built-in CNNs in the Conv1DTower class.

models.Tower(block, repeats, input_size[, ...])

A tower of blocks.

models.Conv1DTower(input_len, ...[, ...])

Generates a PyTorch module for multiple convolutional layers

LightningModules

models.SequenceModule(arch[, task, ...])

Base LightningModule class for EUGENe that handles models that predict single tensor outputs.

models.ProfileModule(arch[, task, ...])

LightningModule class for training models that predict profile data (both shape and count).

Initialization

models.init_weights(model[, initializer])

Initialize the weights of a model.

models.init_motif_weights(model, layer_name, ...)

Initialize the convolutional kernel of choice using a set of motifs

Zoo

Arguments for the cnn_kwargs, recurrent_kwargs and dense_kwargs of all models can be found in the Conv1DTower, RecurrentBlock and DenseBlock classes, respectively. See the blocks section and the towers section for more information. The Satori architecture currently uses the MultiHeadAttention layer which can be found at eugene.models.base._layers for more information on the mha_kwargs argument.

models.zoo.FCN(input_len, output_dim[, ...])

Basic fully connected network

models.zoo.dsFCN(input_len, output_dim[, ...])

Basic FCN model with reverse complement

models.zoo.CNN(input_len, output_dim, ...[, ...])

Basic convolutional network

models.zoo.dsCNN(input_len, output_dim, ...)

Basic CNN model with reverse complement

models.zoo.RNN(input_len, output_dim, ...[, ...])

Basic recurrent network

models.zoo.dsRNN(input_len, output_dim, ...)

Basic RNN model with reverse complement

models.zoo.Hybrid(input_len, output_dim, ...)

Basic hybrid network

models.zoo.dsHybrid(input_len, output_dim, ...)

Basic hybrid network with reverse complement

models.zoo.TutorialCNN(input_len, output_dim)

Tutorial CNN model

models.zoo.DeepBind(input_len, output_dim[, ...])

DeepBind architecture implemented from Alipanahi et al 2015 in PyTorch

models.zoo.ResidualBind(input_len, output_dim)

ResidualBind architecture implemented from Koo et al 2021 in PyTorch

models.zoo.Kopp21CNN(input_len, output_dim)

Custom convolutional model used in Kopp et al. 2021 paper.

models.zoo.DeepSEA(input_len, output_dim[, ...])

DeepSEA model implementation from Zhou et al 2015 in PyTorch

models.zoo.Basset(input_len, output_dim[, ...])

Basset model implementation from Kelley et al 2016 in PyTorch

models.zoo.FactorizedBasset([input_len, ...])

Factorized Basset model implementation from Wnuk et al 2017 in PyTorch

models.zoo.DanQ(input_len, output_dim[, ...])

DanQ model from Quang and Xie 2016 in PyTorch

models.zoo.Satori(input_len, output_dim[, ...])

Satori model from Ullah and Ben-Hur 2021 in PyTorch

models.zoo.Jores21CNN(input_len, output_dim)

Custom convolutional model used in Jores et al. 2021 paper.

models.zoo.DeepSTARR(input_len, output_dim)

DeepSTARR model from de Almeida et al 2022

models.zoo.BPNet(input_len, output_dim[, ...])

This nn.Module was taken without permission from a Mr.

models.zoo.DeepMEL(input_len, output_dim[, ...])

DeepMEL model implementation from Minnoye et al 2020 in PyTorch

models.zoo.scBasset(num_cells[, ...])

scBasset model implementation from Yuan et al 2022 in PyTorch

Utilities

models.list_available_layers(model)

List all layers in a model

models.get_layer(model, layer_name)

Get a layer from a model by name

models.load_config(config_path, **kwargs)

Instantiate a module or architecture from a config file

train

from eugene import train

Training procedures for data and models.

train.fit(model, train_dataloader[, ...])

Fit a model using PyTorch Lightning.

train.fit_sequence_module(model, sdata[, ...])

Fit a SequenceModule using PyTorch Lightning.

train.hyperopt()

evaluate

from eugene import evaluate

Evaluation functions for trained models. Both prediction helpers and metrics.

Predictions

evaluate.predictions(model, dataloader[, ...])

Predictions from a model and dataloader.

evaluate.predictions_sequence_module(model)

Predictions for a SequenceModule model and SeqData

evaluate.train_val_predictions(model, ...[, ...])

Predictions from a model and train/val dataloaders.

evaluate.train_val_predictions_sequence_module(model)

Predictions for a SequenceModule model and SeqData

interpret

from eugene import interpret

Interpretation suite of EUGENe, currently broken into filter visualization, feature attribution and in silico experimentation

Filter interpretation

interpret.generate_pfms_sdata(model, sdata, ...)

Generate position frequency matrices (PFMs) for a given layer in a PyTorch model.

interpret.filters_to_meme_sdata(sdata, ...)

Convert position frequency matrices (PFMs) to a MEME motif file.

Attribution analysis

interpret.attribute_sdata(model, sdata[, ...])

Compute attributions for model and SeqData combination.

Global importance analysis (GIA)

interpret.positional_gia_sdata(model, sdata, ...)

Implant a feature into all sequences in an xarray dataset and return the model predictions.

interpret.motif_distance_dependence_gia(...)

Calculate the dependence of the model predictions on the distance between two motifs.

Generative

interpret.evolve_seqs_sdata(model, sdata, rounds)

In silico evolve a set of sequences that are stored in a SeqData object.

plot

from eugene import plot

Plotting suite in EUGENe for multiple aspects of the workflow.

Categorical plotting

plot.countplot(sdata, vars[, groupby, ...])

Plots a countplot of a vars in a SeqData using Seaborn.

plot.histplot(sdata, vars[, orient, ...])

Plots a histogram of a vars in a SeqData using seaborn.

plot.boxplot(sdata, vars[, groupby, orient, ...])

Plots a boxplot of a vars in a SeqData using Seaborn.

plot.violinplot(sdata[, vars, groupby, ...])

Plots a violinplot of a vars in a SeqData using Seaborn.

plot.scatterplot(sdata, x, y[, seq_idx, ...])

Plots a scatterplot of two columns in seqs_annot using Seaborn.

Training summaries

plot.metric_curve(log_path, metric[, hue, ...])

Plots the loss curves from a PyTorch Lightning (PL) training run.

plot.loss_curve(log_path[, title, xlab, ...])

Plots the loss curves from a PyTorch Lightning (PL) training run.

plot.training_summary(log_path[, metric, ...])

Plots the training summary from a given training run

Performance

plot.performance_scatter(sdata, target_vars, ...)

Plot a scatter plot of the performance of the model on a subset of the sequences.

plot.confusion_mtx(sdata, target_var, ...[, ...])

Plot a confusion matrix for given targets and predictions within SeqData

plot.auroc(sdata, target_vars, prediction_vars)

Plot the area under the receiver operating characteristic curve for one or more predictions against a one or more targets.

plot.auprc(sdata, target_vars, prediction_vars)

Plot the area under the precision recall curve for one or more predictions against a one or more targets.

plot.performance_summary(sdata, target_var)

Plot a performance summary across model predictions for a passed in metric

Sequences

plot.seq_track(sdata, seq_id, attrs_var[, ...])

Plot a track of the importance scores for a sequence using the logomaker package

plot.multiseq_track(sdata, seq_ids, attrs_vars)

Plot the saliency tracks for multiple sequences across multiple importance scores in one plot.

plot.filter_viz(sdata, filter_num, pfms_var)

Plot the PFM for a single filter in a SeqData object's uns dictionary as a PWM logo

plot.multifilter_viz(sdata, filter_nums, ...)

Plot multiple filters in a SeqData object's uns dictionary as PWM logos.

Global importance analysis (GIA)

plot.positional_gia_plot(sdata, vars[, ...])

Plot a lineplot for each position of the sequence after implanting a feature.

plot.distance_cooperativity_gia_plot(sdata)

Plot the median predicted cooperativity as a function of motif pair distance.

utils

File I/O

utils.make_dirs(output_dir[, overwrite])

Make a directory if it doesn't exist.