Skip to content

Instantly share code, notes, and snippets.

@awni
Last active February 7, 2025 21:39
Show Gist options
  • Save awni/fde217c67e6be098e0773d3a7de93f02 to your computer and use it in GitHub Desktop.
Save awni/fde217c67e6be098e0773d3a7de93f02 to your computer and use it in GitHub Desktop.
Conway's Game of Life Accelerated with Custom Kernels in MLX
import numpy as np
import mlx.core as mx
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import tqdm
def conway(a: mx.array):
source = """
uint i = thread_position_in_grid.x;
uint j = thread_position_in_grid.y;
uint n = threads_per_grid.x;
uint m = threads_per_grid.y;
uint down = (i == 0) ? n : (i - 1);
uint up = (i + 1) == n ? 0 : (i + 1);
uint left = (j == 0) ? m : (j - 1);
uint right = (j + 1) == m ? 0 : (j + 1);
size_t idx = i * m + j;
int count = grid[up * m + right] + grid[up * m + j]
+ grid[i * m + right] + grid[up * m + left] + grid[down * m + left]
+ grid[down * m + j] + grid[i * m + left] + grid[down * m + right];
if ((grid[idx] && count == 2) || count == 3) {
out[idx] = true;
} else {
out[idx] = false;
}
"""
kernel = mx.fast.metal_kernel(
name="conway",
input_names=["grid"],
output_names=["out"],
source=source,
)
return kernel(
inputs=[a],
grid=(a.shape[0], a.shape[1], 1),
threadgroup=(2, 512, 1),
output_shapes=[a.shape],
output_dtypes=[a.dtype],
)[0]
def generator(grid, steps=1000):
for i in range(steps):
mx.eval(grid)
yield (~grid).astype(mx.uint8) * 255
grid = conway(grid)
def animate(grid, steps=300, fps=30, save_as="out.mp4"):
fig, ax = plt.subplots(figsize=(8, 4))
im = ax.imshow(mx.zeros_like(grid), cmap="gray", vmin=0, vmax=255)
progress_bar = tqdm.tqdm(total=steps, desc="Animating", ncols=100)
def update(frame):
im.set_data(frame)
progress_bar.update(1)
return [im]
ani = FuncAnimation(
fig,
update,
frames=generator(grid, steps=steps),
interval=steps // fps,
blit=True,
cache_frame_data=False,
)
fig.tight_layout()
try:
ani.save(save_as, writer="ffmpeg", fps=fps, dpi=300)
finally:
progress_bar.close()
plt.show()
if __name__ == "__main__":
grid_size = 2048
grid = mx.random.bernoulli(p=0.3, shape=(grid_size, grid_size))
animate(grid)
@stockeh
Copy link

stockeh commented Aug 30, 2024

Awesome! Here are the changes to use matplotlib instead of av for rendering and without numpy:

import mlx.core as mx
import matplotlib.pyplot as plt

from matplotlib.animation import FuncAnimation
from tqdm import tqdm

# def conway(a: mx.array): ...

# def generator(grid, steps=1000): ...

def animate(grid, steps=300, fps=30, save_as='out.mp4'):
    fig, ax = plt.subplots(figsize=(8, 4))
    im = ax.imshow(mx.zeros_like(grid), cmap='gray', vmin=0, vmax=255)

    progress_bar = tqdm(total=steps, desc='Animating', ncols=100)

    def update(frame):
        im.set_data(frame)
        progress_bar.update(1)
        return [im]

    ani = FuncAnimation(fig, update, frames=generator(
        grid, steps=steps), interval=steps // fps, blit=True, cache_frame_data=False)

    fig.tight_layout()
    try:
        ani.save(save_as, writer='ffmpeg', fps=fps, dpi=300)
    finally:
        progress_bar.close()

    plt.show()

if __name__ == "__main__":
    grid = mx.random.bernoulli(p=0.3, shape=(128, 265))
    animate(grid)

@cdotwang
Copy link

cdotwang commented Oct 3, 2024

There is a bug with this code in latest mlx 0.18.0

TypeError: metal_kernel(): incompatible function arguments. The following argument types are supported:
    1. metal_kernel(name: str, input_names: collections.abc.Sequence[str], output_names: collections.abc.Sequence[str], source: str, header: str = '', ensure_row_contiguous: bool = True, atomic_outputs: bool = False) -> object

Invoked with types: kwargs = { name: str, source: str }

@awni
Copy link
Author

awni commented Oct 4, 2024

I updated it.

@ivanfioravanti
Copy link

ivanfioravanti commented Feb 7, 2025

Cool!!!! Here a version with video in real time using pygame.

import numpy as np
import mlx.core as mx
import pygame

def conway(a: mx.array):
    source = """
        uint i = thread_position_in_grid.x;
        uint j = thread_position_in_grid.y;
        uint n = threads_per_grid.x;
        uint m = threads_per_grid.y;
        uint down = (i == 0) ? n : (i - 1);
        uint up = (i + 1) == n ? 0 : (i + 1);
        uint left = (j == 0) ? m : (j - 1);
        uint right = (j + 1) == m ? 0 : (j + 1);
        size_t idx = i * m + j;
        int count = grid[up * m + right] + grid[up * m + j]
          + grid[i * m + right] + grid[up * m + left] + grid[down * m + left]
          + grid[down * m + j] + grid[i * m + left] + grid[down * m + right];
        if ((grid[idx] && count == 2) || count == 3) {
          out[idx] = true;
        } else {
          out[idx] = false;
        }
    """
    kernel = mx.fast.metal_kernel(
        name="conway",
        input_names=["grid"],
        output_names=["out"],
        source=source,
    )
    return kernel(
        inputs=[a],
        grid=(a.shape[0], a.shape[1], 1),
        threadgroup=(2, 512, 1),
        output_shapes=[a.shape],
        output_dtypes=[a.dtype],
    )[0]

def run_game(grid, fps=30):
    pygame.init()
    pygame.display.set_caption("Apple MLX Conway's Game of Life")

    window_size = (800, 800)
    screen = pygame.display.set_mode(window_size)
    clock = pygame.time.Clock()

    grid_height, grid_width = grid.shape

    initial_zoom = min(window_size[0] / grid_width, window_size[1] / grid_height)
    zoom = initial_zoom

    dragging = False
    drag_start = None
    view_offset = [0, 0]

    running = True
    step = 0

    while running:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                running = False

            elif event.type == pygame.KEYDOWN:
                if event.key in (pygame.K_PLUS, pygame.K_EQUALS):
                    zoom *= 1.1
                elif event.key == pygame.K_MINUS:
                    zoom = max(initial_zoom, zoom / 1.1)

            elif event.type == pygame.MOUSEWHEEL:
                if event.y > 0:
                    zoom *= 1.1
                elif event.y < 0:
                    zoom = max(initial_zoom, zoom / 1.1)
                    
            elif event.type == pygame.MOUSEBUTTONDOWN:
                if event.button == 1:  # Left mouse button
                    dragging = True
                    drag_start = event.pos
            elif event.type == pygame.MOUSEBUTTONUP:
                if event.button == 1:
                    dragging = False
            elif event.type == pygame.MOUSEMOTION:
                if dragging:
                    dx = event.pos[0] - drag_start[0]
                    dy = event.pos[1] - drag_start[1]
                    view_offset[0] -= dx
                    view_offset[1] -= dy
                    drag_start = event.pos

        frame = (~grid).astype(mx.uint8) * 255
        mx.eval(frame)
        frame_np = np.array(frame)

        frame_rgb = np.stack([frame_np] * 3, axis=-1)  

        surface = pygame.surfarray.make_surface(frame_rgb.swapaxes(0, 1))

        scaled_width = int(grid_width * zoom)
        scaled_height = int(grid_height * zoom)
        scaled_surface = pygame.transform.scale(surface, (scaled_width, scaled_height))

        if scaled_width < window_size[0] or scaled_height < window_size[1]:
            screen.fill((0, 0, 0))
            offset_x = (window_size[0] - scaled_width) // 2
            offset_y = (window_size[1] - scaled_height) // 2
            screen.blit(scaled_surface, (offset_x, offset_y))
        else:
            x = (scaled_width - window_size[0]) // 2 + view_offset[0]
            y = (scaled_height - window_size[1]) // 2 + view_offset[1]
            
            x = max(0, min(x, scaled_width - window_size[0]))
            y = max(0, min(y, scaled_height - window_size[1]))
            
            view_offset[0] = x - (scaled_width - window_size[0]) // 2
            view_offset[1] = y - (scaled_height - window_size[1]) // 2
            
            cropped = scaled_surface.subsurface((x, y, window_size[0], window_size[1]))
            screen.blit(cropped, (0, 0))

        pygame.display.flip()

        grid = conway(grid)
        step += 1

        clock.tick(fps)

    pygame.quit()


if __name__ == "__main__":
    grid_size = 2048
    grid = mx.random.bernoulli(p=0.3, shape=(grid_size, grid_size))
    run_game(grid, fps=60)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment