Created
January 15, 2025 15:09
-
-
Save NeelMishra/9e47aebcdef11c4fed0920de3b89e170 to your computer and use it in GitHub Desktop.
Monte Carlo Cart Pole Balancing with Epsilon Greedy Policy Improvement
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gymnasium as gym | |
import numpy as np | |
class MonteCarloCartPole: | |
def __init__(self, n_bins=10, learning_rate=0.1, gamma=0.99, epsilon=0.1): | |
self.env = gym.make('CartPole-v1') | |
self.n_bins = n_bins | |
self.learning_rate = learning_rate | |
self.gamma = gamma | |
self.epsilon = epsilon | |
# Define state space bins for discretization | |
self.cart_pos_bins = np.linspace(-4.8, 4.8, n_bins) | |
self.cart_vel_bins = np.linspace(-5, 5, n_bins) | |
self.pole_ang_bins = np.linspace(-0.418, 0.418, n_bins) | |
self.pole_vel_bins = np.linspace(-5, 5, n_bins) | |
# Initialize Q-table | |
self.q_table = np.zeros((n_bins, n_bins, n_bins, n_bins, 2)) | |
def discretize_state(self, state): | |
"""Convert continuous state to discrete state indices.""" | |
cart_pos_idx = np.digitize(state[0], self.cart_pos_bins) - 1 | |
cart_vel_idx = np.digitize(state[1], self.cart_vel_bins) - 1 | |
pole_ang_idx = np.digitize(state[2], self.pole_ang_bins) - 1 | |
pole_vel_idx = np.digitize(state[3], self.pole_vel_bins) - 1 | |
# Clip indices to handle edge cases | |
cart_pos_idx = np.clip(cart_pos_idx, 0, self.n_bins - 1) | |
cart_vel_idx = np.clip(cart_vel_idx, 0, self.n_bins - 1) | |
pole_ang_idx = np.clip(pole_ang_idx, 0, self.n_bins - 1) | |
pole_vel_idx = np.clip(pole_vel_idx, 0, self.n_bins - 1) | |
return (cart_pos_idx, cart_vel_idx, pole_ang_idx, pole_vel_idx) | |
def get_action(self, state, evaluate=False): | |
"""Epsilon-greedy policy.""" | |
if not evaluate and np.random.random() < self.epsilon: | |
return np.random.choice([0, 1]) | |
state_disc = self.discretize_state(state) | |
return np.argmax(self.q_table[state_disc]) | |
def generate_episode(self): | |
"""Generate one episode using current policy.""" | |
episode = [] | |
state, _ = self.env.reset() | |
done = False | |
truncated = False | |
while not (done or truncated): | |
state_disc = self.discretize_state(state) | |
action = self.get_action(state) | |
next_state, reward, done, truncated, _ = self.env.step(action) | |
episode.append((state_disc, action, reward)) | |
state = next_state | |
return episode | |
def train(self, n_episodes=10000, epsilon_decay_rate = 0.9): | |
"""Train using Monte Carlo policy iteration.""" | |
episode_rewards = [] | |
for episode in range(n_episodes): | |
# Generate episode | |
current_episode = self.generate_episode() | |
episode_rewards.append(sum(r for _, _, r in current_episode)) | |
# Calculate returns for each state-action pair | |
G = 0 | |
for t in range(len(current_episode)-1, -1, -1): | |
state_disc, action, reward = current_episode[t] | |
G = self.gamma * G + reward | |
# Update Q-value | |
old_value = self.q_table[state_disc][action] | |
self.q_table[state_disc][action] += self.learning_rate * (G - old_value) | |
# Decay epsilon | |
self.epsilon = max(0.01, self.epsilon * epsilon_decay_rate) | |
# Print progress | |
if (episode + 1) % 100 == 0: | |
avg_reward = np.mean(episode_rewards[-100:]) | |
print(f"Episode {episode + 1}, Average Reward (last 100): {avg_reward:.2f}") | |
return episode_rewards | |
def evaluate(self, n_episodes=100): | |
"""Evaluate the learned policy.""" | |
total_rewards = [] | |
for _ in range(n_episodes): | |
state, _ = self.env.reset() | |
episode_reward = 0 | |
done = False | |
truncated = False | |
while not (done or truncated): | |
action = self.get_action(state, evaluate=True) | |
state, reward, done, truncated, _ = self.env.step(action) | |
episode_reward += reward | |
total_rewards.append(episode_reward) | |
return np.mean(total_rewards) | |
# Example usage | |
def main(): | |
agent = MonteCarloCartPole(n_bins=15, learning_rate=0.1, gamma=0.999, epsilon=1) | |
# Train the agent | |
print("Training the agent...") | |
episode_rewards = agent.train(n_episodes=50000, epsilon_decay_rate=0.9) | |
# Evaluate the learned policy | |
print("\nEvaluating the learned policy...") | |
avg_reward = agent.evaluate() | |
print(f"Average reward over 100 episodes: {avg_reward:.2f}") | |
# Visualize 10 episodes | |
print("\nVisualizing 10 episodes...") | |
env_render = gym.make('CartPole-v1', render_mode='human') | |
agent.env = env_render # Update agent's environment to rendering version | |
for episode in range(10): | |
state, _ = agent.env.reset() | |
episode_reward = 0 | |
done = False | |
truncated = False | |
while not (done or truncated): | |
agent.env.render() # Render the environment | |
action = agent.get_action(state, evaluate=True) | |
state, reward, done, truncated, _ = agent.env.step(action) | |
episode_reward += reward | |
print(f"Episode {episode + 1} reward: {episode_reward}") | |
# Close the environment | |
agent.env.close() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment