Created
July 1, 2024 20:30
-
-
Save ekzhang/0db8693ea76dbdf80f693a3470d4ede9 to your computer and use it in GitHub Desktop.
An incomplete interview question for performance-optimizing Hashlife in Rust
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::{collections::HashMap, sync::Arc}; | |
use md5::{Digest, Md5}; | |
#[derive(Clone)] | |
pub enum Node { | |
Subtree(Arc<Subtree>), | |
Leaf(Leaf), | |
} | |
pub struct Subtree { | |
pub nw: Node, | |
pub ne: Node, | |
pub sw: Node, | |
pub se: Node, | |
pub hash: u128, | |
pub count: u64, | |
/// Stores the result of the central half-sized region after 2^k generations. | |
/// This starts empty, then it is filled in on-demand and cached. | |
pub result: HashMap<u8, Node>, | |
} | |
#[derive(Clone, Copy)] | |
pub struct Leaf { | |
/// Cells of the 4x4 region, packed into a 16-bit integer. | |
pub value: u16, | |
} | |
impl Node { | |
pub fn hash(&self) -> u128 { | |
match self { | |
Self::Subtree(subtree) => subtree.hash, | |
Self::Leaf(leaf) => 0xdeadbeefdeadbeef1234567812345678 ^ u128::from(leaf.value), | |
} | |
} | |
pub fn count(&self) -> u64 { | |
match self { | |
Self::Subtree(subtree) => subtree.count, | |
Self::Leaf(leaf) => leaf.value.count_ones() as u64, | |
} | |
} | |
/// Return the log2() of the width and height of the region. | |
pub fn order(&self) -> u8 { | |
match self { | |
Self::Subtree(subtree) => 1 + subtree.nw.order(), | |
Self::Leaf(_) => 2, // 4x4 base case | |
} | |
} | |
pub fn get_cell(&self, x: u64, y: u64) -> bool { | |
match self { | |
Self::Subtree(subtree) => { | |
let mid = 1 << subtree.nw.order(); | |
match (x >= mid, y >= mid) { | |
(false, false) => subtree.nw.get_cell(x, y), | |
(true, false) => subtree.ne.get_cell(x - mid, y), | |
(false, true) => subtree.sw.get_cell(x, y - mid), | |
(true, true) => subtree.se.get_cell(x - mid, y - mid), | |
} | |
} | |
Self::Leaf(leaf) => (leaf.value >> (y * 4 + x)) & 1 != 0, | |
} | |
} | |
pub fn as_subtree(&self) -> &Subtree { | |
match self { | |
Self::Subtree(subtree) => subtree, | |
_ => panic!("expected a subtree"), | |
} | |
} | |
pub fn as_8x8(&self) -> u64 { | |
debug_assert!(self.order() == 3); | |
match self.as_subtree() { | |
Subtree { | |
nw: Node::Leaf(nw), | |
ne: Node::Leaf(ne), | |
sw: Node::Leaf(sw), | |
se: Node::Leaf(se), | |
.. | |
} => { | |
let split = |x: u16| -> u64 { | |
u64::from(x & 0xf000) | |
| u64::from(x & 0x0f00) << 4 | |
| u64::from(x & 0x00f0) << 8 | |
| u64::from(x & 0x000f) << 12 | |
}; | |
split(nw.value) | |
| split(ne.value) << 4 | |
| split(sw.value) << 32 | |
| split(se.value) << 36 | |
} | |
_ => unreachable!(), | |
} | |
} | |
} | |
/// Given an 8x8 region, produce the next generation of the 6x6 in the center. | |
fn step_8x8(value: u64) -> u64 { | |
let mask = 0x7050700000000000; | |
let mut result = 0; | |
for row in 0..6 { | |
for col in 0..6 { | |
let offset = row * 8 + col; | |
let neighbors = ((value >> offset) & mask).count_ones(); | |
let cell = (value >> (offset + 9)) & 1 != 0; | |
if neighbors == 3 || (neighbors == 2 && cell) { | |
result = (result << 1) + 1; | |
} else { | |
result <<= 1; | |
} | |
} | |
result <<= 2; | |
} | |
result | |
} | |
fn extract_8x8_nw(value: u64) -> u16 { | |
(value & 0xf000000000000000 | |
| (value >> 4) & 0x0f00000000000000 | |
| (value >> 8) & 0x00f0000000000000 | |
| (value >> 12) & 0x000f000000000000) as u16 | |
} | |
/// Produce the hash for a subtree node from the hashes of its children. | |
fn combine_hashes(nw: u128, ne: u128, sw: u128, se: u128) -> u128 { | |
let mut digest = Md5::new(); | |
digest.update(nw.to_le_bytes()); | |
digest.update(ne.to_le_bytes()); | |
digest.update(sw.to_le_bytes()); | |
digest.update(se.to_le_bytes()); | |
let bytes: [u8; 16] = digest.finalize().into(); | |
u128::from_le_bytes(bytes) | |
} | |
#[derive(Clone)] | |
pub struct Board { | |
/// Coordinates of the top-left of the root node's region. | |
offset: (i64, i64), | |
/// The cells of the grid as a quadtree. | |
root: Node, | |
} | |
impl Board { | |
pub fn get_cell(&self, x: i64, y: i64) -> bool { | |
let dims = 1_i64 << self.root.order(); | |
let x0 = x - self.offset.0; | |
let y0 = y - self.offset.1; | |
if x0 < 0 || x0 >= dims || y0 < 0 || y0 >= dims { | |
return false; // all cells outside of the root are dead | |
} | |
self.root.get_cell(x0 as u64, y0 as u64) | |
} | |
} | |
pub struct Engine { | |
/// A cache with previously-computed quadtree nodes. | |
cache: HashMap<u128, Node>, | |
} | |
impl Engine { | |
/// Create a new simulation engine with an empty cache. | |
pub fn new() -> Self { | |
let cache = HashMap::new(); | |
Self { cache } | |
} | |
/// Simulate 2^k steps of a pattern. | |
pub fn step(&self, board: &Board, k: u8) -> Board { | |
let mut board = board.clone(); | |
let count = board.root.count(); | |
while board.root.order() < k + 2 { | |
board = self.expand(board); | |
} | |
if self.central(&board.root).count() < count { | |
board = self.expand(board); | |
} | |
let order = board.root.order(); | |
let root = self.node_step(&board.root, k); | |
Board { | |
offset: ( | |
board.offset.0 + i64::from(1 << (order - 2)), | |
board.offset.1 + i64::from(1 << (order - 2)), | |
), | |
root, | |
} | |
} | |
pub fn parse_rle(&self, pattern: &str) -> Board { | |
} | |
fn expand(&self, board: Board) -> Board { | |
let sub = board.root.as_subtree(); | |
let k = sub.nw.order(); | |
let zeros = self.zeros(k); | |
let offset = ( | |
board.offset.0 - i64::from(1 << k), | |
board.offset.1 - i64::from(1 << k), | |
); | |
let root = self.subtree( | |
&self.subtree(&zeros, &zeros, &zeros, &sub.nw), | |
&self.subtree(&zeros, &zeros, &sub.ne, &zeros), | |
&self.subtree(&zeros, &sub.sw, &zeros, &zeros), | |
&self.subtree(&sub.se, &zeros, &zeros, &zeros), | |
); | |
Board { offset, root } | |
} | |
/// Create a subtree out of four child nodes. | |
fn subtree(&self, nw: &Node, ne: &Node, sw: &Node, se: &Node) -> Node { | |
let hash = combine_hashes(nw.hash(), ne.hash(), sw.hash(), se.hash()); | |
if let Some(node) = self.cache.get(&hash) { | |
node.clone() | |
} else { | |
Node::Subtree(Arc::new(Subtree { | |
nw: nw.clone(), | |
ne: ne.clone(), | |
sw: sw.clone(), | |
se: se.clone(), | |
hash, | |
count: nw.count() + ne.count() + sw.count() + se.count(), | |
result: HashMap::new(), | |
})) | |
} | |
} | |
fn zeros(&self, order: u8) -> Node { | |
if order == 2 { | |
Node::Leaf(Leaf { value: 0 }) | |
} else { | |
let sub = self.zeros(order - 1); | |
self.subtree(&sub, &sub, &sub, &sub) | |
} | |
} | |
fn hstack(&self, w: &Node, e: &Node) -> Node { | |
let w = w.as_subtree(); | |
let e = e.as_subtree(); | |
self.subtree(&w.ne, &e.nw, &w.se, &e.sw) | |
} | |
fn vstack(&self, n: &Node, s: &Node) -> Node { | |
let n = n.as_subtree(); | |
let s = s.as_subtree(); | |
self.subtree(&n.sw, &n.se, &s.nw, &s.ne) | |
} | |
fn central(&self, node: &Node) -> Node { | |
let node = node.as_subtree(); | |
self.subtree( | |
&node.nw.as_subtree().se, | |
&node.ne.as_subtree().sw, | |
&node.sw.as_subtree().ne, | |
&node.se.as_subtree().nw, | |
) | |
} | |
/// Step through 2^k generations of a node in the quadtree, returning the central half. | |
fn node_step(&self, node: &Node, k: u8) -> Node { | |
// The node must be at least 2^(k+1) x 2^(k+1) in size to simulate 2^k generations. | |
debug_assert!(node.order() >= k + 2); | |
// Base case: 8x8 region. | |
if node.order() == 3 { | |
let value = node.as_8x8(); | |
return Node::Leaf(Leaf { | |
value: if k == 0 { | |
extract_8x8_nw(step_8x8(value) >> 9) | |
} else { | |
debug_assert!(k == 1); | |
extract_8x8_nw(step_8x8(step_8x8(value))) | |
}, | |
}); | |
} | |
let sub = node.as_subtree(); | |
if node.order() == k + 2 { | |
// Recursive case 1: the region is exactly 2^(k+1) x 2^(k+1) in size. | |
let s00 = self.node_step(&sub.nw, k - 1); | |
let s10 = self.node_step(&self.hstack(&sub.nw, &sub.ne), k - 1); | |
let s20 = self.node_step(&sub.ne, k - 1); | |
let s01 = self.node_step(&self.vstack(&sub.nw, &sub.sw), k - 1); | |
let s11 = self.node_step(&self.central(node), k - 1); | |
let s21 = self.node_step(&self.vstack(&sub.ne, &sub.se), k - 1); | |
let s02 = self.node_step(&sub.sw, k - 1); | |
let s12 = self.node_step(&self.hstack(&sub.sw, &sub.se), k - 1); | |
let s22 = self.node_step(&sub.se, k - 1); | |
self.subtree( | |
&self.node_step(&self.subtree(&s00, &s10, &s01, &s11), k - 1), | |
&self.node_step(&self.subtree(&s10, &s20, &s11, &s21), k - 1), | |
&self.node_step(&self.subtree(&s01, &s11, &s02, &s12), k - 1), | |
&self.node_step(&self.subtree(&s11, &s21, &s12, &s22), k - 1), | |
) | |
} else { | |
// Recursive case 2: the region is larger than 2^(k+1) x 2^(k+1). | |
let s00 = self.central(&sub.nw); | |
let s10 = self.central(&self.hstack(&sub.nw, &sub.ne)); | |
let s20 = self.central(&sub.ne); | |
let s01 = self.central(&self.vstack(&sub.nw, &sub.sw)); | |
let s11 = self.central(&self.central(node)); | |
let s21 = self.central(&self.vstack(&sub.ne, &sub.se)); | |
let s02 = self.central(&sub.sw); | |
let s12 = self.central(&self.hstack(&sub.sw, &sub.se)); | |
let s22 = self.central(&sub.se); | |
self.subtree( | |
&self.node_step(&self.subtree(&s00, &s10, &s01, &s11), k), | |
&self.node_step(&self.subtree(&s10, &s20, &s11, &s21), k), | |
&self.node_step(&self.subtree(&s01, &s11, &s02, &s12), k), | |
&self.node_step(&self.subtree(&s11, &s21, &s12, &s22), k), | |
) | |
} | |
} | |
} | |
impl Default for Engine { | |
fn default() -> Self { | |
Self::new() | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment