Skip to content

Instantly share code, notes, and snippets.

View YannBerthelot's full-sized avatar

YannBerthelot

View GitHub Profile
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
import optax
from flax.linen.initializers import constant, orthogonal
from typing import Sequence, NamedTuple
from flax.training.train_state import TrainState
import distrax
from gymnax.wrappers.purerl import LogWrapper
@YannBerthelot
YannBerthelot / taxi_solver.py
Created September 20, 2023 07:52
Compute the values of states in taxi driver MDP using linear equation solver
import numpy as np
_coefficients_equation_A = np.array(
[0.9 * (13 / 48) - 1, 0.9 * (3 / 8), 0.9 * (17 / 48)]
)
_coefficients_equation_B = np.array(
[0.9 * (9 / 32), 0.9 * (7 / 16) - 1, 0.9 * (9 / 32)]
)
_coefficients_equation_C = np.array(
"""
'Realistic' config to be used for testing
"""
import numpy as np
from easygrid.config.realistic_config import grid_config, mg_config
from easygrid.data.data_utils import DATA_FOLDER, get_indexes
from easygrid.types import BatteryConfig, LoadConfig, PvConfig
INDEXES = get_indexes(DATA_FOLDER)
"""
'Realistic' config to be used for testing
"""
from functools import partial
import numpy as np
from easygrid.math_utils import get_hourly_variation
from easygrid.types import (
BatteryConfig,
import os
from typing import Tuple, Union
import numpy as np
import numpy.typing as npt
import torch
import pathlib
import pickle
def t(x):
return torch.from_numpy(x).float()
def compute_mach(V):
norm_v = np.linalg.norm([[V[0], V[1]])
return norm_v / 343.0, norm_v
def newton(theta, gamma, thrust, lift, drag, P, m):
cos_theta = cos(theta)
sin_theta = sin(theta)
lift_drag_thrust = np.array([lift, drag, thrust])
# Z-axis
F_z = np.sum(lift_drag_thrust * np.array([cos_theta, -sin(gamma), sin_theta])) - P
# X-axis
F_x = np.sum(lift_drag_thrust * np.array([-sin_theta, -abs(cos(gamma)), cos_theta]))
# Compute Acceleration using a = F/m
return np.array([F_x / m, F_z / m])
import numpy as np
from numba import njit
MACH_CRITIC = 0.78
C_X_MIN = 0.095
@njit(nogil=True, fastmath=True)
def compute_cx(alpha, mach):
"""
import os
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.env_checker import check_env
from gym_environment import PlaneEnv
if __name__ == "__main__":
wrappable_env = PlaneEnv(task="level-flight")
# check if the env satisfies gym requirements
import os
import time
from configparser import ConfigParser
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from gym_environment import PlaneEnv
parser = ConfigParser()
thisfolder = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(thisfolder, "config.ini")