multiDGD.nn

class multiDGD.nn.Decoder(in_features: int, parameter_dictionary: dict)

Class for the DGD decoder with output modules modelling data according to a modality-specific distribution.

Arguments

in_features: int

Number of input features (i.e. representation size)

parameter_dictionary: dict

Dictionary containing the parameters for the decoder, including: - n_hidden: number of hidden layers in the main body - n_units: number of units in the hidden layers - n_hidden_modality: number of hidden layers in the output modules - n_features_per_modality: number of features in each output module - modalities: list of modalities to be modelled - decoder_width: width of the output layers (as a multiplier for n_units)

Attributes

main: torch.nn.modules.container.ModuleList

The decoder portion shared between modalities (if multiple available) starts at representation and gives output to out_modules.

out_modules: torch.nn.modules.container.ModuleList

Modules taking into account a specific distribution for calculating loss and can include more linear layers.

n_out_groups: int

Number of output groups (i.e. modalities)

Methods

forward(z)

forward pass through the decoder

log_prob(nn_output, target, scale=None, mod_id=None, gene_id=None, reduction='sum', mask=None)

Calculate the log probability of the output (nn_output) given the target. If mod_id is None, the log_prob is calculated for all modalities and summed. If mod_id is not None, the log_prob is calculated for the specific modality. It can also be calculated for specific sets of features (gene_id) and cells (mask). Scale is the scaling factor of each cell, usually the count depth. The reduction can be sum, sample or none. sum sums the log_prob over all cells and features, sample sums over all features but not over cells and none does not sum over cells and features.

loss(nn_output, target, scale=None, mod_id=None, gene_id=None, reduction='sum', mask=None)

Calculate the loss of the model predictions (nn_output) given the targets through the negative log probability (log_prob).

class multiDGD.nn.NB_Layer(out_features, r_init=2, scaling_type='sum')

This is the Negative Binomial version of the OutputModule distribution layer.

Arguments

out_features: int

number of features that come out of this layer

r_init: int

initial value for the log-dispersion parameter

scaling_type: str

type of scaling to be applied to the output

Attributes

fc: torch.nn.modules.container.ModuleList log_r: torch.nn.parameter.Parameter

log-dispersion parameter per feature

dispersion: torch.nn.parameter.Parameter

dispersion parameter per feature

Methods

property dispersion

returns the dispersion parameter

forward(x)

forward pass through the NB layer

log_prob(model_output, target, scaling_factor, gene_id=None, mask=None)

returns the log-prob of the NB layer

loss(model_output, target, scaling_factor, gene_id=None, mask=None)

returns the loss of the NB layer

norm_abs_error(model_output, target, scaling_factor, gene_id=None, mask=None)

returns the normalized absolute error of the NB layer

static rescale(scaling_factor, model_output)

rescales the model output (mean normalized count)

class multiDGD.nn.OutputModule(in_features: int, out_features: int, n_hidden: int, hidden_features: int, modality: str, layer_width: int)

This is the basis output module class that stands between the decoder and the output data.

Arguments

in_features: int

number of features going into this layer

out_features: int

number of features that come out of this layer

n_hidden: int

number of hidden layers

hidden_features: int

number of features in hidden layers

modality: str

modality of the data

layer_width: int

width of the output layers (as a multiplier for n_units)

Attributes

fc: torch.nn.modules.container.ModuleList n_in: int

number of hidden units going into this layer

n_out: int

number of features that come out of this layer

distribution: torch.nn.modules.module.Module

specific class depends on modality argument

Methods

forward(x)

forward pass through the output module

forward_shap(x)

placeholder for SHAP compatibility

log_prob(model_output, target, scaling_factor, gene_id=None, mask=None)

returns log-prob of the output module

loss(model_output, target, scaling_factor, gene_id=None, mask=None)

returns loss of the output module