multiDGD.nn
- class multiDGD.nn.Decoder(in_features: int, parameter_dictionary: dict)
Class for the DGD decoder with output modules modelling data according to a modality-specific distribution.
Arguments
- in_features: int
Number of input features (i.e. representation size)
- parameter_dictionary: dict
Dictionary containing the parameters for the decoder, including: - n_hidden: number of hidden layers in the main body - n_units: number of units in the hidden layers - n_hidden_modality: number of hidden layers in the output modules - n_features_per_modality: number of features in each output module - modalities: list of modalities to be modelled - decoder_width: width of the output layers (as a multiplier for n_units)
Attributes
- main: torch.nn.modules.container.ModuleList
The decoder portion shared between modalities (if multiple available) starts at representation and gives output to out_modules.
- out_modules: torch.nn.modules.container.ModuleList
Modules taking into account a specific distribution for calculating loss and can include more linear layers.
- n_out_groups: int
Number of output groups (i.e. modalities)
Methods
- forward(z)
forward pass through the decoder
- log_prob(nn_output, target, scale=None, mod_id=None, gene_id=None, reduction='sum', mask=None)
Calculate the log probability of the output (nn_output) given the target. If mod_id is None, the log_prob is calculated for all modalities and summed. If mod_id is not None, the log_prob is calculated for the specific modality. It can also be calculated for specific sets of features (gene_id) and cells (mask). Scale is the scaling factor of each cell, usually the count depth. The reduction can be sum, sample or none. sum sums the log_prob over all cells and features, sample sums over all features but not over cells and none does not sum over cells and features.
- loss(nn_output, target, scale=None, mod_id=None, gene_id=None, reduction='sum', mask=None)
Calculate the loss of the model predictions (nn_output) given the targets through the negative log probability (log_prob).
- class multiDGD.nn.NB_Layer(out_features, r_init=2, scaling_type='sum')
This is the Negative Binomial version of the OutputModule distribution layer.
Arguments
- out_features: int
number of features that come out of this layer
- r_init: int
initial value for the log-dispersion parameter
- scaling_type: str
type of scaling to be applied to the output
Attributes
fc: torch.nn.modules.container.ModuleList log_r: torch.nn.parameter.Parameter
log-dispersion parameter per feature
- dispersion: torch.nn.parameter.Parameter
dispersion parameter per feature
Methods
- property dispersion
returns the dispersion parameter
- forward(x)
forward pass through the NB layer
- log_prob(model_output, target, scaling_factor, gene_id=None, mask=None)
returns the log-prob of the NB layer
- loss(model_output, target, scaling_factor, gene_id=None, mask=None)
returns the loss of the NB layer
- norm_abs_error(model_output, target, scaling_factor, gene_id=None, mask=None)
returns the normalized absolute error of the NB layer
- static rescale(scaling_factor, model_output)
rescales the model output (mean normalized count)
- class multiDGD.nn.OutputModule(in_features: int, out_features: int, n_hidden: int, hidden_features: int, modality: str, layer_width: int)
This is the basis output module class that stands between the decoder and the output data.
Arguments
- in_features: int
number of features going into this layer
- out_features: int
number of features that come out of this layer
- n_hidden: int
number of hidden layers
- hidden_features: int
number of features in hidden layers
- modality: str
modality of the data
- layer_width: int
width of the output layers (as a multiplier for n_units)
Attributes
fc: torch.nn.modules.container.ModuleList n_in: int
number of hidden units going into this layer
- n_out: int
number of features that come out of this layer
- distribution: torch.nn.modules.module.Module
specific class depends on modality argument
Methods
- forward(x)
forward pass through the output module
- forward_shap(x)
placeholder for SHAP compatibility
- log_prob(model_output, target, scaling_factor, gene_id=None, mask=None)
returns log-prob of the output module
- loss(model_output, target, scaling_factor, gene_id=None, mask=None)
returns loss of the output module