Created
November 14, 2024 21:57
-
-
Save russelljjarvis/fe6fd4f191e910d825e37dbc90d54a56 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using CUDA | |
using Adapt | |
# Check if CUDA is available | |
if !CUDA.has_cuda() | |
error("CUDA is not available on this system.") | |
else | |
CUDA.allowscalar(false) # Disallow scalar indexing for performance | |
end | |
# Parameter struct for IZ model | |
@snn_kw struct IZParameter{FT = Float32} | |
a::FT = 0.01 | |
b::FT = 0.2 | |
c::FT = -65 | |
d::FT = 2 | |
end | |
# IZCUDA structure supporting GPU arrays with CUDA.jl | |
mutable struct IZCUDA{VT<:AbstractArray, UT<:AbstractArray, FT<:AbstractArray, IT<:AbstractArray} | |
param::IZParameter | |
v::VT | |
u::UT | |
fire::FT | |
I::IT | |
N::Int | |
records::Dict | |
end | |
Adapt.@adapt_structure IZCUDA | |
# Initialize IZCUDA structure with CUDA arrays | |
function IZCUDA(N::Int,param::IZParameter) | |
CUDA.allowscalar(false) # Disallow scalar indexing for performance | |
return IZCUDA( | |
param, | |
CUDA.fill(-65.0f0, N), | |
CUDA.fill(-65.0f0, N)*param.b, | |
CUDA.zeros(Bool, N), | |
CUDA.zeros(Float32, N), | |
N, | |
Dict() | |
) | |
end | |
# Kernel function for the IZCUDA model integration step | |
function integrate_kernel(v, u, fire, I, a, b, c, d, dt, N) | |
idx = threadIdx().x + (blockIdx().x - 1) * blockDim().x | |
if idx <= N | |
# Update the membrane potential (v) and recovery variable (u) | |
v[idx] += 0.5f0 * dt * (0.04f0 * v[idx]^2 + 5.0f0 * v[idx] + 140.0f0 - u[idx] + I[idx]) | |
v[idx] += 0.5f0 * dt * (0.04f0 * v[idx]^2 + 5.0f0 * v[idx] + 140.0f0 - u[idx] + I[idx]) | |
u[idx] += dt * (a * (b * v[idx] - u[idx])) | |
# Check for spike and reset variables | |
fire[idx] = v[idx] > 30.0f0 | |
v[idx] = fire[idx] ? c : v[idx] | |
u[idx] += fire[idx] ? d : 0.0f0 | |
end | |
return | |
end | |
# Function to launch the integration kernel on CUDA | |
function integrate!(p::IZCUDA, param::IZParameter, dt::Float32) | |
a, b, c, d = param.a, param.b, param.c, param.d | |
N = p.N | |
# Configure CUDA kernel launch parameters | |
threads = 256 # Number of threads per block | |
blocks = ceil(Int, N / threads) # Number of blocks | |
# Launch the CUDA kernel | |
@cuda threads=threads blocks=blocks integrate_kernel(p.v, p.u, p.fire, p.I, a, b, c, d, dt, N) | |
# Ensure all GPU operations complete | |
CUDA.synchronize() | |
end | |
""" | |
[Izhikevich Neuron](https://www.izhikevich.org/publications/spikes.htm) | |
""" | |
IZCUDA | |
""" | |
Benchmarking code that shows how the CUDA approach is slower than CPU | |
""" | |
using CUDA | |
using Plots | |
using Adapt | |
using SpikingNeuralNetworks | |
SNN.@load_units | |
function sim0() | |
Ne = 180; | |
Ni = 180; | |
E = SNN.IZCUDA(Ne,SNN.IZParameter(;a = 0.02, b = 0.2, c = -65, d = 8)) | |
I = SNN.IZCUDA(Ni,SNN.IZParameter(;a = 0.1, b = 0.2, c = -65, d = 2)) | |
EE = SNN.SpikingSynapse(E, E, :v; μ = 0.5, p = 0.8) | |
EI = SNN.SpikingSynapse(E, I, :v; μ = 0.5, p = 0.8) | |
IE = SNN.SpikingSynapse(I, E, :v; μ = -1.0, p = 0.8) | |
II = SNN.SpikingSynapse(I, I, :v; μ = -1.0, p = 0.8) | |
P = [E, I] | |
C = [EE, EI, IE, II] | |
SNN.monitor([E, I], [:fire]) | |
for t = 1:2000 | |
E.I .= CUDA.CuArray(5randn(Ne)) | |
I.I .= CUDA.CuArray(4randn(Ni)) | |
SNN.sim!(P, C, 1.0f0) | |
end | |
P | |
end | |
function sim1() | |
Ne = 1800; | |
Ni = 1800; | |
E = SNN.IZCUDA(Ne,SNN.IZParameter(;a = 0.02, b = 0.2, c = -65, d = 8)) | |
I = SNN.IZCUDA(Ni,SNN.IZParameter(;a = 0.1, b = 0.2, c = -65, d = 2)) | |
EE = SNN.SpikingSynapse(E, E, :v; μ = 0.5, p = 0.8) | |
EI = SNN.SpikingSynapse(E, I, :v; μ = 0.5, p = 0.8) | |
IE = SNN.SpikingSynapse(I, E, :v; μ = -1.0, p = 0.8) | |
II = SNN.SpikingSynapse(I, I, :v; μ = -1.0, p = 0.8) | |
P = [E, I] | |
C = [EE, EI, IE, II] | |
SNN.monitor([E, I], [:fire]) | |
for t = 1:2000 | |
E.I .= CUDA.CuArray(5randn(Ne)) | |
I.I .= CUDA.CuArray(4randn(Ni)) | |
SNN.sim!(P, C, 1.0f0) | |
end | |
P | |
end | |
function sim2() | |
Ne = 1800; | |
Ni = 1800; | |
E = SNN.IZCUDA(Ne,SNN.IZParameter(;a = 0.02, b = 0.2, c = -65, d = 8)) | |
I = SNN.IZCUDA(Ni,SNN.IZParameter(;a = 0.1, b = 0.2, c = -65, d = 2)) | |
EE = SNN.SpikingSynapse(E, E, :v; μ = 0.5, p = 0.8) | |
EI = SNN.SpikingSynapse(E, I, :v; μ = 0.5, p = 0.8) | |
IE = SNN.SpikingSynapse(I, E, :v; μ = -1.0, p = 0.8) | |
II = SNN.SpikingSynapse(I, I, :v; μ = -1.0, p = 0.8) | |
P = [E, I] | |
C = [EE, EI, IE, II] | |
SNN.monitor([E, I], [:fire]) | |
for t = 1:3000 | |
E.I .= CUDA.CuArray(5randn(Ne)) | |
I.I .= CUDA.CuArray(4randn(Ni)) | |
SNN.sim!(P, C, 1.0f0) | |
end | |
P | |
end | |
function sim3() | |
Ne = 2800; | |
Ni = 2800; | |
E = SNN.IZCUDA(Ne,SNN.IZParameter(;a = 0.02, b = 0.2, c = -65, d = 8)) | |
I = SNN.IZCUDA(Ni,SNN.IZParameter(;a = 0.1, b = 0.2, c = -65, d = 2)) | |
EE = SNN.SpikingSynapse(E, E, :v; μ = 0.5, p = 0.8) | |
EI = SNN.SpikingSynapse(E, I, :v; μ = 0.5, p = 0.8) | |
IE = SNN.SpikingSynapse(I, E, :v; μ = -1.0, p = 0.8) | |
II = SNN.SpikingSynapse(I, I, :v; μ = -1.0, p = 0.8) | |
P = [E, I] | |
C = [EE, EI, IE, II] | |
SNN.monitor([E, I], [:fire]) | |
for t = 1:6000 | |
E.I .= CUDA.CuArray(5randn(Ne)) | |
I.I .= CUDA.CuArray(4randn(Ni)) | |
SNN.sim!(P, C, 1.0f0) | |
end | |
P | |
end | |
@time P = sim0() | |
@time P = sim0() | |
@time P = sim1() | |
@time P = sim2() | |
@time P = sim3() | |
#= | |
function convert_records(records) | |
return Dict{Symbol, Any}(key => Array.(records[key]) for key in keys(records)) | |
end | |
# Convert records for all elements in P2 | |
for p in P # Iterate directly over the elements of P2 | |
p.records = convert_records(p.records) | |
end | |
SNN.raster(P) | |
=# | |
function sim0() | |
Ne = 180; | |
Ni = 180; | |
E = SNN.IZ(; N = Ne, param = SNN.IZParameter(; a = 0.02, b = 0.2, c = -65, d = 8)) | |
I = SNN.IZ(; N = Ni, param = SNN.IZParameter(; a = 0.1, b = 0.2, c = -65, d = 2)) | |
EE = SNN.SpikingSynapse(E, E, :v; μ = 0.5, p = 0.8) | |
EI = SNN.SpikingSynapse(E, I, :v; μ = 0.5, p = 0.8) | |
IE = SNN.SpikingSynapse(I, E, :v; μ = -1.0, p = 0.8) | |
II = SNN.SpikingSynapse(I, I, :v; μ = -1.0, p = 0.8) | |
P = [E, I] | |
C = [EE, EI, IE, II] | |
SNN.monitor([E, I], [:fire]) | |
for t = 1:2000 | |
E.I .= 5randn(Ne) | |
I.I .= 4randn(Ni) | |
SNN.sim!(P, C, 1.0f0) | |
end | |
P | |
end | |
function sim1() | |
Ne = 1800; | |
Ni = 1800; | |
E = SNN.IZ(; N = Ne, param = SNN.IZParameter(; a = 0.02, b = 0.2, c = -65, d = 8)) | |
I = SNN.IZ(; N = Ni, param = SNN.IZParameter(; a = 0.1, b = 0.2, c = -65, d = 2)) | |
EE = SNN.SpikingSynapse(E, E, :v; μ = 0.5, p = 0.8) | |
EI = SNN.SpikingSynapse(E, I, :v; μ = 0.5, p = 0.8) | |
IE = SNN.SpikingSynapse(I, E, :v; μ = -1.0, p = 0.8) | |
II = SNN.SpikingSynapse(I, I, :v; μ = -1.0, p = 0.8) | |
P = [E, I] | |
C = [EE, EI, IE, II] | |
SNN.monitor([E, I], [:fire]) | |
for t = 1:2000 | |
E.I .= 5randn(Ne) | |
I.I .= 4randn(Ni) | |
SNN.sim!(P, C, 1.0f0) | |
end | |
P | |
end | |
function sim2() | |
Ne = 1800; | |
Ni = 1800; | |
E = SNN.IZ(; N = Ne, param = SNN.IZParameter(; a = 0.02, b = 0.2, c = -65, d = 8)) | |
I = SNN.IZ(; N = Ni, param = SNN.IZParameter(; a = 0.1, b = 0.2, c = -65, d = 2)) | |
EE = SNN.SpikingSynapse(E, E, :v; μ = 0.5, p = 0.8) | |
EI = SNN.SpikingSynapse(E, I, :v; μ = 0.5, p = 0.8) | |
IE = SNN.SpikingSynapse(I, E, :v; μ = -1.0, p = 0.8) | |
II = SNN.SpikingSynapse(I, I, :v; μ = -1.0, p = 0.8) | |
P = [E, I] | |
C = [EE, EI, IE, II] | |
SNN.monitor([E, I], [:fire]) | |
for t = 1:3000 | |
E.I .= 5randn(Ne) | |
I.I .= 4randn(Ni) | |
SNN.sim!(P, C, 1.0f0) | |
end | |
P | |
end | |
function sim3() | |
Ne = 2800; | |
Ni = 2800; | |
E = SNN.IZ(; N = Ne, param = SNN.IZParameter(; a = 0.02, b = 0.2, c = -65, d = 8)) | |
I = SNN.IZ(; N = Ni, param = SNN.IZParameter(; a = 0.1, b = 0.2, c = -65, d = 2)) | |
EE = SNN.SpikingSynapse(E, E, :v; μ = 0.5, p = 0.8) | |
EI = SNN.SpikingSynapse(E, I, :v; μ = 0.5, p = 0.8) | |
IE = SNN.SpikingSynapse(I, E, :v; μ = -1.0, p = 0.8) | |
II = SNN.SpikingSynapse(I, I, :v; μ = -1.0, p = 0.8) | |
P = [E, I] | |
C = [EE, EI, IE, II] | |
SNN.monitor([E, I], [:fire]) | |
for t = 1:6000 | |
E.I .= 5randn(Ne) | |
I.I .= 4randn(Ni) | |
SNN.sim!(P, C, 1.0f0) | |
end | |
P | |
end | |
println("start CPU") | |
@time P = sim0() | |
@time P = sim0() | |
@time P = sim1() | |
@time P = sim2() | |
@time P = sim3() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment