Install notes from BDL2022f env install on 2022-11-28
- For CUDA TOOLKIT 11.3, which can be used on older devices but may not be optimal
- set up basic conda env without any torch or jax packages, via
conda env create -f bdl_2022f.yml
| import numpy as np | |
| import jax | |
| import jax.numpy as jnp | |
| import jax.nn | |
| def calc_trans_mat(x_TD, r_KD, p_KK): | |
| ''' Compute transition matrix for each timestep | |
| Args |
| import jax | |
| import jax.numpy as jnp | |
| if __name__ == '__main__': | |
| print("jax.devices()") | |
| print(jax.devices()) | |
| a = jnp.asarray([[1.0, 2.0, 3.0], [4., 5., 6.]]) | |
| b = jnp.asarray([[1.0, 2.0], [3.0, 4.0], [5., 6.]]) |
Install notes from BDL2022f env install on 2022-11-28
conda env create -f bdl_2022f.yml
| import numpy as np | |
| import scipy.stats | |
| import matplotlib.pyplot as plt | |
| from statsmodels.distributions.empirical_distribution import ECDF | |
| def create_transform_func_to_match_source(target_x_ND, src_x_MD, n_quantiles=1000): | |
| ''' |
| ''' VI for Poisson Normal | |
| Model | |
| ----- | |
| Latent variable z is drawn from a Normal prior: z ~ Normal( 40, 10) | |
| Data y is drawn iid from a Poisson likelihood: y_n ~ Poisson(z) | |
| Approx Posterior | |
| ---------------- | |
| Posterior on z is assumed to be Normal with unknown mean and stddev |
| window_size | sample_id | accuracy | |
|---|---|---|---|
| 5.0 | 0 | 0.6236094882645041 | |
| 5.0 | 1 | 0.593111865845944 | |
| 5.0 | 2 | 0.6060493252962028 | |
| 5.0 | 3 | 0.6342719738873018 | |
| 5.0 | 4 | 0.6259239448289695 | |
| 5.0 | 5 | 0.5623114809268821 | |
| 5.0 | 6 | 0.6054087015122116 | |
| 5.0 | 7 | 0.5807796285836049 | |
| 5.0 | 8 | 0.5818560349582886 |
| vals_float32 = np.logspace(0, 5, dtype=np.float32) | |
| vals_float64 = np.logspace(0, 5, dtype=np.float64) | |
| ## Pretty-print output of array so each float takes same num chars | |
| def pprint_arr(arr, n_per_line=6): | |
| for s in range(0, arr.size, n_per_line): | |
| chunk = arr[s:s+n_per_line] | |
| print(" ".join(["%10s" % np.format_float_scientific(x, precision=2, unique=False, exp_digits=3) for x in chunk])) | |
| print() |