Skip to content

Instantly share code, notes, and snippets.

@myxyy
Created March 19, 2023 11:24
Show Gist options
  • Save myxyy/74cee9bd1bd88ce44b1d4ca8ee399564 to your computer and use it in GitHub Desktop.
Save myxyy/74cee9bd1bd88ce44b1d4ca8ee399564 to your computer and use it in GitHub Desktop.
import gymnasium as gym
import torch
import torch.nn as nn
class Q(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(4,64)
self.fc2 = nn.Linear(64,128)
self.fc3 = nn.Linear(128,64)
self.fc4 = nn.Linear(64,2)
self.relu = nn.ReLU()
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
x = self.relu(x)
x = self.fc4(x)
return x
def epsilon_greedy(q, epsilon):
if torch.rand(1).item() < epsilon:
return 0 if torch.rand(1).item() < 0.5 else 1
else:
return torch.argmax(q).item()
model = Q()
gamma = 0.9
epsilon = 0.1
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
env = gym.make("CartPole-v1", render_mode="human")
observation, info = env.reset()
for _ in range(1000000):
state = torch.from_numpy(observation).reshape(1,-1)
q = model(state)
q_hat = q.clone()
action = epsilon_greedy(q, epsilon)
observation, reward, terminated, truncated, info = env.step(action)
state_next = torch.from_numpy(observation).reshape(1,-1)
if terminated or truncated:
q_hat[:,action] = reward
else:
q_hat[:,action] = reward + gamma*model(state_next).max()
loss = criterion(q, q_hat)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("\rloss:{:.5f}".format(loss.item()),end="")
if terminated or truncated:
observation, info = env.reset()
env.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment