Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save paulbrodersen/2ce61e4e57c7659b727d30e3e1e0128f to your computer and use it in GitHub Desktop.
Save paulbrodersen/2ce61e4e57c7659b727d30e3e1e0128f to your computer and use it in GitHub Desktop.
Code to reproduce neuronal network simulations in Burman et al. (2023).
#!/usr/bin/env python
# -*- coding: utf-8 -*
"""Simulate a LIF neuronal network with variable EGABA and structured inputs.
Copyright (C) 2023 by Paul Brodersen.
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or (at
your option) any later version. This program is distributed in the
hope that it will be useful, but WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the GNU General Public License for more details. You
should have received a copy of the GNU General Public License along
with this program. If not, see <https://www.gnu.org/licenses/.
"""
import time
import pathlib
import brian2 as b2
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from uuid import uuid4
from argparse import ArgumentParser
from functools import partial
from scipy.sparse import coo_matrix
from umap import UMAP
from scipy.stats import (
gaussian_kde,
entropy as get_entropy,
)
from sklearn.neighbors import KNeighborsClassifier as Classifier
from sklearn.model_selection import (
cross_val_score,
StratifiedKFold,
)
def get_peristimulus_histogram(spike_times, interval, bin_width, kde=False):
if kde:
evaluate_at = np.arange(interval[0] + bin_width/2, interval[1], bin_width)
kernel = gaussian_kde(spike_times, bin_width)
density = kernel(evaluate_at)
return density / density.sum() * len(spike_times)
else:
bins = np.arange(interval[0], interval[1]+bin_width, bin_width)
indices = np.digitize(spike_times, bins)
indices -= 1 # index of zero indicates out of bounds samples
interval_length = interval[1] - interval[0]
counts = np.bincount(indices[indices>=0], minlength=interval_length)
counts = counts[:interval_length] # index of len(interval) indicates out of bounds samples
return counts.astype(float)
def get_peristimulus_histogram_entropy(spike_times, interval, bin_width):
histogram = get_peristimulus_histogram(spike_times, interval, bin_width)
probability = histogram / np.sum(histogram)
return get_entropy(probability)
###########################################################
parser = ArgumentParser()
parser.add_argument(
'--stimuli',
help="Number of different stimuli",
default=100,
type=int,
)
parser.add_argument(
'--repeats',
help="Number of times each stimulus is repeated",
default=100,
type=int,
)
parser.add_argument(
'-o', '--output',
help="/path/to/output.csv",
default=None
)
parser.add_argument(
'-f', '--figure_directory',
help="/path/to/figure/directory",
default=None
)
parser.add_argument(
'--show',
help="If specified, show plots.",
action="store_true",
)
args = parser.parse_args()
# create dataset identifier
uuid = uuid4()
###########################################################
# DEFINE EXTERNAL INPUT
b2.start_scope()
total_conditions = 2
total_neurons = 1000
total_stimuli = args.stimuli
total_repeats = total_conditions * args.repeats
total_condition_epochs = total_stimuli * args.repeats
total_balancing_epochs = total_stimuli
total_testing_epochs = total_stimuli * total_repeats
total_epochs = total_balancing_epochs + total_testing_epochs
epoch_duration = 25
balancing_time = total_balancing_epochs * epoch_duration
testing_time = total_testing_epochs * epoch_duration
total_time = balancing_time + testing_time
condition_duration = testing_time / total_conditions
stimulus_set_duration = total_stimuli * epoch_duration
stimulus_duration = 1
stimulus_magnitude = 500
total_subepochs = int(epoch_duration / stimulus_duration)
base_stimuli = np.zeros((total_stimuli * total_subepochs, total_neurons))
base_stimuli[::total_subepochs] = stimulus_magnitude * np.random.lognormal(size=(total_stimuli, total_neurons))
stimuli = np.vstack([base_stimuli] * (total_repeats + 1))
I_stimulus = b2.TimedArray(-stimuli*b2.pA, dt=stimulus_duration * b2.ms)
base_stimuli_labels = np.arange(total_stimuli)
stimuli_labels = np.hstack([base_stimuli_labels] * (total_repeats + 1))
# random background noise
noise_time_scale = 10.
noise_magnitude = 25
arr_noise = noise_magnitude * np.random.lognormal(size=(int(total_time / noise_time_scale), total_neurons))
I_random = b2.TimedArray(-arr_noise*b2.pA, dt=noise_time_scale*b2.ms)
###########################################################
# NETWORK PARAMETERS
params = dict()
params['exc'] = {'name' :'excitatory', 'count':int(4* total_neurons / 5), 'color':'crimson'}
params['inh'] = {'name' :'inhibitory', 'count':int( total_neurons / 5), 'color':'cornflowerblue'}
egaba = [-60, -80]
values = [np.mean(egaba)] + args.repeats * [egaba[0]] + args.repeats * [egaba[1]]
E_GABA_exc = b2.TimedArray(np.array(values) * b2.mV, dt=stimulus_set_duration * b2.ms) # Inhibitory reversal potential (excitatory)
E_L_exc = -60. * b2.mV # Resting membrane potential / reversal of the leak current (excitatory)
V_th_exc = -50. * b2.mV # Spiking threshold (excitatory)
E_GABA_inh = -60. * b2.mV # Inhibitory reversal potential (inhibitory)
E_L_inh = -60. * b2.mV # Resting membrane potential / reversal of the leak current (inhibitory)
V_th_inh = -50. * b2.mV # Spiking threshold (inhibitory)
C_m = 200. * b2.pfarad # Membrane capacitance
g_L = 10. * b2.nsiemens # Leak conductance
E_GLUT = 0. * b2.mV # Excitatory reversal potential
tau_GLUT = 5. * b2.ms # Glutamatergic synaptic time constant
tau_GABA = 10. * b2.ms # GABAergic synaptic time constant
t_ref = 5. * b2.ms # Refractory period
exc_neuron_eqs = '''
dv/dt = -(I_L + I_GLUT + I_GABA + I_ext)/C_m : volt (unless refractory)
I_ext = I_stimulus(t, i) + I_random(t, i) : amp
I_L = g_L * (v - E_L_exc) : amp
I_GLUT = g_GLUT * (v - E_GLUT) : amp
I_GABA = g_GABA * (v - E_GABA_exc(t)) : amp
dg_GLUT/dt = -g_GLUT / tau_GLUT : siemens
dg_GABA/dt = -g_GABA / tau_GABA : siemens
'''
inh_neuron_eqs = '''
dv/dt = -(I_L + I_GLUT + I_GABA + I_ext)/C_m : volt (unless refractory)
I_ext = I_stimulus(t, i) + I_random(t, i) : amp
I_L = g_L * (v - E_L_inh) : amp
I_GLUT = g_GLUT * (v - E_GLUT) : amp
I_GABA = g_GABA * (v - E_GABA_inh) : amp
dg_GLUT/dt = -g_GLUT / tau_GLUT : siemens
dg_GABA/dt = -g_GABA / tau_GABA : siemens
'''
net = b2.Network(b2.collect())
population = dict()
population['exc'] = b2.NeuronGroup(params['exc']['count'],
model = exc_neuron_eqs,
threshold = 'v>V_th_exc',
reset = 'v=E_L_exc',
refractory = t_ref,
method = 'euler')
population['inh'] = b2.NeuronGroup(params['inh']['count'],
model = inh_neuron_eqs,
threshold = 'v>V_th_inh',
reset = 'v=E_L_inh',
refractory = t_ref,
method = 'euler')
population['exc'].v = E_L_exc
population['inh'].v = E_L_inh
net.add(population)
###########################################################
# CONNECTIVITY
weight_ex = 0.1 * b2.nsiemens # Excitatory weight
weight_in = 1. * b2.nsiemens # Inhibitory weight
gmax = 100. * b2.nsiemens # Maximum inhibitory weight
tau_stdp = 20. * b2.ms # STDP time constant
rho = 3. * b2.Hz # Target excitatory population rate
beta = rho * tau_stdp * 2 # Target rate parameter
p_e = 0.1 # excitatory connection probability
p_i = 0.1 # inhibitory connection probability
static_ex_model = dict(model='w : siemens', on_pre='g_GLUT_post += w')
static_in_model = dict(model='w : siemens', on_pre='g_GABA_post += w')
vogels_model = dict(
model='''
w : siemens
dApre/dt = -Apre / tau_stdp : siemens (event-driven)
dApost/dt = -Apost / tau_stdp : siemens (event-driven)
''',
on_pre='''
Apre += 1.*nsiemens
w = clip(w + (Apost - beta * nsiemens) * eta, 0 * nsiemens, gmax)
g_GABA_post += w''',
on_post='''
Apost += 1. * nsiemens
w = clip(w + Apre * eta, 0 * nsiemens, gmax)
''')
static_ex_synapse = partial(b2.Synapses, **static_ex_model)
static_in_synapse = partial(b2.Synapses, **static_in_model)
vogels_synapse = partial(b2.Synapses, **vogels_model)
conn_params = dict()
conn_params[('exc','exc')] = dict(model=static_ex_synapse, p=p_e, w=weight_ex)
conn_params[('exc','inh')] = dict(model=static_ex_synapse, p=p_e, w=weight_ex)
conn_params[('inh','inh')] = dict(model=static_in_synapse, p=p_i, w=weight_in)
conn_params[('inh','exc')] = dict(model=vogels_synapse, p=p_i, w=weight_in)
connectivity = dict()
for connection in conn_params:
connectivity[connection] = conn_params[connection]['model'](population[connection[0]], population[connection[1]])
connectivity[connection].connect(p=conn_params[connection]['p'])
connectivity[connection].w = conn_params[connection]['w']
net.add(connectivity)
# ###########################################
# MONITORS
spike_monitor = dict()
rate_monitor = dict()
for label in params.keys():
spike_monitor[label] = b2.SpikeMonitor(population[label])
rate_monitor[label] = b2.PopulationRateMonitor(population[label])
net.add(spike_monitor)
net.add(rate_monitor)
# ###########################################
# RUN SIMULATION
for simulation_time, eta in zip((balancing_time, testing_time), [0.1, 0.]):
tic = time.time()
net.run(simulation_time * b2.ms)
toc = time.time()
print(f"Time elapsed: {toc-tic:.2f} seconds.")
# ###########################################
# ANALYSIS & PLOTS
for ctr, E_GABA in enumerate(egaba):
condition_start = balancing_time + ctr * condition_duration
condition_stop = condition_start + condition_duration
if args.figure_directory or args.show:
# plot neuronal spike times
fig, (ax0a, ax0b, ax1, ax2, ax3) = plt.subplots(5, 1, sharex=True, figsize=(6.85, 12.5))
time = rate_monitor['exc'].t
ax0a.plot(time / b2.ms, np.array([I_stimulus(t, 0) / b2.pA for t in time]), label='Example A')
ax0b.plot(time / b2.ms, np.array([I_stimulus(t, 1) / b2.pA for t in time]), label='Example B')
ax0a.legend(loc='upper right')
ax0b.legend(loc='upper right')
ax0a.set_ylabel("Input currents [pA]")
ax0b.set_ylabel("Input currents [pA]")
for ax, population, label in zip((ax1, ax2), params.keys(), ('Excitatory neurons', 'Inhibitory neurons')):
ax.plot(spike_monitor[population].t/b2.ms, spike_monitor[population].i, '.',
markersize = 2,
color = params[population]['color'],
# alpha = 0.5,
rasterized = True)
ax.set_ylabel(label)
# plot population firing rates
for label in params.keys():
ax3.plot(rate_monitor[label].t/b2.ms,
rate_monitor[label].smooth_rate(window = 'gaussian', width = 10 * b2.ms)/b2.Hz,
label = params[label]['name'],
color = params[label]['color'])
ax3.set_ylabel('Firing rate [Hz]')
start = condition_start
stop = condition_start + 5 * epoch_duration
ax3.set_xlim(start, stop)
ax3.legend(loc='upper right')
ax3.set_xlabel('Time [ms]')
fig.tight_layout()
fig.align_ylabels()
if args.figure_directory:
fig.savefig(args.figure_directory + f"/{uuid}_E_GABA_{E_GABA:.1f}_example_responses.pdf")
spike_times = spike_monitor['exc'].t / b2.ms
spike_indices = spike_monitor['exc'].i
if not len(spike_times):
raise ValueError("Simulation yielded no spikes during testing.")
# remove spikes outside of condition of interest
valid = (spike_times >= condition_start) & (spike_times < condition_stop)
spike_indices = spike_indices[valid]
spike_times = spike_times[valid]
spike_times -= condition_start
# compute stimulus discriminability
total_neurons = params['exc']['count']
epoch_indices = (spike_times / epoch_duration).astype(int)
X = coo_matrix((np.ones_like(spike_times), (epoch_indices, spike_indices)), shape=(total_condition_epochs, total_neurons))
y = stimuli_labels[(total_balancing_epochs + ctr * total_condition_epochs):\
(total_balancing_epochs + (ctr + 1) * total_condition_epochs)]
accuracy = np.mean(cross_val_score(Classifier(), X[:, :50], y, cv=StratifiedKFold(n_splits=5, shuffle=True)))
if args.figure_directory or args.show:
fig, ax = plt.subplots()
xx, yy = UMAP().fit_transform(X[:, :50]).transpose()
# apply small amounts of noise so samples within a class overlap less
xx += 0.005 * np.ptp(xx) * np.random.randn(*xx.shape)
yy += 0.005 * np.ptp(yy) * np.random.randn(*yy.shape)
ax.scatter(xx, yy, c=y, cmap=plt.get_cmap('jet', 100), s=0.1, alpha=0.8)
ax.set_xlabel('UMAP embedding\ndimension 1')
ax.set_ylabel('UMAP component\ndimension 2')
fig.tight_layout()
if args.figure_directory:
fig.savefig(args.figure_directory + f"/{uuid}_E_GABA_{E_GABA:.1f}_umap.pdf")
# computing the neuron properties takes a very long time;
# remove spikes outside of the first stimulus set
is_valid = (spike_times >= 0) & (spike_times < stimulus_set_duration)
spike_indices = spike_indices[is_valid]
spike_times = spike_times[is_valid]
epoch_indices = epoch_indices[is_valid]
# retain only peri-stimulus spikes
spike_times = spike_times % epoch_duration
interval = (0, 12)
is_valid = np.logical_and(spike_times >= interval[0], spike_times < interval[1])
spike_times = spike_times[is_valid]
spike_indices = spike_indices[is_valid]
epoch_indices = epoch_indices[is_valid]
# compute peri-stimulus spike counts & firing rates
unique_neurons = np.unique(spike_indices)
spike_counts = np.zeros((total_neurons))
for neuron in unique_neurons:
spike_counts[neuron] = (spike_indices == neuron).sum()
total_peristimulus_time = total_stimuli * (interval[1] - interval[0])
firing_rates = spike_counts / (total_peristimulus_time / 1000)
# plot peri-stimulus histogram
if args.figure_directory or args.show:
dt = 0.1
time = np.arange(*interval, dt)
response = get_peristimulus_histogram(spike_times, interval, dt, kde=True)
response = response / np.sum(response)
fig, ax = plt.subplots()
ax.plot(time, response)
ax.set_xlabel("Times [ms]")
ax.set_ylabel("Population response")
primary_response = np.sum(response[time <= 2])
secondary_response = np.sum(response[time > 2])
ax.set_title(f"Primary : secondary = {primary_response:.2f} : {secondary_response:.2f}")
ax.grid(True)
fig.tight_layout()
if args.figure_directory:
fig.savefig(args.figure_directory + f"/{uuid}_E_GABA_{E_GABA:.1f}_peristimulus_histogram.pdf")
# compute peristimulus histogram entropy
entropy = np.full(total_neurons, np.nan)
for neuron in unique_neurons:
mask = spike_indices == neuron
times = spike_times[mask]
entropy[neuron] = get_peristimulus_histogram_entropy(times, interval, 1)
# compute synchrony
synchrony = dict()
for epoch in range(total_stimuli):
population_spikes = spike_times[epoch_indices == epoch]
if len(population_spikes) > 0:
population_response = get_peristimulus_histogram(population_spikes, interval, 1)
population_response /= population_response.sum()
for neuron in unique_neurons:
neuron_spikes = spike_times[np.logical_and(epoch_indices == epoch, spike_indices == neuron)]
if len(neuron_spikes) > 0:
indices = (neuron_spikes - interval[0]).astype(int)
s = population_response[indices].sum()
n = len(indices)
if neuron in synchrony:
synchrony[neuron][0] += s
synchrony[neuron][1] += n
else:
synchrony[neuron] = [s, n]
synchrony = np.array([synchrony[neuron][0]/synchrony[neuron][1] \
if neuron in synchrony else np.nan \
for neuron in range(total_neurons)])
if args.output:
data = dict()
data['uid'] = uuid
data['unit'] = np.arange(total_neurons, dtype=int)
data[r'E$_{GABA}$ [mV]'] = E_GABA
data[r'E$_{L}$ [mV]'] = -60.
data[r'V$_{th}$ [mV]'] = -50.
data['Spike count'] = spike_counts
data['Firing rate [Hz]'] = firing_rates
data['Entropy [nats]'] = entropy
data['Synchrony'] = synchrony
data["Discriminability [%]"] = 100 * accuracy
output_path = pathlib.Path(args.output)
if output_path.exists():
old = pd.read_csv(output_path)
new = pd.DataFrame(data)
df = pd.concat([old, new], ignore_index=True)
else:
df = pd.DataFrame(data)
df.to_csv(output_path, index=False)
print()
print("Results")
print("=======")
try:
print(new.mean(axis=0))
except:
print(df.mean(axis=0))
print()
if args.show:
plt.ion()
plt.show()
input("Press any key to continue...")
plt.close("all")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment