Skip to content

Instantly share code, notes, and snippets.

@russelljjarvis
Created November 14, 2024 21:57
Show Gist options
  • Save russelljjarvis/fe6fd4f191e910d825e37dbc90d54a56 to your computer and use it in GitHub Desktop.
Save russelljjarvis/fe6fd4f191e910d825e37dbc90d54a56 to your computer and use it in GitHub Desktop.
using CUDA
using Adapt
# Check if CUDA is available
if !CUDA.has_cuda()
error("CUDA is not available on this system.")
else
CUDA.allowscalar(false) # Disallow scalar indexing for performance
end
# Parameter struct for IZ model
@snn_kw struct IZParameter{FT = Float32}
a::FT = 0.01
b::FT = 0.2
c::FT = -65
d::FT = 2
end
# IZCUDA structure supporting GPU arrays with CUDA.jl
mutable struct IZCUDA{VT<:AbstractArray, UT<:AbstractArray, FT<:AbstractArray, IT<:AbstractArray}
param::IZParameter
v::VT
u::UT
fire::FT
I::IT
N::Int
records::Dict
end
Adapt.@adapt_structure IZCUDA
# Initialize IZCUDA structure with CUDA arrays
function IZCUDA(N::Int,param::IZParameter)
CUDA.allowscalar(false) # Disallow scalar indexing for performance
return IZCUDA(
param,
CUDA.fill(-65.0f0, N),
CUDA.fill(-65.0f0, N)*param.b,
CUDA.zeros(Bool, N),
CUDA.zeros(Float32, N),
N,
Dict()
)
end
# Kernel function for the IZCUDA model integration step
function integrate_kernel(v, u, fire, I, a, b, c, d, dt, N)
idx = threadIdx().x + (blockIdx().x - 1) * blockDim().x
if idx <= N
# Update the membrane potential (v) and recovery variable (u)
v[idx] += 0.5f0 * dt * (0.04f0 * v[idx]^2 + 5.0f0 * v[idx] + 140.0f0 - u[idx] + I[idx])
v[idx] += 0.5f0 * dt * (0.04f0 * v[idx]^2 + 5.0f0 * v[idx] + 140.0f0 - u[idx] + I[idx])
u[idx] += dt * (a * (b * v[idx] - u[idx]))
# Check for spike and reset variables
fire[idx] = v[idx] > 30.0f0
v[idx] = fire[idx] ? c : v[idx]
u[idx] += fire[idx] ? d : 0.0f0
end
return
end
# Function to launch the integration kernel on CUDA
function integrate!(p::IZCUDA, param::IZParameter, dt::Float32)
a, b, c, d = param.a, param.b, param.c, param.d
N = p.N
# Configure CUDA kernel launch parameters
threads = 256 # Number of threads per block
blocks = ceil(Int, N / threads) # Number of blocks
# Launch the CUDA kernel
@cuda threads=threads blocks=blocks integrate_kernel(p.v, p.u, p.fire, p.I, a, b, c, d, dt, N)
# Ensure all GPU operations complete
CUDA.synchronize()
end
"""
[Izhikevich Neuron](https://www.izhikevich.org/publications/spikes.htm)
"""
IZCUDA
"""
Benchmarking code that shows how the CUDA approach is slower than CPU
"""
using CUDA
using Plots
using Adapt
using SpikingNeuralNetworks
SNN.@load_units
function sim0()
Ne = 180;
Ni = 180;
E = SNN.IZCUDA(Ne,SNN.IZParameter(;a = 0.02, b = 0.2, c = -65, d = 8))
I = SNN.IZCUDA(Ni,SNN.IZParameter(;a = 0.1, b = 0.2, c = -65, d = 2))
EE = SNN.SpikingSynapse(E, E, :v; μ = 0.5, p = 0.8)
EI = SNN.SpikingSynapse(E, I, :v; μ = 0.5, p = 0.8)
IE = SNN.SpikingSynapse(I, E, :v; μ = -1.0, p = 0.8)
II = SNN.SpikingSynapse(I, I, :v; μ = -1.0, p = 0.8)
P = [E, I]
C = [EE, EI, IE, II]
SNN.monitor([E, I], [:fire])
for t = 1:2000
E.I .= CUDA.CuArray(5randn(Ne))
I.I .= CUDA.CuArray(4randn(Ni))
SNN.sim!(P, C, 1.0f0)
end
P
end
function sim1()
Ne = 1800;
Ni = 1800;
E = SNN.IZCUDA(Ne,SNN.IZParameter(;a = 0.02, b = 0.2, c = -65, d = 8))
I = SNN.IZCUDA(Ni,SNN.IZParameter(;a = 0.1, b = 0.2, c = -65, d = 2))
EE = SNN.SpikingSynapse(E, E, :v; μ = 0.5, p = 0.8)
EI = SNN.SpikingSynapse(E, I, :v; μ = 0.5, p = 0.8)
IE = SNN.SpikingSynapse(I, E, :v; μ = -1.0, p = 0.8)
II = SNN.SpikingSynapse(I, I, :v; μ = -1.0, p = 0.8)
P = [E, I]
C = [EE, EI, IE, II]
SNN.monitor([E, I], [:fire])
for t = 1:2000
E.I .= CUDA.CuArray(5randn(Ne))
I.I .= CUDA.CuArray(4randn(Ni))
SNN.sim!(P, C, 1.0f0)
end
P
end
function sim2()
Ne = 1800;
Ni = 1800;
E = SNN.IZCUDA(Ne,SNN.IZParameter(;a = 0.02, b = 0.2, c = -65, d = 8))
I = SNN.IZCUDA(Ni,SNN.IZParameter(;a = 0.1, b = 0.2, c = -65, d = 2))
EE = SNN.SpikingSynapse(E, E, :v; μ = 0.5, p = 0.8)
EI = SNN.SpikingSynapse(E, I, :v; μ = 0.5, p = 0.8)
IE = SNN.SpikingSynapse(I, E, :v; μ = -1.0, p = 0.8)
II = SNN.SpikingSynapse(I, I, :v; μ = -1.0, p = 0.8)
P = [E, I]
C = [EE, EI, IE, II]
SNN.monitor([E, I], [:fire])
for t = 1:3000
E.I .= CUDA.CuArray(5randn(Ne))
I.I .= CUDA.CuArray(4randn(Ni))
SNN.sim!(P, C, 1.0f0)
end
P
end
function sim3()
Ne = 2800;
Ni = 2800;
E = SNN.IZCUDA(Ne,SNN.IZParameter(;a = 0.02, b = 0.2, c = -65, d = 8))
I = SNN.IZCUDA(Ni,SNN.IZParameter(;a = 0.1, b = 0.2, c = -65, d = 2))
EE = SNN.SpikingSynapse(E, E, :v; μ = 0.5, p = 0.8)
EI = SNN.SpikingSynapse(E, I, :v; μ = 0.5, p = 0.8)
IE = SNN.SpikingSynapse(I, E, :v; μ = -1.0, p = 0.8)
II = SNN.SpikingSynapse(I, I, :v; μ = -1.0, p = 0.8)
P = [E, I]
C = [EE, EI, IE, II]
SNN.monitor([E, I], [:fire])
for t = 1:6000
E.I .= CUDA.CuArray(5randn(Ne))
I.I .= CUDA.CuArray(4randn(Ni))
SNN.sim!(P, C, 1.0f0)
end
P
end
@time P = sim0()
@time P = sim0()
@time P = sim1()
@time P = sim2()
@time P = sim3()
#=
function convert_records(records)
return Dict{Symbol, Any}(key => Array.(records[key]) for key in keys(records))
end
# Convert records for all elements in P2
for p in P # Iterate directly over the elements of P2
p.records = convert_records(p.records)
end
SNN.raster(P)
=#
function sim0()
Ne = 180;
Ni = 180;
E = SNN.IZ(; N = Ne, param = SNN.IZParameter(; a = 0.02, b = 0.2, c = -65, d = 8))
I = SNN.IZ(; N = Ni, param = SNN.IZParameter(; a = 0.1, b = 0.2, c = -65, d = 2))
EE = SNN.SpikingSynapse(E, E, :v; μ = 0.5, p = 0.8)
EI = SNN.SpikingSynapse(E, I, :v; μ = 0.5, p = 0.8)
IE = SNN.SpikingSynapse(I, E, :v; μ = -1.0, p = 0.8)
II = SNN.SpikingSynapse(I, I, :v; μ = -1.0, p = 0.8)
P = [E, I]
C = [EE, EI, IE, II]
SNN.monitor([E, I], [:fire])
for t = 1:2000
E.I .= 5randn(Ne)
I.I .= 4randn(Ni)
SNN.sim!(P, C, 1.0f0)
end
P
end
function sim1()
Ne = 1800;
Ni = 1800;
E = SNN.IZ(; N = Ne, param = SNN.IZParameter(; a = 0.02, b = 0.2, c = -65, d = 8))
I = SNN.IZ(; N = Ni, param = SNN.IZParameter(; a = 0.1, b = 0.2, c = -65, d = 2))
EE = SNN.SpikingSynapse(E, E, :v; μ = 0.5, p = 0.8)
EI = SNN.SpikingSynapse(E, I, :v; μ = 0.5, p = 0.8)
IE = SNN.SpikingSynapse(I, E, :v; μ = -1.0, p = 0.8)
II = SNN.SpikingSynapse(I, I, :v; μ = -1.0, p = 0.8)
P = [E, I]
C = [EE, EI, IE, II]
SNN.monitor([E, I], [:fire])
for t = 1:2000
E.I .= 5randn(Ne)
I.I .= 4randn(Ni)
SNN.sim!(P, C, 1.0f0)
end
P
end
function sim2()
Ne = 1800;
Ni = 1800;
E = SNN.IZ(; N = Ne, param = SNN.IZParameter(; a = 0.02, b = 0.2, c = -65, d = 8))
I = SNN.IZ(; N = Ni, param = SNN.IZParameter(; a = 0.1, b = 0.2, c = -65, d = 2))
EE = SNN.SpikingSynapse(E, E, :v; μ = 0.5, p = 0.8)
EI = SNN.SpikingSynapse(E, I, :v; μ = 0.5, p = 0.8)
IE = SNN.SpikingSynapse(I, E, :v; μ = -1.0, p = 0.8)
II = SNN.SpikingSynapse(I, I, :v; μ = -1.0, p = 0.8)
P = [E, I]
C = [EE, EI, IE, II]
SNN.monitor([E, I], [:fire])
for t = 1:3000
E.I .= 5randn(Ne)
I.I .= 4randn(Ni)
SNN.sim!(P, C, 1.0f0)
end
P
end
function sim3()
Ne = 2800;
Ni = 2800;
E = SNN.IZ(; N = Ne, param = SNN.IZParameter(; a = 0.02, b = 0.2, c = -65, d = 8))
I = SNN.IZ(; N = Ni, param = SNN.IZParameter(; a = 0.1, b = 0.2, c = -65, d = 2))
EE = SNN.SpikingSynapse(E, E, :v; μ = 0.5, p = 0.8)
EI = SNN.SpikingSynapse(E, I, :v; μ = 0.5, p = 0.8)
IE = SNN.SpikingSynapse(I, E, :v; μ = -1.0, p = 0.8)
II = SNN.SpikingSynapse(I, I, :v; μ = -1.0, p = 0.8)
P = [E, I]
C = [EE, EI, IE, II]
SNN.monitor([E, I], [:fire])
for t = 1:6000
E.I .= 5randn(Ne)
I.I .= 4randn(Ni)
SNN.sim!(P, C, 1.0f0)
end
P
end
println("start CPU")
@time P = sim0()
@time P = sim0()
@time P = sim1()
@time P = sim2()
@time P = sim3()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment