Last active
January 30, 2025 22:32
-
-
Save llandsmeer/c56aba4a5a4d0aa249afded929a7d1b0 to your computer and use it in GitHub Desktop.
iojax
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax | |
import jax.numpy as jnp | |
import matplotlib.pyplot as plt | |
import functools | |
import typing | |
def main(): | |
params = IOCellParams(g_CaL=1.2) | |
state0 = IOCellState.make(params) | |
@jax.jit | |
def f(state, iapp): | |
return timestep(state, params, I_app= iapp), state | |
iapp = jnp.zeros(10 * 40_000) # 40000 timesteps at 0.025 ms timesteps is 1 second | |
# or iapp = jnp.linspace(0, -0.1, 10 * 40_000) | |
_, trace = jax.lax.scan(f, state0, iapp) | |
plt.plot(trace.V_soma) | |
plt.show() | |
class IOCellParams(typing.NamedTuple): | |
g_int : float = 0.13 # Cell internal conductance -- now a parameter (0.13) | |
p1 : float = 0.25 # Cell surface ratio soma/dendrite | |
p2 : float = 0.15 # Cell surface ratio axon(hillock)/soma | |
g_CaL : float = 1.1 # Calcium T - (CaV 3.1) (0.7) | |
g_h : float = 0.12 # H current (HCN) (0.4996) | |
g_K_Ca : float = 35.0 # Potassium (KCa v1.1 - BK) (35) | |
g_ld : float = 0.01532 # Leak dendrite (0.016) | |
g_la : float = 0.016 # Leak axon (0.016) | |
g_ls : float = 0.016 # Leak soma (0.016) | |
g_Na_s : float = 150.0 # Sodium - (Na v1.6 ) | |
g_Kdr_s : float = 9.0 # Potassium - (K v4.3) | |
g_K_s : float = 5.0 # Potassium - (K v3.4) | |
g_CaH : float = 4.5 # High-threshold calcium -- Ca V2.1 | |
g_Na_a : float = 240.0 # Sodium | |
g_K_a : float = 240.0 # Potassium (20) | |
S : float = 1.0 # 1/C_m, cm^2/uF | |
V_Na : float = 55.0 # Sodium | |
V_K : float = -75.0 # Potassium | |
V_Ca : float = 120.0 # Low-threshold calcium channel | |
V_h : float = -43.0 # H current | |
V_l : float = 10.0 # Leak | |
class IOCellState(typing.NamedTuple): | |
V_soma : float|jax.Array = -60.0 | |
soma_k : float|jax.Array = 0. | |
soma_l : float|jax.Array = 0. | |
soma_h : float|jax.Array = 0. | |
soma_n : float|jax.Array = 0. | |
soma_x : float|jax.Array = 0.1 | |
V_axon : float|jax.Array = -60.0 | |
axon_Sodium_h : float|jax.Array = 0. | |
axon_Potassium_x: float|jax.Array = 0. | |
V_dend : float|jax.Array = -60.0 | |
dend_Ca2Plus : float|jax.Array = 3.715 | |
dend_Calcium_r : float|jax.Array = 0. | |
dend_Potassium_s: float|jax.Array = 0. | |
dend_Hcurrent_q : float|jax.Array = 0. | |
@classmethod | |
def make(cls, params: IOCellParams, V_soma=-49.6, V_axon=-49.6, V_dend=-49.6): | |
return timestep(cls(V_soma=V_soma, V_axon=V_axon, V_dend=V_dend), params=params, steady_state=True) | |
@functools.partial(jax.jit, static_argnames=['steady_state']) | |
def timestep(state: IOCellState, params: IOCellParams, delta=0.025, I_app=0.0, steady_state=False): | |
soma_I_leak = params.g_ls * (state.V_soma - params.V_l) | |
I_ds = (params.g_int / params.p1) * (state.V_soma - state.V_dend) | |
I_as = (params.g_int / (1 - params.p2)) * (state.V_soma - state.V_axon) | |
soma_I_interact = I_ds + I_as | |
soma_Ical = params.g_CaL * state.soma_k * state.soma_k * state.soma_k * state.soma_l * (state.V_soma - params.V_Ca) | |
soma_k_inf = 1 / (1 + jnp.exp(-(state.V_soma + 61)/4.2)) | |
soma_l_inf = 1 / (1 + jnp.exp( (state.V_soma + 85)/8.5)) | |
soma_tau_l = (20 * jnp.exp((state.V_soma + 160)/30) / (1 + jnp.exp((state.V_soma + 84) / 7.3))) + 35 | |
soma_dk_dt = soma_k_inf - state.soma_k | |
soma_dl_dt = (soma_l_inf - state.soma_l) / soma_tau_l | |
soma_m_inf = 1 / (1 + jnp.exp(-(state.V_soma + 30)/5.5)) | |
soma_h_inf = 1 / (1 + jnp.exp( (state.V_soma + 70)/5.8)) | |
soma_Ina = params.g_Na_s * soma_m_inf**3 * state.soma_h * (state.V_soma - params.V_Na) | |
soma_tau_h = 3 * jnp.exp(-(state.V_soma + 40)/33) | |
soma_dh_dt = (soma_h_inf - state.soma_h) / soma_tau_h | |
soma_Ikdr = params.g_Kdr_s * state.soma_n**4 * (state.V_soma - params.V_K) | |
soma_n_inf = 1 / ( 1 + jnp.exp(-(state.V_soma + 3)/10)) | |
soma_tau_n = 5 + (47 * jnp.exp( (state.V_soma + 50)/900)) | |
soma_dn_dt = (soma_n_inf - state.soma_n) / soma_tau_n | |
soma_Ik = params.g_K_s * state.soma_x**4 * (state.V_soma - params.V_K) | |
soma_alpha_x = 0.13 * (state.V_soma + 25) / (1 - jnp.exp(-(state.V_soma + 25)/10)) | |
soma_beta_x = 1.69 * jnp.exp(-(state.V_soma + 35)/80) | |
soma_tau_x_inv = soma_alpha_x + soma_beta_x | |
soma_x_inf = soma_alpha_x / soma_tau_x_inv | |
soma_dx_dt = (soma_x_inf - state.soma_x) * soma_tau_x_inv | |
soma_I_Channels = soma_Ik + soma_Ikdr + soma_Ina + soma_Ical | |
soma_dv_dt = params.S * (-(soma_I_leak + soma_I_interact + soma_I_Channels)) | |
axon_I_leak = params.g_la * (state.V_axon - params.V_l) | |
I_sa = (params.g_int / params.p2) * (state.V_axon - state.V_soma) | |
axon_I_interact = I_sa | |
axon_m_inf = 1 / (1 + jnp.exp(-(state.V_axon+30)/5.5)) | |
axon_h_inf = 1 / (1 + jnp.exp( (state.V_axon+60)/5.8)) | |
axon_Ina = params.g_Na_a * axon_m_inf**3 * state.axon_Sodium_h * (state.V_axon - params.V_Na) | |
axon_tau_h = 1.5 * jnp.exp(-(state.V_axon+40)/33) | |
axon_dh_dt = (axon_h_inf - state.axon_Sodium_h) / axon_tau_h | |
axon_Ik = params.g_K_a * state.axon_Potassium_x**4 * (state.V_axon - params.V_K) | |
axon_alpha_x = 0.13*(state.V_axon + 25) / (1 - jnp.exp(-(state.V_axon + 25)/10)) | |
axon_beta_x = 1.69 * jnp.exp(-(state.V_axon + 35)/80) | |
axon_tau_x_inv = axon_alpha_x + axon_beta_x | |
axon_x_inf = axon_alpha_x / axon_tau_x_inv | |
axon_dx_dt = (axon_x_inf - state.axon_Potassium_x) * axon_tau_x_inv | |
axon_I_Channels = axon_Ina + axon_Ik | |
axon_dv_dt = params.S * (-(axon_I_leak + axon_I_interact + axon_I_Channels)) | |
dend_I_leak = params.g_ld * (state.V_dend - params.V_l) | |
dend_I_interact = (params.g_int / (1 - params.p1)) * (state.V_dend - state.V_soma) | |
dend_Icah = params.g_CaH * state.dend_Calcium_r * state.dend_Calcium_r * (state.V_dend - params.V_Ca) | |
dend_alpha_r = 1.7 / (1 + jnp.exp(-(state.V_dend - 5)/13.9)) | |
dend_beta_r = 0.02*(state.V_dend + 8.5) / (jnp.exp((state.V_dend + 8.5)/5) - 1.0) | |
dend_tau_r_inv5 = (dend_alpha_r + dend_beta_r) # tau = 5 / (alpha + beta) | |
dend_r_inf = dend_alpha_r / dend_tau_r_inv5 | |
dend_dr_dt = (dend_r_inf - state.dend_Calcium_r) * dend_tau_r_inv5 * 0.2 | |
dend_Ikca = params.g_K_Ca * state.dend_Potassium_s * (state.V_dend - params.V_K) | |
dend_alpha_s = jnp.where(0.00002 * state.dend_Ca2Plus < 0.01, 0.00002 * state.dend_Ca2Plus, 0.01) | |
dend_tau_s_inv = dend_alpha_s + 0.015 | |
dend_s_inf = dend_alpha_s / dend_tau_s_inv | |
dend_ds_dt = (dend_s_inf - state.dend_Potassium_s) * dend_tau_s_inv | |
dend_Ih = params.g_h * state.dend_Hcurrent_q * (state.V_dend - params.V_h) | |
q_inf = 1 / (1 + jnp.exp((state.V_dend + 80)/4)) | |
tau_q_inv = jnp.exp(-0.086*state.V_dend - 14.6) + jnp.exp(0.070*state.V_dend - 1.87) | |
dend_dq_dt = (q_inf - state.dend_Hcurrent_q) * tau_q_inv | |
dend_dCa_dt = -3 * dend_Icah - 0.075 * state.dend_Ca2Plus | |
dend_I_Channels = dend_Icah + dend_Ikca + dend_Ih | |
dend_dv_dt = params.S * (-(dend_I_leak + dend_I_interact - I_app + dend_I_Channels)) | |
return IOCellState( | |
state.V_soma + soma_dv_dt * delta, | |
state.soma_k + soma_dk_dt * delta if not steady_state else soma_k_inf, | |
state.soma_l + soma_dl_dt * delta if not steady_state else soma_l_inf, | |
state.soma_h + soma_dh_dt * delta if not steady_state else soma_h_inf, | |
state.soma_n + soma_dn_dt * delta if not steady_state else soma_n_inf, | |
state.soma_x + soma_dx_dt * delta if not steady_state else soma_x_inf, | |
state.V_axon + axon_dv_dt * delta, | |
state.axon_Sodium_h + axon_dh_dt * delta if not steady_state else axon_h_inf, | |
state.axon_Potassium_x + axon_dx_dt * delta if not steady_state else axon_x_inf, | |
state.V_dend + dend_dv_dt * delta, | |
state.dend_Ca2Plus + dend_dCa_dt* delta, | |
state.dend_Calcium_r + dend_dr_dt * delta if not steady_state else dend_r_inf, | |
state.dend_Potassium_s + dend_ds_dt * delta if not steady_state else dend_s_inf, | |
state.dend_Hcurrent_q + dend_dq_dt * delta if not steady_state else q_inf) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment