Skip to content

Instantly share code, notes, and snippets.

View giladturok's full-sized avatar
😃
frantically tuning hyper-parameters

Gilead Turok giladturok

😃
frantically tuning hyper-parameters
View GitHub Profile
import jax
import jax.numpy as jnp
from typing import Callable
class ParameterTransforms:
"""Base parameter transforms specification."""
def forward(self, x):
"""Forward transform from unconstrained to constrained space."""
raise NotImplementedError