Skip to content

Instantly share code, notes, and snippets.

@jonnyli1125
Last active November 30, 2024 06:31
Show Gist options
  • Save jonnyli1125/b037f42ce176c26aab30c0c99420c22f to your computer and use it in GitHub Desktop.
Save jonnyli1125/b037f42ce176c26aab30c0c99420c22f to your computer and use it in GitHub Desktop.
Self-play Reinforcement Learning (Q-Learning) for Tic Tac Toe in 100 lines of code
import argparse
import random
from collections import defaultdict
import numpy as np
from tqdm import tqdm
BOARD_SIZE = 3
def get_winner(board: np.ndarray) -> int:
for player in [1, -1]:
for i in range(3):
if all(board[i, :] == player) or all(board[:, i] == player):
return player
if player == board[0, 0] == board[1, 1] == board[2, 2]:
return player
if player == board[0, 2] == board[1, 1] == board[2, 0]:
return player
return 0
def possible_actions(board: np.ndarray) -> list[tuple[int, int]]:
return [(int(x), int(y)) for x, y in zip(*(board == 0).nonzero())]
def get_key(board: np.ndarray, action: tuple[int, int], player: int) -> tuple[tuple, int, int]:
return tuple(board.flatten().tolist()), (action[0] * BOARD_SIZE + action[1]), player
def choose_action(Q: dict, board: np.ndarray, player: int, epsilon: float, verbose: bool = False) -> tuple[int, int]:
actions = possible_actions(board)
if random.random() < epsilon:
return random.choice(actions)
q_values = [Q[get_key(board, a, player)] for a in actions]
if verbose:
print(list(zip(actions, q_values)))
best_q = max(q_values) if player == 1 else min(q_values)
return random.choice([a for a, q in zip(actions, q_values) if q == best_q])
def update(
Q: dict,
reward: float,
board: np.ndarray,
action: tuple[int, int],
next_board: np.ndarray,
next_actions: list[tuple[int, int]],
player: int,
alpha: float,
gamma: float,
) -> dict:
best_next_fn = min if player == 1 else max
best_next_q = best_next_fn(Q[get_key(next_board, a, player * -1)] for a in next_actions) if next_actions else 0
td_error = reward + gamma * best_next_q - Q[get_key(board, action, player)]
Q[get_key(board, action, player)] += alpha * td_error
def train(num_episodes: int, alpha: float, epsilon: float, gamma: float) -> dict:
Q = defaultdict(float)
for _ in tqdm(range(num_episodes), total=num_episodes, desc="Training"):
board = np.zeros((BOARD_SIZE, BOARD_SIZE))
player = 1
done = False
while not done:
action = choose_action(Q, board, player, epsilon)
x, y = action
next_board = board.copy()
next_board[x, y] = player
next_actions = possible_actions(next_board)
reward = get_winner(next_board)
update(Q, reward, board, action, next_board, next_actions, player, alpha, gamma)
player *= -1
done = reward != 0 or not next_actions
board = next_board
return Q
def visualize(board: np.ndarray) -> None:
for row in board:
print(" ".join("-" if val == 0 else ("X" if val == 1 else "O") for val in row))
def valid_input(board: np.ndarray) -> tuple[int, int]:
x, y = None, None
while (x, y) == (None, None) or board[x, y] != 0:
try:
parts = input("Enter coordinates (row,col): ").split(",")
x, y = int(parts[0]), int(parts[1])
assert 0 <= x < BOARD_SIZE and 0 <= y < BOARD_SIZE and board[x, y] == 0
except Exception:
x, y = None, None
return x, y
def main(Q: dict, verbose: bool) -> None:
human_player = None
player = None
board = None
def reset():
nonlocal human_player, player, board
human_player = -1 if input("Choose player to play as (X or O): ").lower() == "o" else 1
player = 1
board = np.zeros((BOARD_SIZE, BOARD_SIZE))
visualize(board)
reset()
while True:
if player == human_player:
x, y = valid_input(board)
else:
x, y = choose_action(Q, board, player, 0, verbose=verbose)
print(f"AI chooses {x},{y}")
board[x, y] = player
visualize(board)
player *= -1
winner = get_winner(board)
if winner != 0:
print("X" if winner == 1 else "O", "wins!")
reset()
if not possible_actions(board):
print("Draw!")
reset()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-n", "--num-episodes", type=int, default=50000)
parser.add_argument("-a", "--alpha", type=float, default=0.1)
parser.add_argument("-e", "--epsilon", type=float, default=0.9)
parser.add_argument("-g", "--gamma", type=float, default=0.9)
parser.add_argument("-v", "--verbose", action="store_true")
args = parser.parse_args()
Q = train(args.num_episodes, args.alpha, args.epsilon, args.gamma)
main(Q, args.verbose)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment