Last active
November 30, 2024 06:31
-
-
Save jonnyli1125/b037f42ce176c26aab30c0c99420c22f to your computer and use it in GitHub Desktop.
Self-play Reinforcement Learning (Q-Learning) for Tic Tac Toe in 100 lines of code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import random | |
from collections import defaultdict | |
import numpy as np | |
from tqdm import tqdm | |
BOARD_SIZE = 3 | |
def get_winner(board: np.ndarray) -> int: | |
for player in [1, -1]: | |
for i in range(3): | |
if all(board[i, :] == player) or all(board[:, i] == player): | |
return player | |
if player == board[0, 0] == board[1, 1] == board[2, 2]: | |
return player | |
if player == board[0, 2] == board[1, 1] == board[2, 0]: | |
return player | |
return 0 | |
def possible_actions(board: np.ndarray) -> list[tuple[int, int]]: | |
return [(int(x), int(y)) for x, y in zip(*(board == 0).nonzero())] | |
def get_key(board: np.ndarray, action: tuple[int, int], player: int) -> tuple[tuple, int, int]: | |
return tuple(board.flatten().tolist()), (action[0] * BOARD_SIZE + action[1]), player | |
def choose_action(Q: dict, board: np.ndarray, player: int, epsilon: float, verbose: bool = False) -> tuple[int, int]: | |
actions = possible_actions(board) | |
if random.random() < epsilon: | |
return random.choice(actions) | |
q_values = [Q[get_key(board, a, player)] for a in actions] | |
if verbose: | |
print(list(zip(actions, q_values))) | |
best_q = max(q_values) if player == 1 else min(q_values) | |
return random.choice([a for a, q in zip(actions, q_values) if q == best_q]) | |
def update( | |
Q: dict, | |
reward: float, | |
board: np.ndarray, | |
action: tuple[int, int], | |
next_board: np.ndarray, | |
next_actions: list[tuple[int, int]], | |
player: int, | |
alpha: float, | |
gamma: float, | |
) -> dict: | |
best_next_fn = min if player == 1 else max | |
best_next_q = best_next_fn(Q[get_key(next_board, a, player * -1)] for a in next_actions) if next_actions else 0 | |
td_error = reward + gamma * best_next_q - Q[get_key(board, action, player)] | |
Q[get_key(board, action, player)] += alpha * td_error | |
def train(num_episodes: int, alpha: float, epsilon: float, gamma: float) -> dict: | |
Q = defaultdict(float) | |
for _ in tqdm(range(num_episodes), total=num_episodes, desc="Training"): | |
board = np.zeros((BOARD_SIZE, BOARD_SIZE)) | |
player = 1 | |
done = False | |
while not done: | |
action = choose_action(Q, board, player, epsilon) | |
x, y = action | |
next_board = board.copy() | |
next_board[x, y] = player | |
next_actions = possible_actions(next_board) | |
reward = get_winner(next_board) | |
update(Q, reward, board, action, next_board, next_actions, player, alpha, gamma) | |
player *= -1 | |
done = reward != 0 or not next_actions | |
board = next_board | |
return Q | |
def visualize(board: np.ndarray) -> None: | |
for row in board: | |
print(" ".join("-" if val == 0 else ("X" if val == 1 else "O") for val in row)) | |
def valid_input(board: np.ndarray) -> tuple[int, int]: | |
x, y = None, None | |
while (x, y) == (None, None) or board[x, y] != 0: | |
try: | |
parts = input("Enter coordinates (row,col): ").split(",") | |
x, y = int(parts[0]), int(parts[1]) | |
assert 0 <= x < BOARD_SIZE and 0 <= y < BOARD_SIZE and board[x, y] == 0 | |
except Exception: | |
x, y = None, None | |
return x, y | |
def main(Q: dict, verbose: bool) -> None: | |
human_player = None | |
player = None | |
board = None | |
def reset(): | |
nonlocal human_player, player, board | |
human_player = -1 if input("Choose player to play as (X or O): ").lower() == "o" else 1 | |
player = 1 | |
board = np.zeros((BOARD_SIZE, BOARD_SIZE)) | |
visualize(board) | |
reset() | |
while True: | |
if player == human_player: | |
x, y = valid_input(board) | |
else: | |
x, y = choose_action(Q, board, player, 0, verbose=verbose) | |
print(f"AI chooses {x},{y}") | |
board[x, y] = player | |
visualize(board) | |
player *= -1 | |
winner = get_winner(board) | |
if winner != 0: | |
print("X" if winner == 1 else "O", "wins!") | |
reset() | |
if not possible_actions(board): | |
print("Draw!") | |
reset() | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-n", "--num-episodes", type=int, default=50000) | |
parser.add_argument("-a", "--alpha", type=float, default=0.1) | |
parser.add_argument("-e", "--epsilon", type=float, default=0.9) | |
parser.add_argument("-g", "--gamma", type=float, default=0.9) | |
parser.add_argument("-v", "--verbose", action="store_true") | |
args = parser.parse_args() | |
Q = train(args.num_episodes, args.alpha, args.epsilon, args.gamma) | |
main(Q, args.verbose) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment