Last active
June 12, 2023 15:48
-
-
Save paulbrodersen/2ce61e4e57c7659b727d30e3e1e0128f to your computer and use it in GitHub Desktop.
Code to reproduce neuronal network simulations in Burman et al. (2023).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -* | |
"""Simulate a LIF neuronal network with variable EGABA and structured inputs. | |
Copyright (C) 2023 by Paul Brodersen. | |
This program is free software: you can redistribute it and/or modify | |
it under the terms of the GNU General Public License as published by | |
the Free Software Foundation, either version 3 of the License, or (at | |
your option) any later version. This program is distributed in the | |
hope that it will be useful, but WITHOUT ANY WARRANTY; without even | |
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR | |
PURPOSE. See the GNU General Public License for more details. You | |
should have received a copy of the GNU General Public License along | |
with this program. If not, see <https://www.gnu.org/licenses/. | |
""" | |
import time | |
import pathlib | |
import brian2 as b2 | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
from uuid import uuid4 | |
from argparse import ArgumentParser | |
from functools import partial | |
from scipy.sparse import coo_matrix | |
from umap import UMAP | |
from scipy.stats import ( | |
gaussian_kde, | |
entropy as get_entropy, | |
) | |
from sklearn.neighbors import KNeighborsClassifier as Classifier | |
from sklearn.model_selection import ( | |
cross_val_score, | |
StratifiedKFold, | |
) | |
def get_peristimulus_histogram(spike_times, interval, bin_width, kde=False): | |
if kde: | |
evaluate_at = np.arange(interval[0] + bin_width/2, interval[1], bin_width) | |
kernel = gaussian_kde(spike_times, bin_width) | |
density = kernel(evaluate_at) | |
return density / density.sum() * len(spike_times) | |
else: | |
bins = np.arange(interval[0], interval[1]+bin_width, bin_width) | |
indices = np.digitize(spike_times, bins) | |
indices -= 1 # index of zero indicates out of bounds samples | |
interval_length = interval[1] - interval[0] | |
counts = np.bincount(indices[indices>=0], minlength=interval_length) | |
counts = counts[:interval_length] # index of len(interval) indicates out of bounds samples | |
return counts.astype(float) | |
def get_peristimulus_histogram_entropy(spike_times, interval, bin_width): | |
histogram = get_peristimulus_histogram(spike_times, interval, bin_width) | |
probability = histogram / np.sum(histogram) | |
return get_entropy(probability) | |
########################################################### | |
parser = ArgumentParser() | |
parser.add_argument( | |
'--stimuli', | |
help="Number of different stimuli", | |
default=100, | |
type=int, | |
) | |
parser.add_argument( | |
'--repeats', | |
help="Number of times each stimulus is repeated", | |
default=100, | |
type=int, | |
) | |
parser.add_argument( | |
'-o', '--output', | |
help="/path/to/output.csv", | |
default=None | |
) | |
parser.add_argument( | |
'-f', '--figure_directory', | |
help="/path/to/figure/directory", | |
default=None | |
) | |
parser.add_argument( | |
'--show', | |
help="If specified, show plots.", | |
action="store_true", | |
) | |
args = parser.parse_args() | |
# create dataset identifier | |
uuid = uuid4() | |
########################################################### | |
# DEFINE EXTERNAL INPUT | |
b2.start_scope() | |
total_conditions = 2 | |
total_neurons = 1000 | |
total_stimuli = args.stimuli | |
total_repeats = total_conditions * args.repeats | |
total_condition_epochs = total_stimuli * args.repeats | |
total_balancing_epochs = total_stimuli | |
total_testing_epochs = total_stimuli * total_repeats | |
total_epochs = total_balancing_epochs + total_testing_epochs | |
epoch_duration = 25 | |
balancing_time = total_balancing_epochs * epoch_duration | |
testing_time = total_testing_epochs * epoch_duration | |
total_time = balancing_time + testing_time | |
condition_duration = testing_time / total_conditions | |
stimulus_set_duration = total_stimuli * epoch_duration | |
stimulus_duration = 1 | |
stimulus_magnitude = 500 | |
total_subepochs = int(epoch_duration / stimulus_duration) | |
base_stimuli = np.zeros((total_stimuli * total_subepochs, total_neurons)) | |
base_stimuli[::total_subepochs] = stimulus_magnitude * np.random.lognormal(size=(total_stimuli, total_neurons)) | |
stimuli = np.vstack([base_stimuli] * (total_repeats + 1)) | |
I_stimulus = b2.TimedArray(-stimuli*b2.pA, dt=stimulus_duration * b2.ms) | |
base_stimuli_labels = np.arange(total_stimuli) | |
stimuli_labels = np.hstack([base_stimuli_labels] * (total_repeats + 1)) | |
# random background noise | |
noise_time_scale = 10. | |
noise_magnitude = 25 | |
arr_noise = noise_magnitude * np.random.lognormal(size=(int(total_time / noise_time_scale), total_neurons)) | |
I_random = b2.TimedArray(-arr_noise*b2.pA, dt=noise_time_scale*b2.ms) | |
########################################################### | |
# NETWORK PARAMETERS | |
params = dict() | |
params['exc'] = {'name' :'excitatory', 'count':int(4* total_neurons / 5), 'color':'crimson'} | |
params['inh'] = {'name' :'inhibitory', 'count':int( total_neurons / 5), 'color':'cornflowerblue'} | |
egaba = [-60, -80] | |
values = [np.mean(egaba)] + args.repeats * [egaba[0]] + args.repeats * [egaba[1]] | |
E_GABA_exc = b2.TimedArray(np.array(values) * b2.mV, dt=stimulus_set_duration * b2.ms) # Inhibitory reversal potential (excitatory) | |
E_L_exc = -60. * b2.mV # Resting membrane potential / reversal of the leak current (excitatory) | |
V_th_exc = -50. * b2.mV # Spiking threshold (excitatory) | |
E_GABA_inh = -60. * b2.mV # Inhibitory reversal potential (inhibitory) | |
E_L_inh = -60. * b2.mV # Resting membrane potential / reversal of the leak current (inhibitory) | |
V_th_inh = -50. * b2.mV # Spiking threshold (inhibitory) | |
C_m = 200. * b2.pfarad # Membrane capacitance | |
g_L = 10. * b2.nsiemens # Leak conductance | |
E_GLUT = 0. * b2.mV # Excitatory reversal potential | |
tau_GLUT = 5. * b2.ms # Glutamatergic synaptic time constant | |
tau_GABA = 10. * b2.ms # GABAergic synaptic time constant | |
t_ref = 5. * b2.ms # Refractory period | |
exc_neuron_eqs = ''' | |
dv/dt = -(I_L + I_GLUT + I_GABA + I_ext)/C_m : volt (unless refractory) | |
I_ext = I_stimulus(t, i) + I_random(t, i) : amp | |
I_L = g_L * (v - E_L_exc) : amp | |
I_GLUT = g_GLUT * (v - E_GLUT) : amp | |
I_GABA = g_GABA * (v - E_GABA_exc(t)) : amp | |
dg_GLUT/dt = -g_GLUT / tau_GLUT : siemens | |
dg_GABA/dt = -g_GABA / tau_GABA : siemens | |
''' | |
inh_neuron_eqs = ''' | |
dv/dt = -(I_L + I_GLUT + I_GABA + I_ext)/C_m : volt (unless refractory) | |
I_ext = I_stimulus(t, i) + I_random(t, i) : amp | |
I_L = g_L * (v - E_L_inh) : amp | |
I_GLUT = g_GLUT * (v - E_GLUT) : amp | |
I_GABA = g_GABA * (v - E_GABA_inh) : amp | |
dg_GLUT/dt = -g_GLUT / tau_GLUT : siemens | |
dg_GABA/dt = -g_GABA / tau_GABA : siemens | |
''' | |
net = b2.Network(b2.collect()) | |
population = dict() | |
population['exc'] = b2.NeuronGroup(params['exc']['count'], | |
model = exc_neuron_eqs, | |
threshold = 'v>V_th_exc', | |
reset = 'v=E_L_exc', | |
refractory = t_ref, | |
method = 'euler') | |
population['inh'] = b2.NeuronGroup(params['inh']['count'], | |
model = inh_neuron_eqs, | |
threshold = 'v>V_th_inh', | |
reset = 'v=E_L_inh', | |
refractory = t_ref, | |
method = 'euler') | |
population['exc'].v = E_L_exc | |
population['inh'].v = E_L_inh | |
net.add(population) | |
########################################################### | |
# CONNECTIVITY | |
weight_ex = 0.1 * b2.nsiemens # Excitatory weight | |
weight_in = 1. * b2.nsiemens # Inhibitory weight | |
gmax = 100. * b2.nsiemens # Maximum inhibitory weight | |
tau_stdp = 20. * b2.ms # STDP time constant | |
rho = 3. * b2.Hz # Target excitatory population rate | |
beta = rho * tau_stdp * 2 # Target rate parameter | |
p_e = 0.1 # excitatory connection probability | |
p_i = 0.1 # inhibitory connection probability | |
static_ex_model = dict(model='w : siemens', on_pre='g_GLUT_post += w') | |
static_in_model = dict(model='w : siemens', on_pre='g_GABA_post += w') | |
vogels_model = dict( | |
model=''' | |
w : siemens | |
dApre/dt = -Apre / tau_stdp : siemens (event-driven) | |
dApost/dt = -Apost / tau_stdp : siemens (event-driven) | |
''', | |
on_pre=''' | |
Apre += 1.*nsiemens | |
w = clip(w + (Apost - beta * nsiemens) * eta, 0 * nsiemens, gmax) | |
g_GABA_post += w''', | |
on_post=''' | |
Apost += 1. * nsiemens | |
w = clip(w + Apre * eta, 0 * nsiemens, gmax) | |
''') | |
static_ex_synapse = partial(b2.Synapses, **static_ex_model) | |
static_in_synapse = partial(b2.Synapses, **static_in_model) | |
vogels_synapse = partial(b2.Synapses, **vogels_model) | |
conn_params = dict() | |
conn_params[('exc','exc')] = dict(model=static_ex_synapse, p=p_e, w=weight_ex) | |
conn_params[('exc','inh')] = dict(model=static_ex_synapse, p=p_e, w=weight_ex) | |
conn_params[('inh','inh')] = dict(model=static_in_synapse, p=p_i, w=weight_in) | |
conn_params[('inh','exc')] = dict(model=vogels_synapse, p=p_i, w=weight_in) | |
connectivity = dict() | |
for connection in conn_params: | |
connectivity[connection] = conn_params[connection]['model'](population[connection[0]], population[connection[1]]) | |
connectivity[connection].connect(p=conn_params[connection]['p']) | |
connectivity[connection].w = conn_params[connection]['w'] | |
net.add(connectivity) | |
# ########################################### | |
# MONITORS | |
spike_monitor = dict() | |
rate_monitor = dict() | |
for label in params.keys(): | |
spike_monitor[label] = b2.SpikeMonitor(population[label]) | |
rate_monitor[label] = b2.PopulationRateMonitor(population[label]) | |
net.add(spike_monitor) | |
net.add(rate_monitor) | |
# ########################################### | |
# RUN SIMULATION | |
for simulation_time, eta in zip((balancing_time, testing_time), [0.1, 0.]): | |
tic = time.time() | |
net.run(simulation_time * b2.ms) | |
toc = time.time() | |
print(f"Time elapsed: {toc-tic:.2f} seconds.") | |
# ########################################### | |
# ANALYSIS & PLOTS | |
for ctr, E_GABA in enumerate(egaba): | |
condition_start = balancing_time + ctr * condition_duration | |
condition_stop = condition_start + condition_duration | |
if args.figure_directory or args.show: | |
# plot neuronal spike times | |
fig, (ax0a, ax0b, ax1, ax2, ax3) = plt.subplots(5, 1, sharex=True, figsize=(6.85, 12.5)) | |
time = rate_monitor['exc'].t | |
ax0a.plot(time / b2.ms, np.array([I_stimulus(t, 0) / b2.pA for t in time]), label='Example A') | |
ax0b.plot(time / b2.ms, np.array([I_stimulus(t, 1) / b2.pA for t in time]), label='Example B') | |
ax0a.legend(loc='upper right') | |
ax0b.legend(loc='upper right') | |
ax0a.set_ylabel("Input currents [pA]") | |
ax0b.set_ylabel("Input currents [pA]") | |
for ax, population, label in zip((ax1, ax2), params.keys(), ('Excitatory neurons', 'Inhibitory neurons')): | |
ax.plot(spike_monitor[population].t/b2.ms, spike_monitor[population].i, '.', | |
markersize = 2, | |
color = params[population]['color'], | |
# alpha = 0.5, | |
rasterized = True) | |
ax.set_ylabel(label) | |
# plot population firing rates | |
for label in params.keys(): | |
ax3.plot(rate_monitor[label].t/b2.ms, | |
rate_monitor[label].smooth_rate(window = 'gaussian', width = 10 * b2.ms)/b2.Hz, | |
label = params[label]['name'], | |
color = params[label]['color']) | |
ax3.set_ylabel('Firing rate [Hz]') | |
start = condition_start | |
stop = condition_start + 5 * epoch_duration | |
ax3.set_xlim(start, stop) | |
ax3.legend(loc='upper right') | |
ax3.set_xlabel('Time [ms]') | |
fig.tight_layout() | |
fig.align_ylabels() | |
if args.figure_directory: | |
fig.savefig(args.figure_directory + f"/{uuid}_E_GABA_{E_GABA:.1f}_example_responses.pdf") | |
spike_times = spike_monitor['exc'].t / b2.ms | |
spike_indices = spike_monitor['exc'].i | |
if not len(spike_times): | |
raise ValueError("Simulation yielded no spikes during testing.") | |
# remove spikes outside of condition of interest | |
valid = (spike_times >= condition_start) & (spike_times < condition_stop) | |
spike_indices = spike_indices[valid] | |
spike_times = spike_times[valid] | |
spike_times -= condition_start | |
# compute stimulus discriminability | |
total_neurons = params['exc']['count'] | |
epoch_indices = (spike_times / epoch_duration).astype(int) | |
X = coo_matrix((np.ones_like(spike_times), (epoch_indices, spike_indices)), shape=(total_condition_epochs, total_neurons)) | |
y = stimuli_labels[(total_balancing_epochs + ctr * total_condition_epochs):\ | |
(total_balancing_epochs + (ctr + 1) * total_condition_epochs)] | |
accuracy = np.mean(cross_val_score(Classifier(), X[:, :50], y, cv=StratifiedKFold(n_splits=5, shuffle=True))) | |
if args.figure_directory or args.show: | |
fig, ax = plt.subplots() | |
xx, yy = UMAP().fit_transform(X[:, :50]).transpose() | |
# apply small amounts of noise so samples within a class overlap less | |
xx += 0.005 * np.ptp(xx) * np.random.randn(*xx.shape) | |
yy += 0.005 * np.ptp(yy) * np.random.randn(*yy.shape) | |
ax.scatter(xx, yy, c=y, cmap=plt.get_cmap('jet', 100), s=0.1, alpha=0.8) | |
ax.set_xlabel('UMAP embedding\ndimension 1') | |
ax.set_ylabel('UMAP component\ndimension 2') | |
fig.tight_layout() | |
if args.figure_directory: | |
fig.savefig(args.figure_directory + f"/{uuid}_E_GABA_{E_GABA:.1f}_umap.pdf") | |
# computing the neuron properties takes a very long time; | |
# remove spikes outside of the first stimulus set | |
is_valid = (spike_times >= 0) & (spike_times < stimulus_set_duration) | |
spike_indices = spike_indices[is_valid] | |
spike_times = spike_times[is_valid] | |
epoch_indices = epoch_indices[is_valid] | |
# retain only peri-stimulus spikes | |
spike_times = spike_times % epoch_duration | |
interval = (0, 12) | |
is_valid = np.logical_and(spike_times >= interval[0], spike_times < interval[1]) | |
spike_times = spike_times[is_valid] | |
spike_indices = spike_indices[is_valid] | |
epoch_indices = epoch_indices[is_valid] | |
# compute peri-stimulus spike counts & firing rates | |
unique_neurons = np.unique(spike_indices) | |
spike_counts = np.zeros((total_neurons)) | |
for neuron in unique_neurons: | |
spike_counts[neuron] = (spike_indices == neuron).sum() | |
total_peristimulus_time = total_stimuli * (interval[1] - interval[0]) | |
firing_rates = spike_counts / (total_peristimulus_time / 1000) | |
# plot peri-stimulus histogram | |
if args.figure_directory or args.show: | |
dt = 0.1 | |
time = np.arange(*interval, dt) | |
response = get_peristimulus_histogram(spike_times, interval, dt, kde=True) | |
response = response / np.sum(response) | |
fig, ax = plt.subplots() | |
ax.plot(time, response) | |
ax.set_xlabel("Times [ms]") | |
ax.set_ylabel("Population response") | |
primary_response = np.sum(response[time <= 2]) | |
secondary_response = np.sum(response[time > 2]) | |
ax.set_title(f"Primary : secondary = {primary_response:.2f} : {secondary_response:.2f}") | |
ax.grid(True) | |
fig.tight_layout() | |
if args.figure_directory: | |
fig.savefig(args.figure_directory + f"/{uuid}_E_GABA_{E_GABA:.1f}_peristimulus_histogram.pdf") | |
# compute peristimulus histogram entropy | |
entropy = np.full(total_neurons, np.nan) | |
for neuron in unique_neurons: | |
mask = spike_indices == neuron | |
times = spike_times[mask] | |
entropy[neuron] = get_peristimulus_histogram_entropy(times, interval, 1) | |
# compute synchrony | |
synchrony = dict() | |
for epoch in range(total_stimuli): | |
population_spikes = spike_times[epoch_indices == epoch] | |
if len(population_spikes) > 0: | |
population_response = get_peristimulus_histogram(population_spikes, interval, 1) | |
population_response /= population_response.sum() | |
for neuron in unique_neurons: | |
neuron_spikes = spike_times[np.logical_and(epoch_indices == epoch, spike_indices == neuron)] | |
if len(neuron_spikes) > 0: | |
indices = (neuron_spikes - interval[0]).astype(int) | |
s = population_response[indices].sum() | |
n = len(indices) | |
if neuron in synchrony: | |
synchrony[neuron][0] += s | |
synchrony[neuron][1] += n | |
else: | |
synchrony[neuron] = [s, n] | |
synchrony = np.array([synchrony[neuron][0]/synchrony[neuron][1] \ | |
if neuron in synchrony else np.nan \ | |
for neuron in range(total_neurons)]) | |
if args.output: | |
data = dict() | |
data['uid'] = uuid | |
data['unit'] = np.arange(total_neurons, dtype=int) | |
data[r'E$_{GABA}$ [mV]'] = E_GABA | |
data[r'E$_{L}$ [mV]'] = -60. | |
data[r'V$_{th}$ [mV]'] = -50. | |
data['Spike count'] = spike_counts | |
data['Firing rate [Hz]'] = firing_rates | |
data['Entropy [nats]'] = entropy | |
data['Synchrony'] = synchrony | |
data["Discriminability [%]"] = 100 * accuracy | |
output_path = pathlib.Path(args.output) | |
if output_path.exists(): | |
old = pd.read_csv(output_path) | |
new = pd.DataFrame(data) | |
df = pd.concat([old, new], ignore_index=True) | |
else: | |
df = pd.DataFrame(data) | |
df.to_csv(output_path, index=False) | |
print() | |
print("Results") | |
print("=======") | |
try: | |
print(new.mean(axis=0)) | |
except: | |
print(df.mean(axis=0)) | |
print() | |
if args.show: | |
plt.ion() | |
plt.show() | |
input("Press any key to continue...") | |
plt.close("all") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment