conda create --name tmp python=3.10 mamba install scvi-tools=1.1.2 jaxlib=0.4.16 jax=0.4.16 -c conda-forge mamba install conda-forge::flax