Plotting
After training, you can plot the training history. This can help understand whether the model has converged or is overfitting:
model.plot_history()
Visualizing the representations (embeddings):
PCA of plain representations and the GMM (means and samples):
model.plot_latent_space()
UMAP of representations using
scanpy:data.obsm['latent'] = model.get_representation() data.obs['cluster'] = model.clustering().astype(str) sc.pp.neighbors(data, use_rep='latent') sc.tl.umap(data, min_dist=1.0) sc.pl.umap(data, color='observable') sc.pl.umap(data, color='cluster')
Covariate representations (2D):
cov_rep = model.get_covariate_representation() import seaborn as sns sns.scatterplot(x=cov_rep[:, 0], y=cov_rep[:, 1], hue=data.obs[data.obs["train_val_test"]=="train"]["Site"].values)
See this notebook for the example plots: