Last active
December 4, 2024 16:04
-
-
Save paulbrodersen/244cdfca790e2a351d52d921f848e122 to your computer and use it in GitHub Desktop.
Find synaptic inputs in voltage or current traces
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
"""This program identifies synaptic inputs in voltage and current | |
electrophysiological recordings. | |
Copyright (C) 2024 Gemma Gothard, Paul Brodersen | |
This program is free software: you can redistribute it and/or modify | |
it under the terms of the GNU General Public License as published by | |
the Free Software Foundation, either version 3 of the License, or | |
(at your option) any later version. | |
This program is distributed in the hope that it will be useful, | |
but WITHOUT ANY WARRANTY; without even the implied warranty of | |
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
GNU General Public License for more details. | |
You should have received a copy of the GNU General Public License | |
along with this program. If not, see <https://www.gnu.org/licenses/>. | |
""" | |
import warnings | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from scipy.signal import find_peaks, savgol_filter | |
def find_inputs( | |
trace, fs, | |
derivative_peak_threshold = 1., | |
amplitude_threshold = 5., | |
rise_time = 2, | |
show = False, | |
): | |
"""Find synaptic inputs in neuronal current or voltage traces | |
based on peaks in the first derivative of the trace (i.e. events | |
with fast rise times), and return input magnitude, the input onset | |
index, and the input peak index. | |
Arguments: | |
---------- | |
trace : NDArray[float] | |
The neuronal current or voltage trace in pA or mV. | |
fs : float | |
Sampling frequency of the current or voltage trace in kHz. | |
derivative_peak_threshold: float (default 0.2) | |
The threshold for detecting peaks in the first derivative. | |
amplitude_threshold: float (default 5) | |
The minimum amplitude of inputs to be detected, in pA or mV. | |
rise_time: float (default 2) | |
The (expected) input rise time in ms. | |
show : bool | |
Plot a diagnostic figure showing the original trace, the filtered trace, | |
its derivative, and the outputs. | |
Returns: | |
-------- | |
magnitudes : NDArray[float] | |
The input magnitudes in pA or mV. | |
onset_indices : NDArray[int] | |
The corresponding indices of input onsets. | |
peak_indices : NDArray[int] | |
The corresponding indices of input peaks. | |
""" | |
# Find peaks in the first derivative. | |
# We apply a very mild local averaging to reduce the sensitivity to noise. | |
filtered = savgol_filter(trace, 5, 3) | |
derivative = np.r_[0, np.diff(filtered)] | |
derivative_peak_indices, derivative_peak_properties = find_peaks( | |
derivative, prominence=derivative_peak_threshold, height=0) # set height = 0 so that the height of the peaks is returned but no height threshold enforced | |
derivative_peak_heights = derivative_peak_properties['peak_heights'] | |
# The derivative peaks during the middle of the input rise time, where the trace increases the fastest. | |
# We will look for input onsets and peaks within a window around the derivative peaks. | |
# To prevent capturing partial inputs, we remove peaks that are within too close to the start or end of the trace. | |
look_back = int(2.5 * rise_time * fs) | |
look_ahead = int(5.0 * rise_time * fs) | |
is_near_border = np.logical_or( | |
derivative_peak_indices < look_back, | |
derivative_peak_indices > (len(derivative) - look_ahead) | |
) | |
derivative_peak_indices = derivative_peak_indices[~is_near_border] | |
derivative_peak_heights = derivative_peak_heights[~is_near_border] | |
def get_zero_crossings(y): | |
x, = np.where(np.diff(np.sign(y))) | |
return x | |
onset_indices = [] | |
peak_indices = [] | |
magnitudes = [] | |
for (derivative_peak_index, derivative_peak_height) in zip(derivative_peak_indices, derivative_peak_heights): | |
# Find zero-crossings in the derivative in a window around the peak. | |
# The onset index corresponds to the last zero-crossing before the peak; | |
# the peak index corresponds to the first zero-crossing after the peak. | |
# As the onset can be masked by noise in the baseline, we take a conservative estimate. | |
baseline_threshold = derivative_peak_height * 0.1 | |
peak_index_candidates = get_zero_crossings(derivative[derivative_peak_index:derivative_peak_index+look_ahead]) | |
onset_index_candidates = get_zero_crossings(derivative[derivative_peak_index-look_back:derivative_peak_index] - baseline_threshold) | |
if len(onset_index_candidates) >= 1: | |
onset_index = onset_index_candidates[-1] | |
else: | |
# Could not find the onset of the input. | |
# Discard event. | |
continue | |
if len(peak_index_candidates) >= 1: | |
peak_index = peak_index_candidates[0] | |
else: | |
msg = "Could not find the peak of the input using the derivative. " | |
msg += "Falling back to finding the maximum in the original trace within the specified window." | |
warnings.warn(msg) | |
peak_index = np.argmax(trace[derivative_peak_index:derivative_peak_index+look_ahead]) | |
# Correct indices to their position in the original trace and determine the input magnitude. | |
peak_index += derivative_peak_index | |
onset_index += derivative_peak_index - look_back | |
onset_index += 1 # correct for the filter eating into the up-stroke of the input | |
magnitude = trace[peak_index] - trace[onset_index] | |
# Only capture inputs which are bigger than the amplitude threshold. | |
if magnitude > amplitude_threshold: | |
peak_indices.append(peak_index) | |
onset_indices.append(onset_index) | |
magnitudes.append(magnitude) | |
# Remove inputs that have the same onset as another input but had a slower rise time (i.e. later derivative peak). | |
onset_indices, deduplicated_input_indices = np.unique(onset_indices, return_index=True) | |
peak_indices = np.array(peak_indices)[deduplicated_input_indices] | |
magnitudes = np.array(magnitudes)[deduplicated_input_indices] | |
# Remove inputs that have the same peak as another input | |
peak_indices, deduplicated_input_indices = np.unique(peak_indices, return_index=True) | |
onset_indices = np.array(onset_indices)[deduplicated_input_indices] | |
magnitudes = np.array(magnitudes)[deduplicated_input_indices] | |
if show: | |
fig, axes = plt.subplots(2, 1, sharex=True, figsize=(10,8)) | |
axes[0].plot(trace, color="tab:blue", label="Input trace") | |
axes[0].plot(filtered, linestyle='--', color="tab:orange", label="Filtered trace") | |
axes[0].plot(peak_indices, magnitudes + trace[onset_indices], '*', color='tab:orange', label="peaks") | |
axes[1].plot(derivative, color="tab:blue", label="Derivative") | |
axes[1].plot(derivative_peak_indices, derivative_peak_heights, '*', color='tab:blue', label="derivative peaks") | |
for peak, onset, magnitude in zip(peak_indices, onset_indices, magnitudes): | |
axes[0].axvspan(onset, peak, color="limegreen", alpha=0.5, zorder=-1) | |
axes[1].axvspan(onset, peak, color="limegreen", alpha=0.5, zorder=-1) | |
axes[0].set_ylim(*np.percentile(trace, [0.1, 99.9])) | |
axes[1].set_ylim(*np.percentile(derivative, [0.1, 99.9])) | |
axes[0].legend(loc="upper left") | |
axes[1].legend(loc="upper left") | |
fig.tight_layout() | |
return magnitudes, onset_indices, peak_indices | |
def get_average_input(trace, sampling_frequency, window=(-10, 40), *args, **kwargs): | |
magnitudes, onset_indices, peak_indices = find_inputs(trace, sampling_frequency) | |
start, stop = sampling_frequency * window[0], sampling_frequency * window[1] | |
inputs = [trace[ii + start : ii + stop] for ii in onset_indices] | |
inputs = [sample for sample in inputs if len(sample) == stop-start] | |
return np.mean(inputs, axis=0) | |
def test(total_samples=1000, total_events=20, tau_rise=2, tau_decay=10): | |
"""Test with dual exponential synaptic inputs.""" | |
# simulate inputs | |
x = np.arange(total_samples) | |
y = np.zeros_like(x, dtype=float) | |
event_times = np.sort(np.random.randint(0, total_samples-1, size=total_events)) | |
event_magnitudes = np.random.normal(10, 2, size=total_events) | |
for t, a in zip(event_times, event_magnitudes): | |
y += a * dual_exponential_waveform(x, t, tau_rise, tau_decay) | |
# add white noise | |
noise = np.random.randn(len(y)) | |
y += noise | |
# find inputs and plot on top | |
magnitudes, onset_indices, peak_indices = find_inputs(y, fs=1, show=True) | |
fig = plt.gcf() | |
axes = fig.axes | |
for t in event_times: | |
for ax in axes: | |
ax.axvline(t, color="red", linestyle='--') | |
plt.show() | |
def dual_exponential_waveform(t, onset, tau_rise, tau_decay): | |
rise = np.exp(-(t[t >= onset] - onset) / tau_rise) | |
decay = np.exp(-(t[t >= onset] - onset) / tau_decay) | |
y = np.zeros_like(t, dtype=float) | |
y[t >= onset] = tau_rise * tau_decay / (tau_decay - tau_rise) * (decay - rise) | |
return y | |
if __name__ == "__main__": | |
test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment