Skip to content

Instantly share code, notes, and snippets.

@llandsmeer
Last active January 30, 2025 22:32
Show Gist options
  • Save llandsmeer/c56aba4a5a4d0aa249afded929a7d1b0 to your computer and use it in GitHub Desktop.
Save llandsmeer/c56aba4a5a4d0aa249afded929a7d1b0 to your computer and use it in GitHub Desktop.
iojax
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import functools
import typing
def main():
params = IOCellParams(g_CaL=1.2)
state0 = IOCellState.make(params)
@jax.jit
def f(state, iapp):
return timestep(state, params, I_app= iapp), state
iapp = jnp.zeros(10 * 40_000) # 40000 timesteps at 0.025 ms timesteps is 1 second
# or iapp = jnp.linspace(0, -0.1, 10 * 40_000)
_, trace = jax.lax.scan(f, state0, iapp)
plt.plot(trace.V_soma)
plt.show()
class IOCellParams(typing.NamedTuple):
g_int : float = 0.13 # Cell internal conductance -- now a parameter (0.13)
p1 : float = 0.25 # Cell surface ratio soma/dendrite
p2 : float = 0.15 # Cell surface ratio axon(hillock)/soma
g_CaL : float = 1.1 # Calcium T - (CaV 3.1) (0.7)
g_h : float = 0.12 # H current (HCN) (0.4996)
g_K_Ca : float = 35.0 # Potassium (KCa v1.1 - BK) (35)
g_ld : float = 0.01532 # Leak dendrite (0.016)
g_la : float = 0.016 # Leak axon (0.016)
g_ls : float = 0.016 # Leak soma (0.016)
g_Na_s : float = 150.0 # Sodium - (Na v1.6 )
g_Kdr_s : float = 9.0 # Potassium - (K v4.3)
g_K_s : float = 5.0 # Potassium - (K v3.4)
g_CaH : float = 4.5 # High-threshold calcium -- Ca V2.1
g_Na_a : float = 240.0 # Sodium
g_K_a : float = 240.0 # Potassium (20)
S : float = 1.0 # 1/C_m, cm^2/uF
V_Na : float = 55.0 # Sodium
V_K : float = -75.0 # Potassium
V_Ca : float = 120.0 # Low-threshold calcium channel
V_h : float = -43.0 # H current
V_l : float = 10.0 # Leak
class IOCellState(typing.NamedTuple):
V_soma : float|jax.Array = -60.0
soma_k : float|jax.Array = 0.
soma_l : float|jax.Array = 0.
soma_h : float|jax.Array = 0.
soma_n : float|jax.Array = 0.
soma_x : float|jax.Array = 0.1
V_axon : float|jax.Array = -60.0
axon_Sodium_h : float|jax.Array = 0.
axon_Potassium_x: float|jax.Array = 0.
V_dend : float|jax.Array = -60.0
dend_Ca2Plus : float|jax.Array = 3.715
dend_Calcium_r : float|jax.Array = 0.
dend_Potassium_s: float|jax.Array = 0.
dend_Hcurrent_q : float|jax.Array = 0.
@classmethod
def make(cls, params: IOCellParams, V_soma=-49.6, V_axon=-49.6, V_dend=-49.6):
return timestep(cls(V_soma=V_soma, V_axon=V_axon, V_dend=V_dend), params=params, steady_state=True)
@functools.partial(jax.jit, static_argnames=['steady_state'])
def timestep(state: IOCellState, params: IOCellParams, delta=0.025, I_app=0.0, steady_state=False):
soma_I_leak = params.g_ls * (state.V_soma - params.V_l)
I_ds = (params.g_int / params.p1) * (state.V_soma - state.V_dend)
I_as = (params.g_int / (1 - params.p2)) * (state.V_soma - state.V_axon)
soma_I_interact = I_ds + I_as
soma_Ical = params.g_CaL * state.soma_k * state.soma_k * state.soma_k * state.soma_l * (state.V_soma - params.V_Ca)
soma_k_inf = 1 / (1 + jnp.exp(-(state.V_soma + 61)/4.2))
soma_l_inf = 1 / (1 + jnp.exp( (state.V_soma + 85)/8.5))
soma_tau_l = (20 * jnp.exp((state.V_soma + 160)/30) / (1 + jnp.exp((state.V_soma + 84) / 7.3))) + 35
soma_dk_dt = soma_k_inf - state.soma_k
soma_dl_dt = (soma_l_inf - state.soma_l) / soma_tau_l
soma_m_inf = 1 / (1 + jnp.exp(-(state.V_soma + 30)/5.5))
soma_h_inf = 1 / (1 + jnp.exp( (state.V_soma + 70)/5.8))
soma_Ina = params.g_Na_s * soma_m_inf**3 * state.soma_h * (state.V_soma - params.V_Na)
soma_tau_h = 3 * jnp.exp(-(state.V_soma + 40)/33)
soma_dh_dt = (soma_h_inf - state.soma_h) / soma_tau_h
soma_Ikdr = params.g_Kdr_s * state.soma_n**4 * (state.V_soma - params.V_K)
soma_n_inf = 1 / ( 1 + jnp.exp(-(state.V_soma + 3)/10))
soma_tau_n = 5 + (47 * jnp.exp( (state.V_soma + 50)/900))
soma_dn_dt = (soma_n_inf - state.soma_n) / soma_tau_n
soma_Ik = params.g_K_s * state.soma_x**4 * (state.V_soma - params.V_K)
soma_alpha_x = 0.13 * (state.V_soma + 25) / (1 - jnp.exp(-(state.V_soma + 25)/10))
soma_beta_x = 1.69 * jnp.exp(-(state.V_soma + 35)/80)
soma_tau_x_inv = soma_alpha_x + soma_beta_x
soma_x_inf = soma_alpha_x / soma_tau_x_inv
soma_dx_dt = (soma_x_inf - state.soma_x) * soma_tau_x_inv
soma_I_Channels = soma_Ik + soma_Ikdr + soma_Ina + soma_Ical
soma_dv_dt = params.S * (-(soma_I_leak + soma_I_interact + soma_I_Channels))
axon_I_leak = params.g_la * (state.V_axon - params.V_l)
I_sa = (params.g_int / params.p2) * (state.V_axon - state.V_soma)
axon_I_interact = I_sa
axon_m_inf = 1 / (1 + jnp.exp(-(state.V_axon+30)/5.5))
axon_h_inf = 1 / (1 + jnp.exp( (state.V_axon+60)/5.8))
axon_Ina = params.g_Na_a * axon_m_inf**3 * state.axon_Sodium_h * (state.V_axon - params.V_Na)
axon_tau_h = 1.5 * jnp.exp(-(state.V_axon+40)/33)
axon_dh_dt = (axon_h_inf - state.axon_Sodium_h) / axon_tau_h
axon_Ik = params.g_K_a * state.axon_Potassium_x**4 * (state.V_axon - params.V_K)
axon_alpha_x = 0.13*(state.V_axon + 25) / (1 - jnp.exp(-(state.V_axon + 25)/10))
axon_beta_x = 1.69 * jnp.exp(-(state.V_axon + 35)/80)
axon_tau_x_inv = axon_alpha_x + axon_beta_x
axon_x_inf = axon_alpha_x / axon_tau_x_inv
axon_dx_dt = (axon_x_inf - state.axon_Potassium_x) * axon_tau_x_inv
axon_I_Channels = axon_Ina + axon_Ik
axon_dv_dt = params.S * (-(axon_I_leak + axon_I_interact + axon_I_Channels))
dend_I_leak = params.g_ld * (state.V_dend - params.V_l)
dend_I_interact = (params.g_int / (1 - params.p1)) * (state.V_dend - state.V_soma)
dend_Icah = params.g_CaH * state.dend_Calcium_r * state.dend_Calcium_r * (state.V_dend - params.V_Ca)
dend_alpha_r = 1.7 / (1 + jnp.exp(-(state.V_dend - 5)/13.9))
dend_beta_r = 0.02*(state.V_dend + 8.5) / (jnp.exp((state.V_dend + 8.5)/5) - 1.0)
dend_tau_r_inv5 = (dend_alpha_r + dend_beta_r) # tau = 5 / (alpha + beta)
dend_r_inf = dend_alpha_r / dend_tau_r_inv5
dend_dr_dt = (dend_r_inf - state.dend_Calcium_r) * dend_tau_r_inv5 * 0.2
dend_Ikca = params.g_K_Ca * state.dend_Potassium_s * (state.V_dend - params.V_K)
dend_alpha_s = jnp.where(0.00002 * state.dend_Ca2Plus < 0.01, 0.00002 * state.dend_Ca2Plus, 0.01)
dend_tau_s_inv = dend_alpha_s + 0.015
dend_s_inf = dend_alpha_s / dend_tau_s_inv
dend_ds_dt = (dend_s_inf - state.dend_Potassium_s) * dend_tau_s_inv
dend_Ih = params.g_h * state.dend_Hcurrent_q * (state.V_dend - params.V_h)
q_inf = 1 / (1 + jnp.exp((state.V_dend + 80)/4))
tau_q_inv = jnp.exp(-0.086*state.V_dend - 14.6) + jnp.exp(0.070*state.V_dend - 1.87)
dend_dq_dt = (q_inf - state.dend_Hcurrent_q) * tau_q_inv
dend_dCa_dt = -3 * dend_Icah - 0.075 * state.dend_Ca2Plus
dend_I_Channels = dend_Icah + dend_Ikca + dend_Ih
dend_dv_dt = params.S * (-(dend_I_leak + dend_I_interact - I_app + dend_I_Channels))
return IOCellState(
state.V_soma + soma_dv_dt * delta,
state.soma_k + soma_dk_dt * delta if not steady_state else soma_k_inf,
state.soma_l + soma_dl_dt * delta if not steady_state else soma_l_inf,
state.soma_h + soma_dh_dt * delta if not steady_state else soma_h_inf,
state.soma_n + soma_dn_dt * delta if not steady_state else soma_n_inf,
state.soma_x + soma_dx_dt * delta if not steady_state else soma_x_inf,
state.V_axon + axon_dv_dt * delta,
state.axon_Sodium_h + axon_dh_dt * delta if not steady_state else axon_h_inf,
state.axon_Potassium_x + axon_dx_dt * delta if not steady_state else axon_x_inf,
state.V_dend + dend_dv_dt * delta,
state.dend_Ca2Plus + dend_dCa_dt* delta,
state.dend_Calcium_r + dend_dr_dt * delta if not steady_state else dend_r_inf,
state.dend_Potassium_s + dend_ds_dt * delta if not steady_state else dend_s_inf,
state.dend_Hcurrent_q + dend_dq_dt * delta if not steady_state else q_inf)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment