Skip to content

Instantly share code, notes, and snippets.

@paulbrodersen
Last active December 4, 2024 16:04
Show Gist options
  • Save paulbrodersen/244cdfca790e2a351d52d921f848e122 to your computer and use it in GitHub Desktop.
Save paulbrodersen/244cdfca790e2a351d52d921f848e122 to your computer and use it in GitHub Desktop.
Find synaptic inputs in voltage or current traces
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""This program identifies synaptic inputs in voltage and current
electrophysiological recordings.
Copyright (C) 2024 Gemma Gothard, Paul Brodersen
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import warnings
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import find_peaks, savgol_filter
def find_inputs(
trace, fs,
derivative_peak_threshold = 1.,
amplitude_threshold = 5.,
rise_time = 2,
show = False,
):
"""Find synaptic inputs in neuronal current or voltage traces
based on peaks in the first derivative of the trace (i.e. events
with fast rise times), and return input magnitude, the input onset
index, and the input peak index.
Arguments:
----------
trace : NDArray[float]
The neuronal current or voltage trace in pA or mV.
fs : float
Sampling frequency of the current or voltage trace in kHz.
derivative_peak_threshold: float (default 0.2)
The threshold for detecting peaks in the first derivative.
amplitude_threshold: float (default 5)
The minimum amplitude of inputs to be detected, in pA or mV.
rise_time: float (default 2)
The (expected) input rise time in ms.
show : bool
Plot a diagnostic figure showing the original trace, the filtered trace,
its derivative, and the outputs.
Returns:
--------
magnitudes : NDArray[float]
The input magnitudes in pA or mV.
onset_indices : NDArray[int]
The corresponding indices of input onsets.
peak_indices : NDArray[int]
The corresponding indices of input peaks.
"""
# Find peaks in the first derivative.
# We apply a very mild local averaging to reduce the sensitivity to noise.
filtered = savgol_filter(trace, 5, 3)
derivative = np.r_[0, np.diff(filtered)]
derivative_peak_indices, derivative_peak_properties = find_peaks(
derivative, prominence=derivative_peak_threshold, height=0) # set height = 0 so that the height of the peaks is returned but no height threshold enforced
derivative_peak_heights = derivative_peak_properties['peak_heights']
# The derivative peaks during the middle of the input rise time, where the trace increases the fastest.
# We will look for input onsets and peaks within a window around the derivative peaks.
# To prevent capturing partial inputs, we remove peaks that are within too close to the start or end of the trace.
look_back = int(2.5 * rise_time * fs)
look_ahead = int(5.0 * rise_time * fs)
is_near_border = np.logical_or(
derivative_peak_indices < look_back,
derivative_peak_indices > (len(derivative) - look_ahead)
)
derivative_peak_indices = derivative_peak_indices[~is_near_border]
derivative_peak_heights = derivative_peak_heights[~is_near_border]
def get_zero_crossings(y):
x, = np.where(np.diff(np.sign(y)))
return x
onset_indices = []
peak_indices = []
magnitudes = []
for (derivative_peak_index, derivative_peak_height) in zip(derivative_peak_indices, derivative_peak_heights):
# Find zero-crossings in the derivative in a window around the peak.
# The onset index corresponds to the last zero-crossing before the peak;
# the peak index corresponds to the first zero-crossing after the peak.
# As the onset can be masked by noise in the baseline, we take a conservative estimate.
baseline_threshold = derivative_peak_height * 0.1
peak_index_candidates = get_zero_crossings(derivative[derivative_peak_index:derivative_peak_index+look_ahead])
onset_index_candidates = get_zero_crossings(derivative[derivative_peak_index-look_back:derivative_peak_index] - baseline_threshold)
if len(onset_index_candidates) >= 1:
onset_index = onset_index_candidates[-1]
else:
# Could not find the onset of the input.
# Discard event.
continue
if len(peak_index_candidates) >= 1:
peak_index = peak_index_candidates[0]
else:
msg = "Could not find the peak of the input using the derivative. "
msg += "Falling back to finding the maximum in the original trace within the specified window."
warnings.warn(msg)
peak_index = np.argmax(trace[derivative_peak_index:derivative_peak_index+look_ahead])
# Correct indices to their position in the original trace and determine the input magnitude.
peak_index += derivative_peak_index
onset_index += derivative_peak_index - look_back
onset_index += 1 # correct for the filter eating into the up-stroke of the input
magnitude = trace[peak_index] - trace[onset_index]
# Only capture inputs which are bigger than the amplitude threshold.
if magnitude > amplitude_threshold:
peak_indices.append(peak_index)
onset_indices.append(onset_index)
magnitudes.append(magnitude)
# Remove inputs that have the same onset as another input but had a slower rise time (i.e. later derivative peak).
onset_indices, deduplicated_input_indices = np.unique(onset_indices, return_index=True)
peak_indices = np.array(peak_indices)[deduplicated_input_indices]
magnitudes = np.array(magnitudes)[deduplicated_input_indices]
# Remove inputs that have the same peak as another input
peak_indices, deduplicated_input_indices = np.unique(peak_indices, return_index=True)
onset_indices = np.array(onset_indices)[deduplicated_input_indices]
magnitudes = np.array(magnitudes)[deduplicated_input_indices]
if show:
fig, axes = plt.subplots(2, 1, sharex=True, figsize=(10,8))
axes[0].plot(trace, color="tab:blue", label="Input trace")
axes[0].plot(filtered, linestyle='--', color="tab:orange", label="Filtered trace")
axes[0].plot(peak_indices, magnitudes + trace[onset_indices], '*', color='tab:orange', label="peaks")
axes[1].plot(derivative, color="tab:blue", label="Derivative")
axes[1].plot(derivative_peak_indices, derivative_peak_heights, '*', color='tab:blue', label="derivative peaks")
for peak, onset, magnitude in zip(peak_indices, onset_indices, magnitudes):
axes[0].axvspan(onset, peak, color="limegreen", alpha=0.5, zorder=-1)
axes[1].axvspan(onset, peak, color="limegreen", alpha=0.5, zorder=-1)
axes[0].set_ylim(*np.percentile(trace, [0.1, 99.9]))
axes[1].set_ylim(*np.percentile(derivative, [0.1, 99.9]))
axes[0].legend(loc="upper left")
axes[1].legend(loc="upper left")
fig.tight_layout()
return magnitudes, onset_indices, peak_indices
def get_average_input(trace, sampling_frequency, window=(-10, 40), *args, **kwargs):
magnitudes, onset_indices, peak_indices = find_inputs(trace, sampling_frequency)
start, stop = sampling_frequency * window[0], sampling_frequency * window[1]
inputs = [trace[ii + start : ii + stop] for ii in onset_indices]
inputs = [sample for sample in inputs if len(sample) == stop-start]
return np.mean(inputs, axis=0)
def test(total_samples=1000, total_events=20, tau_rise=2, tau_decay=10):
"""Test with dual exponential synaptic inputs."""
# simulate inputs
x = np.arange(total_samples)
y = np.zeros_like(x, dtype=float)
event_times = np.sort(np.random.randint(0, total_samples-1, size=total_events))
event_magnitudes = np.random.normal(10, 2, size=total_events)
for t, a in zip(event_times, event_magnitudes):
y += a * dual_exponential_waveform(x, t, tau_rise, tau_decay)
# add white noise
noise = np.random.randn(len(y))
y += noise
# find inputs and plot on top
magnitudes, onset_indices, peak_indices = find_inputs(y, fs=1, show=True)
fig = plt.gcf()
axes = fig.axes
for t in event_times:
for ax in axes:
ax.axvline(t, color="red", linestyle='--')
plt.show()
def dual_exponential_waveform(t, onset, tau_rise, tau_decay):
rise = np.exp(-(t[t >= onset] - onset) / tau_rise)
decay = np.exp(-(t[t >= onset] - onset) / tau_decay)
y = np.zeros_like(t, dtype=float)
y[t >= onset] = tau_rise * tau_decay / (tau_decay - tau_rise) * (decay - rise)
return y
if __name__ == "__main__":
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment