Last active
May 28, 2024 00:05
-
-
Save zkytony/73dce5f0832c6ded3197bd68f27f99eb to your computer and use it in GitHub Desktop.
quick example for reachable belief and value iteration with infinite horizon using pomdp-py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import pprint | |
import pomdp_py | |
import seaborn as sns | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
from pomdp_py.algorithms.value_function import expected_reward, belief_observation_model | |
from pomdp_py.problems.tiger.tiger_problem import TigerProblem, TigerState | |
PRECISION = 4 # precision of belief probabilities | |
def _to_tuple(b: dict) -> tuple: | |
"""given a belief dictionary return a tuple representation""" | |
return tuple( | |
sorted(((s, round(b[s], PRECISION)) for s in b), key=lambda elm: str(elm[0])) | |
) | |
def _to_dict(b_tuple: tuple) -> dict: | |
"""given a tuple belief, return a dictionary from state to prob""" | |
return {elm[0]: elm[1] for elm in b_tuple} | |
def reachable_belief(b, A, Z, T, O) -> set: | |
"""given an initial belief b, computes the set of beliefs reachable from b | |
under the transition defined by A, Z, T, O""" | |
reachable_set = set({_to_tuple(b)}) | |
transitions = {} # maps from b,a,z to b' | |
_reachable_belief(b, A, Z, T, O, reachable_set, transitions) | |
return reachable_set, transitions | |
def _reachable_belief(b, A, Z, T, O, reachable_set, transitions) -> None: | |
"""Given initial belief b of a POMDP | |
<S,A,Z,T,O,R>, return a set of reachable | |
belief states. Only A, Z, T, O are necessary | |
to be passed in""" | |
for a in A: | |
for z in Z: | |
b_next = pomdp_py.belief_update(b, a, z, T, O) | |
b_next_tuple = _to_tuple(b_next) | |
transitions[(_to_tuple(b), a, z)] = b_next_tuple | |
if b_next_tuple not in reachable_set: | |
reachable_set.add(b_next_tuple) | |
_reachable_belief(b_next, A, Z, T, O, reachable_set, transitions) | |
def value_iteration_infinite_horizon( | |
Rb0, A, Z, T, O, R, gamma, belief_transitions, max_iter=1000, epsilon=1e-4 | |
): | |
"""Perform value iteration with infinite horizon over reachable belief states; | |
Also returns the optimal policy""" | |
V = {b: random.uniform(-5, 5) for b in Rb0} | |
pi = {} | |
for step in range(max_iter): | |
Vp = {} | |
for b in Rb0: | |
Vp[b], pi[b] = _value( | |
_to_dict(b), A, Z, T, O, R, gamma, V, belief_transitions | |
) | |
diff = _value_difference(V, Vp) | |
if diff < epsilon: | |
print(f"Value Iteration converged after {step+1} iterations.") | |
return V, pi | |
V = Vp | |
return V, pi | |
def _value(b, A, Z, T, O, R, gamma, V, belief_transitions): | |
"""Compute value at belief b making use of future values from V.""" | |
max_qval = float("-inf") | |
best_action = None | |
for a in A: | |
qval = _qvalue(b, a, Z, T, O, R, gamma, V, belief_transitions) | |
if qval > max_qval: | |
max_qval = qval | |
best_action = a | |
return max_qval, best_action | |
def _qvalue(b, a, Z, T, O, R, gamma, V, belief_transitions): | |
"""Compute qvalue at b, a making use of future values from V""" | |
r = expected_reward(b, R, a, T) | |
expected_future_value = 0.0 | |
for z in Z: | |
# compute Pr(o|b,a)*V(b') | |
prob_z = belief_observation_model(z, b, a, T, O) | |
# If o has non-zero probability | |
if prob_z > 0: | |
next_b = belief_transitions[(_to_tuple(b), a, z)] | |
next_value = V[next_b] | |
expected_future_value += prob_z * next_value | |
return r + gamma * expected_future_value | |
def _value_difference(V1, V2): | |
diffs = [abs(V1[b] - V2[b]) for b in V1] | |
return max(diffs) | |
def _create_pomdp(noise=0.15, init_state="tiger-left"): | |
tiger = TigerProblem( | |
noise, | |
TigerState(init_state), | |
pomdp_py.Histogram( | |
{TigerState("tiger-left"): 0.5, TigerState("tiger-right"): 0.5} | |
), | |
) | |
T = tiger.agent.transition_model | |
O = tiger.agent.observation_model | |
S = list(T.get_all_states()) | |
Z = list(O.get_all_observations()) | |
A = list(tiger.agent.policy_model.get_all_actions()) | |
R = tiger.agent.reward_model | |
gamma = 0.95 | |
b0 = tiger.agent.belief | |
s0 = tiger.env.state | |
return b0, s0, S, A, Z, T, O, R, gamma | |
def test(): | |
b0, s0, S, A, Z, T, O, R, gamma = _create_pomdp() | |
Rb0, btrans = reachable_belief(b0, A, Z, T, O) | |
print(f"Reachable belief from {b0} has {len(Rb0)} belief states:") | |
pprint.pp(Rb0) | |
value, policy = value_iteration_infinite_horizon(Rb0, A, Z, T, O, R, gamma, btrans) | |
print("Value function:") | |
pprint.pp(value) | |
print("Policy:") | |
pprint.pp(policy) | |
# Let's make a plot | |
data = {"b_tiger_left": [], "value": [], "action": []} | |
for b in value: | |
data["b_tiger_left"].append(_to_dict(b)[TigerState("tiger-left")]) | |
data["value"].append(value[b]) | |
data["action"].append(policy[b]) | |
sns.scatterplot(pd.DataFrame(data), x="b_tiger_left", y="value", hue="action") | |
plt.show() | |
if __name__ == "__main__": | |
test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Output:
Plot:
