Last active
November 5, 2021 13:50
-
-
Save maxpagels/2815ad2faa9058fb6f3e3ed7b3d487ff to your computer and use it in GitHub Desktop.
explore-first-knapsack.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
# -- Globals | |
N_ARMS = 4 | |
N_RESOURCES = 3 | |
N_EXPLORATION_ROUNDS = 1 | |
CONSUMPTION_PER_PLAY = [ | |
[4, 0, 0], # Example: first arm consumes 4 of resource 1 and 0 of resources 2 & 3 | |
[1, 5, 1], | |
[6, 1, 4], | |
[9, 1, 0] | |
] | |
BUDGETS = [40, 71, 82] | |
# -- Exploration-first | |
# Assume you have explored for N rounds and recorded | |
# empirical mean reward probabilities per arm | |
REWARD_PROBS = [ | |
0.4, # 1 reward with probability 0.4, 0 with probability 0.6 | |
0.2, | |
0.1, | |
0.5 | |
] | |
# -- Exploitation phase | |
# Idea: solve a MIP to maximise reward subject to budget constraints, | |
# thereby getting an empirical pmf of action probabilities | |
from ortools.linear_solver import pywraplp | |
solver = pywraplp.Solver.CreateSolver('SCIP') | |
# Variables: one integer variable per arm, denoting the number of plays | |
actions = [] | |
for i in range(0, N_ARMS): | |
actions.append(solver.IntVar(0, solver.infinity(), f"action_{i}")) | |
# Constraints: the remaining budget of any resource must not be | |
# exhausted | |
for j, b in enumerate(BUDGETS): | |
rowsums = [] | |
for i, a in enumerate(actions): | |
rowsums.append(a * CONSUMPTION_PER_PLAY[i][j]) | |
solver.Add(solver.Sum([i for i in rowsums]) <= b) | |
# Objective: maximise reward over actions subject to constraints | |
solver.Maximize(solver.Sum([a * REWARD_PROBS[i] for i, a in enumerate(actions)])) | |
status = solver.Solve() | |
if status == pywraplp.Solver.OPTIMAL: | |
print('MIP solution:') | |
print('Simulated reward =', solver.Objective().Value()) | |
plays = [] | |
for a in actions: | |
plays.append(a.solution_value()) | |
print(plays) | |
else: | |
print('The problem does not have an optimal solution.') | |
# -- Print example predictions | |
dist = np.array(plays) | |
pmf = dist / dist.sum() | |
print("Example predictions, 20 rounds:") | |
for i in range(0, 20): | |
print(f"round {i}, play arm {np.random.choice(np.arange(N_ARMS), p=pmf)}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment