Model customization
multiDGD hyperparameters
Hyperparameters for multiDGD are provided in the form of a dictionary. The package has a default set of hyperparameters, which can be overwritten by the user (by providing a dictionary to the model initialization, see below).
import multiDGD
custom_parameters = {
# custom hyperparameters
}
# initializing the model with default and custom parameters
model = multiDGD.DGD(
data=data,
parameter_dictionary=custom_parameters,
)
The following table lists the hyperparameters that can be customized, as well as their default values:
Argument |
Type |
Default |
Description |
|---|---|---|---|
|
int |
|
Dimensionality of the latent space. |
|
int |
|
Number of components in the mixture model. |
|
int |
|
Number of hidden layers in the shared decoder (\(\theta_{h}\)). |
|
int |
|
Number of hidden layers in the modality-specific decoder (\(\theta_{h_m}\)). |
|
int |
|
Number of units in the hidden layers (except the last layer, which is the maximum of \(\{100, \sqrt{|features|}\}\)). |
|
str |
|
Initialization of the weights. Options are |
|
float |
|
Scale parameter of the Softball prior (see manuscript). It determines the scale of the sphere of the (mollified uniform) prior over component means. |
|
float |
|
Hardness parameter of the Softball prior (see manuscript). It determines how not smooth the transition from probability 1 to 0 is. |
|
float |
|
Standard deviation of the Gaussian prior over the negative log covariance. It is pretty irrelevant and can just stay at 1. The mean of this prior is determined by the number of components and the softball scale. |
|
float |
|
Same as |
|
float |
|
Same as |
|
float |
|
Same as |
|
float |
|
Concentration parameter of the Dirichlet prior over the mixture weights. Higher values means stronger enforcement of equal probabilities. |
|
int |
|
Batch size for training. |
|
list |
|
Learning rates for the three sets of parameters: decoder, representation, GMM. |
|
list |
|
Betas for the Adam optimizer. |
|
float |
|
Weight decay for the Adam optimizer. |
|
int |
|
Multiplies all hidden units by its factor (to bypass the last layer width rule). |
|
list |
|
List of strings to log to wandb. |