Skip to content

Instantly share code, notes, and snippets.

@hathibelagal-dev
Created November 10, 2025 16:19
Show Gist options
  • Select an option

  • Save hathibelagal-dev/780c23514c62c150657d85ed2b16dfe6 to your computer and use it in GitHub Desktop.

Select an option

Save hathibelagal-dev/780c23514c62c150657d85ed2b16dfe6 to your computer and use it in GitHub Desktop.
Using Extropic's THRML to solve the n-queens puzzle
import jax
import jax.numpy as jnp
import random
import sys
import numpy as np
from thrml.block_management import Block
from thrml.block_sampling import BlockGibbsSpec, sample_states, SamplingSchedule
from thrml.pgm import CategoricalNode
from thrml.factor import AbstractFactor, FactorSamplingProgram
from thrml.interaction import InteractionGroup
from thrml.conditional_samplers import SoftmaxConditional
from jaxtyping import Array, Key, PyTree, Shaped
_State = PyTree[Shaped[Array, "nodes ?*state"], "State"]
# 1. Problem Setup
N = 8
# 2. Custom Sampler for N-Queens Constraints
class NQueensSampler(SoftmaxConditional):
n_choices: int
def compute_parameters(
self,
key: Key,
interactions: list[PyTree],
active_flags: list[Array],
states: list[list[_State]],
sampler_state: None,
output_sd: PyTree[jax.ShapeDtypeStruct],
) -> PyTree:
# We expect one interaction group.
interaction_data = interactions[0]
# states[0] is now a list of N-1 tail states. Each is a PyTree.
# For categorical nodes, the state is a single array.
# Each state has shape (batch_size, 1) because the blocks are size 1.
other_cols_list = [s[0] for s in states[0]]
other_cols = jnp.concatenate(other_cols_list, axis=1) # Shape: (batch_size, N-1)
# Data from the interaction
other_rows = interaction_data["other_rows"] # Shape: (1, 1, N-1)
current_row = interaction_data["current_row"] # Shape: (1, 1)
# Squeeze the dimensions added by the sampling program
other_rows = other_rows.squeeze()
current_row = current_row.squeeze()
batch_size = other_cols.shape[0]
# Broadcast to batch size
other_rows_b = jnp.broadcast_to(other_rows, (batch_size, self.n_choices - 1))
current_row_b = jnp.broadcast_to(current_row, (batch_size,))
def calculate_logits_for_board(current_row_single, other_rows_single, other_cols_single):
# current_row_single: scalar
# other_rows_single: [N-1]
# other_cols_single: [N-1]
def is_safe_for_col(c1):
# c1 is the column we are considering for the current row
def check_conflict(r2, c2):
col_conflict = (c1 == c2)
diag_conflict = (jnp.abs(current_row_single - r2) == jnp.abs(c1 - c2))
return col_conflict | diag_conflict
conflicts = jax.vmap(check_conflict)(other_rows_single, other_cols_single)
return jnp.any(conflicts)
any_conflict = jax.vmap(is_safe_for_col)(jnp.arange(self.n_choices))
return jnp.where(any_conflict, -1e9, 0.0)
# Vmap over the batch dimension
logits = jax.vmap(calculate_logits_for_board)(current_row_b, other_rows_b, other_cols)
return logits, sampler_state
def sample_given_parameters(self, key, parameters, sampler_state, output_sd):
# parameters are the logits, shape (batch_size, n_choices)
samples = jax.random.categorical(key, parameters, axis=-1).astype(output_sd.dtype)
# samples has shape (batch_size,)
# Reshape to (batch_size, 1) to match the state of a block of size 1
reshaped_samples = samples[:, None]
return reshaped_samples, sampler_state
# 3. Custom Factor to create the all-to-all interactions
class AllToAllFactor(AbstractFactor):
def to_interaction_groups(self) -> list[InteractionGroup]:
nodes = self.node_groups[0].nodes
n_nodes = len(nodes)
interactions = []
for i in range(n_nodes):
head_node = nodes[i]
tail_nodes_list = [nodes[j] for j in range(n_nodes) if i != j]
tail_rows = jnp.array([j for j in range(n_nodes) if i != j])
# Create N-1 tail blocks, each of size 1
tail_blocks = [Block([node]) for node in tail_nodes_list]
interaction = InteractionGroup(
interaction={
"other_rows": tail_rows[None, ...], # Shape: (1, N-1)
"current_row": jnp.array([i]) # Shape: (1,)
},
head_nodes=Block([head_node]),
tail_nodes=tail_blocks
)
interactions.append(interaction)
return interactions
# 4. Setup Sampling Program
# N nodes, one for each row. The state is the column of the queen.
row_nodes = [CategoricalNode() for _ in range(N)]
# The factor connects all row nodes together.
factor = AllToAllFactor([Block(row_nodes)])
# Gibbs spec: update one row at a time.
free_blocks = [Block([node]) for node in row_nodes]
clamped_blocks = []
node_shape_dtypes = {CategoricalNode: jax.ShapeDtypeStruct((), jnp.uint8)}
spec = BlockGibbsSpec(free_blocks, clamped_blocks, node_shape_dtypes)
# Use our custom sampler for each block.
sampler = NQueensSampler(n_choices=N)
samplers = [sampler] * len(free_blocks)
program = FactorSamplingProgram(
gibbs_spec=spec,
samplers=samplers,
factors=[factor],
other_interaction_groups=[],
)
# 5. Run Sampling
_seed = random.randint(102, sys.maxsize)
print(f"Using seed: {_seed}")
key = jax.random.key(_seed)
batch_size = 20 # Run a few independent chains to increase chances of finding a solution
schedule = SamplingSchedule(
n_warmup=200, # With this model, it should converge very quickly
n_samples=1,
steps_per_sample=1,
)
# Client-side batching loop
for i in range(batch_size):
print(f"Running chain {i + 1}/{batch_size}...")
# Create initial state for a single chain
key, subkey1, subkey2 = jax.random.split(key, 3)
single_init_state = [jax.random.randint(subkey1, (1, 1), 0, N, dtype=jnp.uint8) for _ in range(N)]
# Run sampling for the single chain
final_states_per_block = sample_states(subkey2, program, schedule, single_init_state, [], spec.free_blocks)
# Reconstruct and verify solution for this chain
board = np.zeros((N, N), dtype=int)
positions = []
for r in range(N):
# Result is for a single chain, so batch dimension is 0
c = final_states_per_block[r][0, 0, 0]
board[r, c] = 1
positions.append((r, c))
# Verify the solution
is_attacking = False
for i1 in range(len(positions)):
for i2 in range(i1 + 1, len(positions)):
r1, c1 = positions[i1]
r2, c2 = positions[i2]
if (c1 == c2) or (abs(r1 - r2) == abs(c1 - c2)):
is_attacking = True
break
if is_attacking:
break
if not is_attacking:
print("Found a valid solution!")
print(board)
exit()
print(f"No valid solution found in {batch_size} attempts.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment