Created
November 10, 2025 16:19
-
-
Save hathibelagal-dev/780c23514c62c150657d85ed2b16dfe6 to your computer and use it in GitHub Desktop.
Using Extropic's THRML to solve the n-queens puzzle
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import jax | |
| import jax.numpy as jnp | |
| import random | |
| import sys | |
| import numpy as np | |
| from thrml.block_management import Block | |
| from thrml.block_sampling import BlockGibbsSpec, sample_states, SamplingSchedule | |
| from thrml.pgm import CategoricalNode | |
| from thrml.factor import AbstractFactor, FactorSamplingProgram | |
| from thrml.interaction import InteractionGroup | |
| from thrml.conditional_samplers import SoftmaxConditional | |
| from jaxtyping import Array, Key, PyTree, Shaped | |
| _State = PyTree[Shaped[Array, "nodes ?*state"], "State"] | |
| # 1. Problem Setup | |
| N = 8 | |
| # 2. Custom Sampler for N-Queens Constraints | |
| class NQueensSampler(SoftmaxConditional): | |
| n_choices: int | |
| def compute_parameters( | |
| self, | |
| key: Key, | |
| interactions: list[PyTree], | |
| active_flags: list[Array], | |
| states: list[list[_State]], | |
| sampler_state: None, | |
| output_sd: PyTree[jax.ShapeDtypeStruct], | |
| ) -> PyTree: | |
| # We expect one interaction group. | |
| interaction_data = interactions[0] | |
| # states[0] is now a list of N-1 tail states. Each is a PyTree. | |
| # For categorical nodes, the state is a single array. | |
| # Each state has shape (batch_size, 1) because the blocks are size 1. | |
| other_cols_list = [s[0] for s in states[0]] | |
| other_cols = jnp.concatenate(other_cols_list, axis=1) # Shape: (batch_size, N-1) | |
| # Data from the interaction | |
| other_rows = interaction_data["other_rows"] # Shape: (1, 1, N-1) | |
| current_row = interaction_data["current_row"] # Shape: (1, 1) | |
| # Squeeze the dimensions added by the sampling program | |
| other_rows = other_rows.squeeze() | |
| current_row = current_row.squeeze() | |
| batch_size = other_cols.shape[0] | |
| # Broadcast to batch size | |
| other_rows_b = jnp.broadcast_to(other_rows, (batch_size, self.n_choices - 1)) | |
| current_row_b = jnp.broadcast_to(current_row, (batch_size,)) | |
| def calculate_logits_for_board(current_row_single, other_rows_single, other_cols_single): | |
| # current_row_single: scalar | |
| # other_rows_single: [N-1] | |
| # other_cols_single: [N-1] | |
| def is_safe_for_col(c1): | |
| # c1 is the column we are considering for the current row | |
| def check_conflict(r2, c2): | |
| col_conflict = (c1 == c2) | |
| diag_conflict = (jnp.abs(current_row_single - r2) == jnp.abs(c1 - c2)) | |
| return col_conflict | diag_conflict | |
| conflicts = jax.vmap(check_conflict)(other_rows_single, other_cols_single) | |
| return jnp.any(conflicts) | |
| any_conflict = jax.vmap(is_safe_for_col)(jnp.arange(self.n_choices)) | |
| return jnp.where(any_conflict, -1e9, 0.0) | |
| # Vmap over the batch dimension | |
| logits = jax.vmap(calculate_logits_for_board)(current_row_b, other_rows_b, other_cols) | |
| return logits, sampler_state | |
| def sample_given_parameters(self, key, parameters, sampler_state, output_sd): | |
| # parameters are the logits, shape (batch_size, n_choices) | |
| samples = jax.random.categorical(key, parameters, axis=-1).astype(output_sd.dtype) | |
| # samples has shape (batch_size,) | |
| # Reshape to (batch_size, 1) to match the state of a block of size 1 | |
| reshaped_samples = samples[:, None] | |
| return reshaped_samples, sampler_state | |
| # 3. Custom Factor to create the all-to-all interactions | |
| class AllToAllFactor(AbstractFactor): | |
| def to_interaction_groups(self) -> list[InteractionGroup]: | |
| nodes = self.node_groups[0].nodes | |
| n_nodes = len(nodes) | |
| interactions = [] | |
| for i in range(n_nodes): | |
| head_node = nodes[i] | |
| tail_nodes_list = [nodes[j] for j in range(n_nodes) if i != j] | |
| tail_rows = jnp.array([j for j in range(n_nodes) if i != j]) | |
| # Create N-1 tail blocks, each of size 1 | |
| tail_blocks = [Block([node]) for node in tail_nodes_list] | |
| interaction = InteractionGroup( | |
| interaction={ | |
| "other_rows": tail_rows[None, ...], # Shape: (1, N-1) | |
| "current_row": jnp.array([i]) # Shape: (1,) | |
| }, | |
| head_nodes=Block([head_node]), | |
| tail_nodes=tail_blocks | |
| ) | |
| interactions.append(interaction) | |
| return interactions | |
| # 4. Setup Sampling Program | |
| # N nodes, one for each row. The state is the column of the queen. | |
| row_nodes = [CategoricalNode() for _ in range(N)] | |
| # The factor connects all row nodes together. | |
| factor = AllToAllFactor([Block(row_nodes)]) | |
| # Gibbs spec: update one row at a time. | |
| free_blocks = [Block([node]) for node in row_nodes] | |
| clamped_blocks = [] | |
| node_shape_dtypes = {CategoricalNode: jax.ShapeDtypeStruct((), jnp.uint8)} | |
| spec = BlockGibbsSpec(free_blocks, clamped_blocks, node_shape_dtypes) | |
| # Use our custom sampler for each block. | |
| sampler = NQueensSampler(n_choices=N) | |
| samplers = [sampler] * len(free_blocks) | |
| program = FactorSamplingProgram( | |
| gibbs_spec=spec, | |
| samplers=samplers, | |
| factors=[factor], | |
| other_interaction_groups=[], | |
| ) | |
| # 5. Run Sampling | |
| _seed = random.randint(102, sys.maxsize) | |
| print(f"Using seed: {_seed}") | |
| key = jax.random.key(_seed) | |
| batch_size = 20 # Run a few independent chains to increase chances of finding a solution | |
| schedule = SamplingSchedule( | |
| n_warmup=200, # With this model, it should converge very quickly | |
| n_samples=1, | |
| steps_per_sample=1, | |
| ) | |
| # Client-side batching loop | |
| for i in range(batch_size): | |
| print(f"Running chain {i + 1}/{batch_size}...") | |
| # Create initial state for a single chain | |
| key, subkey1, subkey2 = jax.random.split(key, 3) | |
| single_init_state = [jax.random.randint(subkey1, (1, 1), 0, N, dtype=jnp.uint8) for _ in range(N)] | |
| # Run sampling for the single chain | |
| final_states_per_block = sample_states(subkey2, program, schedule, single_init_state, [], spec.free_blocks) | |
| # Reconstruct and verify solution for this chain | |
| board = np.zeros((N, N), dtype=int) | |
| positions = [] | |
| for r in range(N): | |
| # Result is for a single chain, so batch dimension is 0 | |
| c = final_states_per_block[r][0, 0, 0] | |
| board[r, c] = 1 | |
| positions.append((r, c)) | |
| # Verify the solution | |
| is_attacking = False | |
| for i1 in range(len(positions)): | |
| for i2 in range(i1 + 1, len(positions)): | |
| r1, c1 = positions[i1] | |
| r2, c2 = positions[i2] | |
| if (c1 == c2) or (abs(r1 - r2) == abs(c1 - c2)): | |
| is_attacking = True | |
| break | |
| if is_attacking: | |
| break | |
| if not is_attacking: | |
| print("Found a valid solution!") | |
| print(board) | |
| exit() | |
| print(f"No valid solution found in {batch_size} attempts.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment