Skip to content

Instantly share code, notes, and snippets.

View yberreby's full-sized avatar

Yohaï-Eliel Berreby yberreby

View GitHub Profile
@yberreby
yberreby / bench_jax.py
Last active April 20, 2025 01:16
Quick JAX vs Triton comparison on a toy kernel. Outputs from runs on a RTX 4060 Mobile.
import functools, time, jax, jax.numpy as jnp
jax.config.update("jax_default_matmul_precision", "tensorfloat32")
SQRT2_OVER_PI = 0.7978845608028654
# ----------------------------------------------------------------------
def gelu_fast(x):
u = SQRT2_OVER_PI * (x + 0.044715 * x * x * x)
return 0.5 * x * (1. + jnp.tanh(u))
@yberreby
yberreby / 0_geom_opt_cmp.py
Last active April 20, 2025 01:16
Quick comparison of a few optimizers on the 2-simplex: GD, Adam, Mirror Descent, Adam in mirror space, LBFGS in mirror space
#!/usr/bin/env -S uv run --script
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "jax==0.5.0",
# "jaxopt>=0.8.5",
# "optax>=0.2.4",
# "matplotlib>=3.10.1",
# "pyqt6>=6.9.0", # for matplotlib gui
# ]
@yberreby
yberreby / cahn_hilliard_literate.py
Last active April 19, 2025 00:06
Cahn-Hilliard - Run this for visually-interesting non-linear ODE behavior, in an animated plot.
#!/usr/bin/env -S uv run --script
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "jax==0.5.0",
# "matplotlib>=3.10.1",
# "pyqt6>=6.9.0", # for matplotlib gui
# ]
# ///
@yberreby
yberreby / eccentricity_dependent_perlin_noise_jax.py
Last active April 17, 2025 23:19
Fast Perlin noise in JAX with eccentricity-dependent feature scaling. GPU-ready.
#!/usr/bin/env -S uv run --script --quiet
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "pyqt6", # For matplotlib backend
# "numpy",
# "matplotlib",
# "jax[cuda12]==0.5.2", # Change for CPU.
# ]
# ///
@yberreby
yberreby / jax_mlp.py
Created December 13, 2023 21:06
Minimal MLP in JAX - excerpt from the "Working with Pytrees" section of the JAX manual
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
def init_mlp_params(layer_widths):
params = []
for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
params.append(
dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in),
@ TOTAL=10000,DT=.1,xlo=0,xhi=1,ylo=-100,yhi=100
@ NPLOT=1,XP1=w,YP1=V
@ MAXSTOR=10000000
@ BOUNDS=100000
@ dsmin=1e-5,dsmax=0.5,parmin=0,parmax=3,autoxmin=0,autoxmax=1,Ntst=150,Nmax=2000,Npr=500,Ds=0.02,EPSL=1e-7,EPSU=1e-7,EPSS=1e-7,*Y-axis=V
@ autoymax=100,autoymin=-100
@yberreby
yberreby / results
Created September 4, 2016 13:34
Rust Benchmark - printing to stdout
print_macro: 105 ns/iter (+/- 20)
print_macro_locked_stdoutbench: 87 ns/iter (+/- 22)
direct_locked_stdout: 17 ns/iter (+/- 2)
direct_unlocked_stdout: 51 ns/iter (+/- 9)
#![feature(test)]
extern crate test;
use test::Bencher;
const LARGE_NUMBER: i32 = 1_000_000;
#[bench]
fn bench_even_imperative(b: &mut Bencher) {
b.iter(|| {
let mut list = Vec::with_capacity(LARGE_NUMBER as usize / 2 + 1);
@yberreby
yberreby / mod.rs
Created June 28, 2015 14:03
src/libstd/rt/unwind/mod.rs - panic handler PoC
// Copyright 2013 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// http://rust-lang.org/COPYRIGHT.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
#![feature(rt)]
#![feature(unmarked_api)]
use std::thread;
use std::any::Any;
use std::rt::unwind::set_panic_handler;
fn main() {
// Use the default handler
panic!("Something's wrong"); // Prints "thread '<main>' panicked at 'Something's wrong', /Users/yohai/code/panic_handlers_test.rs:10"