Last active
June 14, 2023 02:34
-
-
Save amit-gshe/4d6aa66bde0c721067d02f865203dd96 to your computer and use it in GitHub Desktop.
diambra debug for the wrong opponent
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
characters = {0: "Kyo", 1: "Benimaru", 2: "Daimon", 3: "Terry", 4: "Andy", 5: "Joe", | |
6: "Ryo", 7: "Robert", 8: "Yuri", 9: "Leona", 10: "Ralf", 11: "Clark", 12: "Athena", | |
13: "Kensou", 14: "Chin", 15: "Chizuru", 16: "Mai", 17: "King", 18: "Kim", 19: "Chang", 20: "Choi", | |
21: "Yashiro", 22: "Shermie", 23: "Chris", 24: "Yamazaki", 25: "Mary", | |
26: "Billy", 27: "Iori", 28: "Mature", 29: "Vice", 30: "Heidern", | |
31: "Takuma", 32: "Saisyu", 33: "Heavy-D!", 34: "Lucky", 35: "Brian", | |
36: "Eiji", 37: "Kasumi", 38: "Shingo", 39: "Rugal", 40: "Geese", | |
41: "Krauser", 42: "Mr.Big", 43: "Goenitz", 44: "Orochi"} | |
# DIAMBRA Gym base class providing frame and additional info as observations | |
class DiambraGym1P(DiambraGymHardcore1P): | |
def __init__(self, env_settings): | |
super().__init__(env_settings) | |
# Dictionary observation space | |
observation_space_dict = {} | |
observation_space_dict['frame'] = spaces.Box(low=0, high=255, | |
shape=(self.hwc_dim[0], | |
self.hwc_dim[1], | |
self.hwc_dim[2]), | |
dtype=np.uint8) | |
player_spec_dict = {} | |
# Adding env additional observations (side-specific) | |
for k, v in self.ram_states.items(): | |
if k == "stage": | |
continue | |
if k[-2:] == "P1": | |
knew = "own" + k[:-2] | |
else: | |
knew = "opp" + k[:-2] | |
# Discrete spaces (binary / categorical) | |
if v[0] == 0 or v[0] == 2: | |
player_spec_dict[knew] = spaces.Discrete(v[2] + 1) | |
elif v[0] == 1: # Box spaces | |
player_spec_dict[knew] = spaces.Box(low=v[1], high=v[2], | |
shape=(1,), dtype=np.int32) | |
else: | |
raise RuntimeError( | |
"Only Discrete (Binary/Categorical) | Box Spaces allowed") | |
actions_dict = { | |
"move": spaces.Discrete(self.n_actions[0]), | |
"attack": spaces.Discrete(self.n_actions[1]) | |
} | |
player_spec_dict["actions"] = spaces.Dict(actions_dict) | |
observation_space_dict["P1"] = spaces.Dict(player_spec_dict) | |
observation_space_dict["stage"] = spaces.Box(low=self.ram_states["stage"][1], | |
high=self.ram_states["stage"][2], | |
shape=(1,), dtype=np.int8) | |
self.observation_space = spaces.Dict(observation_space_dict) | |
def ram_states_integration(self, frame, data): | |
observation = {} | |
observation["frame"] = frame | |
observation["stage"] = np.array([data["stage"]], dtype=np.int8) | |
player_spec_dict = {} | |
# Adding env additional observations (side-specific) | |
for k, v in self.ram_states.items(): | |
if k == "stage": | |
continue | |
if k[-2:] == self.player_side: | |
knew = "own" + k[:-2] | |
else: | |
knew = "opp" + k[:-2] | |
# Box spaces | |
if v[0] == 1: | |
player_spec_dict[knew] = np.array([data[k]], dtype=np.int32) | |
else: # Discrete spaces (binary / categorical) | |
player_spec_dict[knew] = data[k] | |
actions_dict = { | |
"move": data["moveAction{}".format(self.player_side)], | |
"attack": data["attackAction{}".format(self.player_side)], | |
} | |
player_spec_dict["actions"] = actions_dict | |
observation["P1"] = player_spec_dict | |
return observation | |
def step(self, action): | |
self.frame, reward, done, data = self.step_complete(action) | |
observation = self.ram_states_integration(self.frame, data) | |
env_rank = self.env_settings.rank | |
stage = observation['stage'][0] | |
# stage reward | |
if data["stage_done"]: | |
stage = stage - 1 | |
stage_reward = stage * 100 | |
reward = stage_reward | |
print(f"({env_rank}) *** eatra stage {stage} reward {stage_reward} ***") | |
if reward != 0.0: | |
own_wins = observation['P1']['ownWins'][0] | |
opp_wins = observation['P1']['oppWins'][0] | |
own_health = observation['P1']['ownHealth'][0] | |
opp_health = observation['P1']['oppHealth'][0] | |
own_character = characters[observation['P1']['ownChar']] | |
opp_character = characters[observation['P1']['oppChar']] | |
print( | |
# f"env[{env_rank}] stage[{stage}] wins[{own_wins}:{opp_wins}] {own_character}:{opp_character} {reward}[{own_health}:{opp_health}]") | |
f"({env_rank}) {stage}-{own_wins}-{opp_wins} [{own_health}:{reward}:{opp_health}] {own_character}[{observation['P1']['ownChar']}]:{opp_character}[{observation['P1']['oppChar']}]") | |
return observation, reward, done,\ | |
{"round_done": data["round_done"], "stage_done": data["stage_done"], | |
"game_done": data["game_done"], "ep_done": data["ep_done"], "env_done": data["env_done"]} | |
# Reset the environment | |
def reset(self): | |
self.frame, data, self.player_side = self.arena_engine.reset() | |
observation = self.ram_states_integration(self.frame, data) | |
return observation |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import diambra.arena | |
from diambra.arena.stable_baselines3.sb3_utils import linear_schedule, AutoSave | |
from stable_baselines3 import PPO | |
from stable_baselines3.common.evaluation import evaluate_policy | |
from diambra.arena.stable_baselines3.make_sb3_env import make_sb3_env | |
from stable_baselines3.common.callbacks import BaseCallback | |
class RenderCallback(BaseCallback): | |
def __init__(self, env): | |
super(RenderCallback, self).__init__() | |
self.env = env | |
def _on_step(self) -> bool: | |
self.env.render() | |
return True | |
if __name__ == "__main__": | |
# Settings | |
settings = {} | |
settings["hardcore"] = False | |
settings["player"] = "P1" | |
# settings["step_ratio"] = 1 | |
settings["frame_shape"] = (128, 128, 1) | |
settings["show_final"] = False | |
settings["action_space"] = "multi_discrete" | |
# settings["action_space"] = "discrete" | |
settings["attack_but_combination"] = False | |
settings["difficulty"] = 2 | |
# settings["characters"] = ("Kyo", "Benimaru", "Daimon") | |
settings["characters"] = ("Kyo", "Benimaru", "Chang") | |
settings["fighting_style"] = 1 | |
# Wrappers Settings | |
wrappers_settings = {} | |
wrappers_settings["reward_normalization"] = True | |
wrappers_settings["actions_stack"] = 8 | |
wrappers_settings["frame_stack"] = 5 | |
wrappers_settings["dilation"] = 1 | |
wrappers_settings["scale"] = True | |
wrappers_settings["exclude_image_scaling"] = True | |
wrappers_settings["flatten"] = True | |
wrappers_settings["filter_keys"] = ["stage", "P1_ownHealth", "P1_oppHealth", | |
"P1_ownChar", "P1_oppChar", | |
"P1_actions_move", "P1_actions_attack", | |
"P1_ownBarType", "P1_oppBarType", | |
"P1_ownPowerBar", "P1_oppPowerBar", | |
"P1_ownSpecialAttacks", "P1_oppSpecialAttacks", | |
] | |
# Create environment | |
env, num_envs = make_sb3_env( | |
"kof98umh", settings, wrappers_settings, seed=98) | |
learning_rate = 1e-04 # 1e-04 | |
clip_range = 0.3 # 0.3 | |
n_epochs = 3 | |
n_steps = 128 | |
gamma = 0.99 | |
gae_lambda = 0.9 | |
batch_size = 128 | |
ent_coef = 0.02 # 0.01 | |
policy_kwargs = dict(net_arch=[dict(vf=[256, 128, 128], pi=[128, 64, 64])]) | |
PATH = "kof2/" | |
model_path = PATH + "saved_model" | |
print("Training a new model") | |
agent = PPO('MultiInputPolicy', env, verbose=2, seed=98, tensorboard_log=PATH + "logs/", n_epochs=n_epochs, | |
policy_kwargs=policy_kwargs, | |
n_steps=n_steps, batch_size=batch_size, gamma=gamma, gae_lambda=gae_lambda, ent_coef=ent_coef, | |
clip_range=linear_schedule(clip_range, clip_range/2), learning_rate=linear_schedule( | |
learning_rate, learning_rate/10)) | |
print("Policy architecture:") | |
print(agent.policy) | |
# Train the agent | |
print("Begin training the agent") | |
agent.learn(total_timesteps=4000000, progress_bar=False, callback=RenderCallback(env)) | |
print("Train the agent finished") | |
agent.save(model_path) | |
mean_reward, std_reward = evaluate_policy(agent, agent.get_env(), n_eval_episodes=10) | |
print("Reward: {} (avg) ± {} (std)".format(mean_reward, std_reward)) | |
env.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment