Skip to content

Instantly share code, notes, and snippets.

@maxpagels
Last active June 12, 2018 13:42
Show Gist options
  • Save maxpagels/a8ff4ad999613c52e4fe378a2f73152c to your computer and use it in GitHub Desktop.
Save maxpagels/a8ff4ad999613c52e4fe378a2f73152c to your computer and use it in GitHub Desktop.
# Python implementation of the EXP3 (Exponential weight for Exploration and Exploitation)
# algorithm for solving adversarial bandit problems. Based on the original paper:
# http://rob.schapire.net/papers/AuerCeFrSc01.pdf
import numpy as np
import time
np.random.seed(12345)
n_arms = 4
reward_probs = np.array([0.0202, 0.02, 0.015, 0.001]) # not really known, this is for simulation
timesteps = 100000
gamma = 0.1 # hyperparameter
w = np.ones(n_arms)
a = np.arange(n_arms)
total_picks = np.zeros(n_arms) # bookkeeping
total_rewards = np.zeros(n_arms)
def get_reward(idx):
if np.random.rand() <= reward_probs[idx]:
return 1
else:
return 0
then = time.time()
for t in range(timesteps):
p = (1 - gamma) * ( w / (np.sum(w)) ) + (gamma / n_arms)
idx = np.random.choice(a, p=p)
total_picks[idx] += 1
reward = get_reward(idx)
total_rewards[idx] += reward
rewards = np.zeros(n_arms)
rewards[idx] = reward / p[idx]
w = w * np.exp((gamma * rewards) / n_arms)
now = time.time()
print(("Total running time (in seconds) for {0}"
" timesteps ({0} choices and {0} rewards): {1}"
" ({2} per second)").format(timesteps, now - then, timesteps / (now - then)))
print("Total picks per arm: ", total_picks)
print("Total rewards per arm: ", total_rewards)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment