Created
February 18, 2025 16:21
-
-
Save JacobFV/a636687dac14ed869d0782680c9637b4 to your computer and use it in GitHub Desktop.
composite gym spaces
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from abc import ABC, abstractmethod | |
from typing import ( | |
Type, | |
Any, | |
List, | |
Dict, | |
Literal, | |
Optional, | |
Set, | |
Union, | |
get_type_hints, | |
get_origin, | |
get_args, | |
) | |
from pydantic import ValidationError, BaseModel, Field | |
from random import choice, randint, uniform | |
import string | |
from enum import Enum | |
import types # for types.UnionType | |
class Space(BaseModel, ABC): | |
"""Abstract base class for observation and action spaces.""" | |
@abstractmethod | |
def contains(self, x: Any) -> bool: | |
"""Check if x is a valid member of this space.""" | |
@abstractmethod | |
def sample(self) -> Any: | |
"""Randomly sample a valid value from this space.""" | |
class DiscreteSpace(Space): | |
"""Space with a finite set of possible values. | |
>>> space = DiscreteSpace(values=[1, 2, 3]) | |
>>> space.contains(2) | |
True | |
>>> space.contains(4) | |
False | |
>>> sample = space.sample() | |
>>> sample in [1, 2, 3] | |
True | |
""" | |
values: List[Any] | |
def contains(self, x: Any) -> bool: | |
return x in self.values | |
def sample(self) -> Any: | |
return choice(self.values) | |
class SingletonSpace(Space): | |
"""Space containing exactly one value. | |
>>> space = SingletonSpace(value=42) | |
>>> space.contains(42) | |
True | |
>>> space.contains(43) | |
False | |
>>> space.sample() | |
42 | |
""" | |
value: Any | |
def contains(self, x: Any) -> bool: | |
return x == self.value | |
def sample(self) -> Any: | |
return self.value | |
class IntegerSpace(Space): | |
"""Space for integer values with optional bounds. | |
>>> space = IntegerSpace(min_value=0, max_value=10) | |
>>> space.contains(5) | |
True | |
>>> space.contains(-1) | |
False | |
>>> space.contains(11) | |
False | |
>>> space.contains(3.14) | |
False | |
>>> sample = space.sample() | |
>>> 0 <= sample <= 10 | |
True | |
>>> isinstance(sample, int) | |
True | |
""" | |
min_value: Optional[int] = None | |
max_value: Optional[int] = None | |
sample_default_min: int = -100 # Default min when no bounds provided | |
sample_default_max: int = 100 # Default max when no bounds provided | |
def contains(self, x: Any) -> bool: | |
if not isinstance(x, int): | |
return False | |
if self.min_value is not None and x < self.min_value: | |
return False | |
if self.max_value is not None and x > self.max_value: | |
return False | |
return True | |
def sample(self) -> int: | |
min_val = ( | |
self.min_value if self.min_value is not None else self.sample_default_min | |
) | |
max_val = ( | |
self.max_value if self.max_value is not None else self.sample_default_max | |
) | |
return randint(min_val, max_val) | |
class FloatSpace(Space): | |
"""Space for float values with optional bounds. | |
>>> space = FloatSpace(min_value=0.0, max_value=1.0) | |
>>> space.contains(0.5) | |
True | |
>>> space.contains(-0.1) | |
False | |
>>> space.contains(1.1) | |
False | |
>>> space.contains("0.5") | |
False | |
>>> sample = space.sample() | |
>>> 0.0 <= sample <= 1.0 | |
True | |
>>> isinstance(sample, float) | |
True | |
""" | |
min_value: Optional[float] = None | |
max_value: Optional[float] = None | |
sample_default_min: float = -100.0 # Default min when no bounds provided | |
sample_default_max: float = 100.0 # Default max when no bounds provided | |
def contains(self, x: Any) -> bool: | |
if not isinstance(x, (int, float)): # Allow integers as valid floats | |
return False | |
if self.min_value is not None and x < self.min_value: | |
return False | |
if self.max_value is not None and x > self.max_value: | |
return False | |
return True | |
def sample(self) -> float: | |
min_val = ( | |
self.min_value if self.min_value is not None else self.sample_default_min | |
) | |
max_val = ( | |
self.max_value if self.max_value is not None else self.sample_default_max | |
) | |
return uniform(min_val, max_val) | |
class BooleanSpace(Space): | |
"""Space for boolean values. | |
>>> space = BooleanSpace() | |
>>> space.contains(True) | |
True | |
>>> space.contains(False) | |
True | |
>>> space.contains(1) | |
False | |
>>> isinstance(space.sample(), bool) | |
True | |
""" | |
def contains(self, x: Any) -> bool: | |
return isinstance(x, bool) | |
def sample(self) -> bool: | |
return choice([True, False]) | |
class StringSpace(Space): | |
"""Space for string values with optional constraints. | |
>>> space = StringSpace(min_length=2, max_length=4, allowed_chars=set('abc')) | |
>>> space.contains('abc') | |
True | |
>>> space.contains('a') | |
False | |
>>> space.contains('abcde') | |
False | |
>>> space.contains('def') | |
False | |
>>> sample = space.sample() | |
>>> 2 <= len(sample) <= 4 | |
True | |
>>> all(c in 'abc' for c in sample) | |
True | |
""" | |
min_length: Optional[int] = None | |
max_length: Optional[int] = None | |
allowed_chars: Set[str] = Field(default_factory=lambda: set(string.printable)) | |
sample_default_min_length: int = 0 # Default min length when no bounds provided | |
sample_default_max_length: int = 10 # Default max length when no bounds provided | |
def contains(self, x: str) -> bool: | |
if not isinstance(x, str): | |
return False | |
if self.min_length is not None and len(x) < self.min_length: | |
return False | |
if self.max_length is not None and len(x) > self.max_length: | |
return False | |
return all(c in self.allowed_chars for c in x) | |
def sample(self) -> str: | |
min_len = self.min_length or self.sample_default_min_length | |
max_len = self.max_length or self.sample_default_max_length | |
length = randint(min_len, max_len) | |
chars = list(self.allowed_chars) | |
return "".join(choice(chars) for _ in range(length)) | |
class DictSpace(Space): | |
"""Space for dictionary values with specified subspaces for each key. | |
This version supports keys of any type (including enums). | |
>>> # For a fixed dictionary with enum keys: | |
>>> from enum import Enum | |
>>> class MouseButton(Enum): | |
... LEFT = "left" | |
... RIGHT = "right" | |
>>> bool_space = BooleanSpace() | |
>>> space = DictSpace(spaces={button: bool_space for button in MouseButton}) | |
>>> sample = space.sample() | |
>>> all(isinstance(k, MouseButton) and isinstance(v, bool) for k, v in sample.items()) | |
True | |
""" | |
spaces: Dict[Any, Space] | |
def contains(self, x: Dict[Any, Any]) -> bool: | |
if not isinstance(x, dict): | |
return False | |
if set(x.keys()) != set(self.spaces.keys()): | |
return False | |
return all(space.contains(x[key]) for key, space in self.spaces.items()) | |
def sample(self) -> Dict[Any, Any]: | |
return {key: space.sample() for key, space in self.spaces.items()} | |
class TupleSpace(Space): | |
"""Space for fixed-length tuples with heterogeneous subspaces. | |
>>> subspaces = [IntegerSpace(min_value=0, max_value=10), IntegerSpace(min_value=0, max_value=10)] | |
>>> space = TupleSpace(subspaces=subspaces) | |
>>> x = space.sample() | |
>>> isinstance(x, tuple) and len(x) == 2 | |
True | |
>>> space.contains((5, 7)) | |
True | |
>>> space.contains((5,)) | |
False | |
""" | |
subspaces: List[Space] | |
def contains(self, x: Any) -> bool: | |
if not isinstance(x, tuple): | |
return False | |
if len(x) != len(self.subspaces): | |
return False | |
return all(space.contains(item) for space, item in zip(self.subspaces, x)) | |
def sample(self) -> tuple: | |
return tuple(space.sample() for space in self.subspaces) | |
class ListSpace(Space): | |
"""Space for list values with a specified subspace for elements. | |
>>> int_space = IntegerSpace(min_value=0, max_value=10) | |
>>> space = ListSpace(subspace=int_space, min_length=2, max_length=4) | |
>>> space.contains([1, 2, 3]) | |
True | |
>>> space.contains([1]) | |
False | |
>>> space.contains([1, 2, 3, 4, 5]) | |
False | |
>>> space.contains([1, -1, 3]) | |
False | |
>>> sample = space.sample() | |
>>> 2 <= len(sample) <= 4 | |
True | |
>>> all(0 <= x <= 10 for x in sample) | |
True | |
""" | |
subspace: Space | |
min_length: Optional[int] = None | |
max_length: Optional[int] = None | |
sample_default_min_length: int = 0 # Default min length when no bounds provided | |
sample_default_max_length: int = 10 # Default max length when no bounds provided | |
def contains(self, x: List[Any]) -> bool: | |
if not isinstance(x, list): | |
return False | |
if self.min_length is not None and len(x) < self.min_length: | |
return False | |
if self.max_length is not None and len(x) > self.max_length: | |
return False | |
return all(self.subspace.contains(item) for item in x) | |
def sample(self) -> List[Any]: | |
min = self.min_length or self.sample_default_min_length | |
max = self.max_length or self.sample_default_max_length | |
length = randint(min, max) | |
return [self.subspace.sample() for _ in range(length)] | |
class UnionSpace(Space): | |
"""Space that accepts values from any of its component spaces. | |
>>> int_space = IntegerSpace(min_value=0, max_value=10) | |
>>> str_space = StringSpace(allowed_chars=set('abc')) | |
>>> space = UnionSpace(spaces=[int_space, str_space]) | |
>>> space.contains(5) | |
True | |
>>> space.contains('abc') | |
True | |
>>> space.contains(-1) | |
False | |
>>> space.contains('def') | |
False | |
>>> sample = space.sample() | |
>>> isinstance(sample, (int, str)) | |
True | |
""" | |
spaces: List[Space] | |
def contains(self, x: Any) -> bool: | |
return any(space.contains(x) for space in self.spaces) | |
def sample(self) -> Any: | |
return choice(self.spaces).sample() | |
class StructuredSpace(Space): | |
"""Space for validating and sampling Pydantic model instances. | |
>>> from pydantic import BaseModel | |
>>> class Point(BaseModel): | |
... x: int | |
... y: int | |
>>> space = StructuredSpace(model=Point) | |
>>> space.contains(Point(x=1, y=2)) | |
True | |
>>> space.contains(2) | |
False | |
>>> sample = space.sample() | |
>>> isinstance(sample, Point) | |
True | |
""" | |
model: Type[BaseModel] | |
_field_spaces: Dict[str, Space] = {} | |
def __init__(self, **data): | |
super().__init__(**data) | |
self._field_spaces = {} | |
self._analyze_model_fields() | |
def _analyze_model_fields(self): | |
"""Recursively analyze model fields and create appropriate spaces.""" | |
type_hints = get_type_hints(self.model) | |
for field_name, field_type in type_hints.items(): | |
if field_name.startswith("_"): | |
continue | |
if hasattr(field_type, "__metadata__"): | |
# Field already has a space annotation | |
self._field_spaces[field_name] = field_type.__metadata__[0] | |
else: | |
self._field_spaces[field_name] = self._create_space_for_type(field_type) | |
def _normalize_type(self, type_: Any) -> tuple: | |
""" | |
Convert a type annotation into a structured tuple representation for easier processing. | |
The returned tuple has a "kind" tag as its first element: | |
- ("optional", inner_type): for Optional types (e.g. float | None) | |
- ("union", (arg1, arg2, ...)): for unions with more than two types | |
- ("list", element_type): for list[...] types | |
- ("tuple", (elem1, elem2, ...)): for fixed-length tuples | |
- ("var_tuple", element_type): for variable-length tuples (using Ellipsis) [not supported] | |
- ("dict", key_type, value_type): for dict[...] types | |
- ("literal", [literal1, literal2, ...]): for Literals | |
- ("scalar", type): for all other types | |
""" | |
origin = get_origin(type_) | |
if origin in (Union, types.UnionType): | |
args = get_args(type_) | |
if len(args) == 2 and type(None) in args: | |
# Optional type detected. | |
non_none_type = next(t for t in args if t is not type(None)) | |
return ("optional", non_none_type) | |
return ("union", args) | |
elif origin is list: | |
(elt,) = get_args(type_) | |
return ("list", elt) | |
elif origin is tuple: | |
args = get_args(type_) | |
if len(args) == 2 and args[1] is Ellipsis: | |
return ("var_tuple", args[0]) | |
return ("tuple", args) | |
elif origin is dict: | |
key_type, value_type = get_args(type_) | |
return ("dict", key_type, value_type) | |
elif origin is Literal: | |
return ("literal", list(get_args(type_))) | |
else: | |
return ("scalar", type_) | |
def _create_space_for_type(self, type_: Any) -> Space: | |
""" | |
Create an appropriate space for a given type using its normalized representation. | |
Supports: | |
- Optional types (e.g. float | None) | |
- Unions | |
- Lists, Tuples (fixed-length) | |
- Dicts (with keys that are either str or Enum) | |
- Literals | |
- Nested Pydantic models | |
- Basic types (str, int, float, bool) | |
- Enums (returns a DiscreteSpace with all enum members) | |
""" | |
normalized = self._normalize_type(type_) | |
kind = normalized[0] | |
if kind == "optional": | |
inner_type = normalized[1] | |
inner_space = self._create_space_for_type(inner_type) | |
return UnionSpace(spaces=[inner_space, SingletonSpace(value=None)]) | |
elif kind == "union": | |
args = normalized[1] | |
return UnionSpace(spaces=[self._create_space_for_type(t) for t in args]) | |
elif kind == "list": | |
element_type = normalized[1] | |
return ListSpace(subspace=self._create_space_for_type(element_type)) | |
elif kind == "tuple": | |
subspaces = [self._create_space_for_type(t) for t in normalized[1]] | |
return TupleSpace(subspaces=subspaces) | |
elif kind == "var_tuple": | |
raise ValueError("Variable-length tuples are not supported.") | |
elif kind == "dict": | |
key_type, value_type = normalized[1], normalized[2] | |
if key_type is str: | |
return DictSpace(spaces={}) | |
elif isinstance(key_type, type) and issubclass(key_type, Enum): | |
keys = list(key_type) | |
spaces_dict = { | |
key: self._create_space_for_type(value_type) for key in keys | |
} | |
return DictSpace(spaces=spaces_dict) | |
else: | |
raise ValueError(f"Dict keys must be strings or Enum, got {key_type}") | |
elif kind == "literal": | |
return DiscreteSpace(values=normalized[1]) | |
elif kind == "scalar": | |
base = normalized[1] | |
if isinstance(base, type) and issubclass(base, BaseModel): | |
return StructuredSpace(model=base) | |
if isinstance(base, type) and issubclass(base, Enum): | |
return DiscreteSpace(values=list(base)) | |
if base is str: | |
return StringSpace() | |
if base is int: | |
return IntegerSpace() | |
if base is float: | |
return FloatSpace() | |
if base is bool: | |
return BooleanSpace() | |
raise ValueError(f"Unsupported type: {base}") | |
else: | |
raise ValueError(f"Unsupported type kind: {kind}") | |
def contains(self, x: Any) -> bool: | |
# Make sure x is actually an instance of the model | |
if not isinstance(x, self.model): | |
return False | |
# Strictly check each field with its corresponding child space | |
for field_name, space in self._field_spaces.items(): | |
value = getattr(x, field_name) | |
if not space.contains(value): | |
return False | |
return True | |
def sample(self) -> BaseModel: | |
sample_data = { | |
field_name: space.sample() | |
for field_name, space in self._field_spaces.items() | |
} | |
return self.model(**sample_data) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment