-
-
Save tlkahn/a173e40b7457caacda6f84dfe1695649 to your computer and use it in GitHub Desktop.
Conway's Game of Life Accelerated with Custom Kernels in MLX
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import av | |
import numpy as np | |
import mlx.core as mx | |
def conway(a: mx.array): | |
source = """ | |
uint i = thread_position_in_grid.x; | |
uint j = thread_position_in_grid.y; | |
uint n = threads_per_grid.x; | |
uint m = threads_per_grid.y; | |
uint down = (i == 0) ? n : (i - 1); | |
uint up = (i + 1) == n ? 0 : (i + 1); | |
uint left = (j == 0) ? m : (j - 1); | |
uint right = (j + 1) == m ? 0 : (j + 1); | |
size_t idx = i * m + j; | |
int count = grid[up * m + right] + grid[up * m + j] | |
+ grid[i * m + right] + grid[up * m + left] + grid[down * m + left] | |
+ grid[down * m + j] + grid[i * m + left] + grid[down * m + right]; | |
if ((grid[idx] && count == 2) || count == 3) { | |
out[idx] = true; | |
} else { | |
out[idx] = false; | |
} | |
""" | |
kernel = mx.fast.metal_kernel( | |
name="conway", | |
source=source, | |
) | |
return kernel( | |
inputs={"grid": a}, | |
grid=(a.shape[0], a.shape[1], 1), | |
threadgroup=(2, 512, 1), | |
output_shapes={"out": a.shape}, | |
output_dtypes={"out": a.dtype}, | |
)["out"] | |
def generator(grid, steps=1000): | |
for i in range(steps): | |
mx.eval(grid) | |
yield (~grid).astype(mx.uint8) * 255 | |
grid = conway(grid) | |
def animate(generator, fps=8): | |
with av.open("out.mp4", mode="w") as container: | |
stream = container.add_stream("h264", rate=fps) | |
stream.pix_fmt = "yuv420p" | |
for img in generator: | |
stream.width = img.shape[1] | |
stream.height = img.shape[0] | |
frame = av.VideoFrame.from_ndarray(np.array(img), format="gray8") | |
frame.pict_type = "NONE" | |
for packet in stream.encode(frame): | |
container.mux(packet) | |
# Flush stream | |
for packet in stream.encode(): | |
container.mux(packet) | |
if __name__ == "__main__": | |
grid_size = 4096 | |
p_alive = 0.3 | |
grid = mx.random.bernoulli(p=0.3, shape=(grid_size, grid_size)) | |
animate(generator(grid)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment