Skip to content

Instantly share code, notes, and snippets.

@kwsp
Created June 24, 2024 18:33
Show Gist options
  • Save kwsp/69440134a4bee0dcbc05cf97e1446bfd to your computer and use it in GitHub Desktop.
Save kwsp/69440134a4bee0dcbc05cf97e1446bfd to your computer and use it in GitHub Desktop.
InitMixin
"""
typed_mat_mixin.py provides a mixin `InitMixin` that tries
to generate a `.init` method to a typed Dataclass to automatically
deserialize .mat files exported from MATLAB to a Python Dataclass.
Supported types:
- strings
- int, float, np.uint8, and most np typedefs
- np.ndarray (1D and 2D tested)
- Type arrays as np.array[ndims, dtype]
- For example, a 2D array of type uint8 would be `np.ndarray[2, np.uint8]`
Usage:
```python
from dataclasses import dataclass
import numpy as np
import scipy.io as sio
from typed_mat_mixin import InitMixin
@dataclass
class InnerObject(InitMixin):
var3: int
arr1: np.ndarray[1, np.float64]
@dataclass
class Sequence(InitMixin):
var1: int
var2: float
path "something.mat"
seq = Sequence.init(sio.loadmat(path))
```
"""
# %%
from typing import Any, Callable, TypeVar, GenericAlias, Optional, get_args, TypeGuard
import types
import inspect
import numbers
import numpy as np
# %%
def npdtype(cls: type | None = None, align=False):
"""
Converts an annotated Python class into a numpy.dtype object.
https://numpy.org/doc/stable/reference/generated/numpy.dtype.html
Propose to support using numpy.dtype as a class decorator.
The type annotation can be
* A type compatible with np.dtype, such as int, np.int, np.float64
* A 2-tuple of (type, shape)
"""
def wrap(cls):
anno = inspect.get_annotations(cls)
dtype = []
for attrname, attrtype in anno.items():
if isinstance(attrtype, tuple):
dtype.append((attrname, *attrtype))
else:
dtype.append((attrname, attrtype))
return np.dtype(dtype, align=align)
if cls is None:
return wrap
return wrap(cls)
# %%
# Until https://peps.python.org/pep-0646/ gets full support, we cannot explicitly
# type array dimensions.
# Here's a hack to make declarations simpler and add custom init functions for arrays # of different shapes and types
T = TypeVar("T", bound=object)
_Type = TypeVar("_Type") # generic type used throughout module
_InitFuncT = Callable[[Any], _Type] # generic init function type
_init_func_cache: dict[_Type, _InitFuncT] = {}
# %%
def _get_optional_type(tp: type) -> Optional[type]:
"""
If tp is in {Optional[T], T | None, Union[T, None]}, return T
else return None
"""
args = get_args(tp)
match args:
case type(), types.NoneType:
return args[0]
case types.NoneType, type():
return args[1]
return None
def _get_optional_type_test():
assert _get_optional_type(Optional[int]) is int
assert _get_optional_type(Optional[str]) is str
assert _get_optional_type(str | None) is str
assert _get_optional_type(None | str) is str
assert _get_optional_type(str) is None
# %%
def _is_str(inp) -> bool:
"""
Returns true if `inp` is a str, bytes, or np.ndarray[Any, np.character]
"""
match inp:
case np.ndarray():
return issubclass(inp.dtype.type, np.character)
case bytes() | str():
return True
return False
def _is_str_test():
assert _is_str("Hello")
assert _is_str(b"Hello")
assert _is_str(np.array("Hello"))
assert _is_str(np.array(["Hello"]))
assert not _is_str(1)
assert not _is_str([1, 2])
assert not _is_str(np.ndarray([1, 2]))
def _type_is_str(tp) -> bool:
return tp in (str, bytes)
def _type_is_str_test():
assert _type_is_str(str)
assert _type_is_str(bytes)
# %%
def _is_sequence_not_str(inp) -> bool:
return hasattr(inp, "__iter__") and not _is_str(inp)
def _is_sequence_not_str_test():
inp = [1, 2, 3, 4]
assert _is_sequence_not_str(inp)
inp = np.array([1, 2, 3, 4])
assert _is_sequence_not_str(inp)
inp = np.array([[1, 2, 3, 4]])
assert _is_sequence_not_str(inp)
inp = "hello"
assert not _is_sequence_not_str(inp)
inp = b"hello"
assert not _is_sequence_not_str(inp)
inp = np.array("hello")
assert not _is_sequence_not_str(inp)
inp = np.array(["hello"])
assert not _is_sequence_not_str(inp)
inp = np.array(b"hello")
assert not _is_sequence_not_str(inp)
inp = np.array([b"hello"])
assert not _is_sequence_not_str(inp)
# %%
def _get_scalar_init_func(_type: _Type) -> _Type:
def _init(inp):
if isinstance(inp, np.ndarray):
# handle cases of np.array([[ 1.0 ]])
if inp.size != 1:
raise ValueError(f"Input array must be size 1, got", inp)
inp = inp.item()
else:
# handle [[ 1.0 ]]
while _is_sequence_not_str(inp):
inp = inp[0]
return _type(inp)
return _init
def _get_string_init_func(_type: _Type) -> _Type:
def _init(inp):
if _is_sequence_not_str(inp):
assert len(inp) == 1
return _init(inp[0])
# squeeze to handle np.array(["hello"])
if isinstance(inp, np.ndarray):
inp = inp.squeeze()
return _type(inp)
return _init
def _is_tuple_of_int(obj) -> TypeGuard[tuple[int, ...]]:
"Check the given object is a tuple of int, e.g. (2, 3, 4)"
assert isinstance(obj, tuple)
return all(isinstance(a, int) for a in obj)
import builtins
def _get_ndarray_init_func(_type: type[np.ndarray]):
"""
Take a specialization of a np.ndarray generic type
and return a init function that converts an input to _type
https://numpy.org/devdocs/reference/generated/numpy.ndarray.__class_getitem__.html
Check the *number* of dimensions and the dtype.
_type examples:
np.ndarray[(int, int), float] <- 2D array with any shape
np.ndarray[tuple[int, int], np.uint8] <- 2D array with any shape
np.ndarray[(2, 3), int] <- 2D array with shape (2, 3)
Note: `builtins.tuple`` supports `[]` since 3.9
"""
assert _type.__origin__ == np.ndarray
assert len(_type.__args__) == 2
# annotated shape and type
_shape_t, _dtype = _type.__args__
n_dims = 0
shape = None
match _shape_t:
case Any:
pass
case int():
n_dims = 1
shape = _shape_t
case tuple() if _is_tuple_of_int(_shape_t):
# np.ndarray[(2, 3), dtype]
# One dimension can be -1
n_dims = len(_shape_t)
shape = _shape_t
case GenericAlias(__origin__=builtins.tuple):
# np.ndarray[tuple[int, int], dtype]
n_dims = len(_shape_t.__args__)
case _:
raise ValueError("Failed to parse np.ndarray dimension: ", _shape_t)
# specialize for dim == 1
if n_dims == 1:
# This is specific to .mat files
# Mat files saves 1D arrays as 2D
# not even a 2D array, but rather array(array([...], dtype), object)
# when loaded with scipy.io.loadmat
def _init(inp):
return np.asarray(inp[0], dtype=_dtype)
return _init
# generic version
def _init(inp):
try:
return np.asarray(inp, dtype=_dtype).reshape(shape)
except TypeError as e:
# Try this heinous case
return np.asarray([x for x in inp[0]], dtype=_dtype).reshape(shape)
return _init
def _get_list_init_func(_type: type[_Type]):
"""
Support List[int] etc
"""
assert _type.__origin__ == list
_dtype = _type.__args__[0]
_dtype_init = get_init_func(_dtype)
def _init(inp) -> _Type:
return [_dtype_init(v) for v in inp]
return _init
def _noop(inp):
return inp
def get_init_func(_type) -> _InitFuncT:
"""
Get the init function for `_type` _Type.
for specializations of GenericAlias, we need to inspect the type annotation
(e.g. for np.ndarray, we need to inspect for dims and dtype), so we have to
use a wrapper function (rather than a simple dict lookup).
Use {Any, object} to use noop as the init function.
Throws NotImplementedError for unknown types.
"""
global _init_func_cache
def c(func):
_init_func_cache[_type] = func
return func
# Use the `init` classmethod if available
if hasattr(_type, "init"):
return _type.init
# Any defaults to noop
if _type in (Any, object):
return _noop
if func := _init_func_cache.get(_type):
return func
if isinstance(_type, GenericAlias):
# _type is a GenericAlias specialization
if _type.__origin__ is np.ndarray:
return c(_get_ndarray_init_func(_type))
if _type.__origin__ is list:
return c(_get_list_init_func(_type))
if _type.__origin__ is tuple:
raise NotImplementedError()
# Scalar types {int, float, np.number}
# note: np.number includes all {np.uint8, np.float64, ...}
# https://numpy.org/doc/stable/reference/arrays.scalars.html
#
# numbers.Number includes np.numbers.
if issubclass(_type, numbers.Number):
return c(_get_scalar_init_func(_type))
# string types
if _type in (str, bytes):
return c(_get_string_init_func(_type))
raise NotImplementedError(f"get_init_func not implemented for type {_type}")
# %%
def get_init_func_tests():
assert get_init_func(int)([[1.0]]) == 1
assert get_init_func(float)([[1.0]]) == 1.0
assert get_init_func(str)(["hello"]) == "hello"
# 1d array (integer n_dims), uint8
tp = np.ndarray[tuple[int], np.uint8]
a = get_init_func(tp)([[2, 2, 2]])
gt = np.array([2, 2, 2], np.uint8)
assert a.shape == (3,)
assert np.allclose(a, gt)
assert a.dtype == gt.dtype
# 1d array (tuple shape), uint8
tp = np.ndarray[(1,), np.uint8]
a = get_init_func(tp)([[2, 2, 2]])
gt = np.array([2, 2, 2], np.uint8)
assert a.shape == (3,)
assert np.allclose(a, gt)
assert a.dtype == gt.dtype
# 1d array (tuple generic), uint8
tp = np.ndarray[tuple[int], np.uint8]
a = get_init_func(tp)([[2, 2, 2]])
gt = np.array([2, 2, 2], np.uint8)
assert a.shape == (3,)
assert np.allclose(a, gt)
assert a.dtype == gt.dtype
# 2d array,
tp = np.ndarray[3, float]
a = get_init_func(tp)([[2, 2, 2]])
gt = np.array([2, 2, 2], float)
assert np.allclose(a, gt)
assert a.dtype == gt.dtype
# 2d array,
tp = np.ndarray[(2, 3), np.float64]
a = get_init_func(tp)([[2, 2, 2], [1, 2, 3]])
gt = np.array([[2, 2, 2], [1, 2, 3]], np.float64)
assert a.shape == (2, 3)
assert np.allclose(a, gt)
assert a.dtype == gt.dtype
# 2d array with variable shape
tp = np.ndarray[(2, -1), np.float64]
a = get_init_func(tp)([[2, 2, 2], [1, 2, 3]])
gt = np.array([[2, 2, 2], [1, 2, 3]], np.float64)
assert a.shape == (2, 3)
assert np.allclose(a, gt)
assert a.dtype == gt.dtype
# strings
a = get_init_func(str)(["hello"])
gt = "hello"
assert a == gt
a = get_init_func(str)(np.array(["hello"]))
gt = "hello"
assert a == gt
a = get_init_func(str)(np.array([["hello"]]))
gt = "hello"
assert a == gt
# %%
# get_init_func(str)(np.array(["hello"]))
StructT = TypeVar("StructT", bound="InitMixin")
class InitMixin:
"""
Init mixin to initialize dataclass attributes from
an object loaded with scipy.io.loadmat
"""
@classmethod
def init(cls: type[StructT], inp: object) -> StructT:
"""
Inspects the attribute annotations on the type `cls` and try to load them
attributes (with type) from the `inp` object.
If a non-optional attr is missing in `inp`, a KeyError is raised.
"""
# For initializing a matlab struct/cell read from scipy.io.loadmat,
# it's always one index in
if isinstance(inp, np.ndarray):
inp = inp[0]
data = {}
# alternatively, use dataclasses.fields
# https://docs.python.org/3/library/dataclasses.html#dataclasses.fields
anno = inspect.get_annotations(cls)
for attr_name, attr_type in anno.items():
# when attr_name is a python reserved keyword
# e.g. "type" and "class", I use a "_" suffix
# so the name cane be used in a dataclass
_attr_name = attr_name.removesuffix("_")
# Check if attr_type is an optional type
_attr_type = attr_type
is_optional = False
if _tp := _get_optional_type(attr_type):
_attr_type = _tp
is_optional = True
# If a field is marked as `any`, make it optional
if attr_type is Any:
is_optional = True
try:
inp_data = inp[_attr_name]
except KeyError as e:
if is_optional:
data[attr_name] = None
continue
raise KeyError(f"Key {attr_name} not found when initializing f{cls}")
if (
isinstance(inp_data, np.ndarray)
and inp_data.dtype == np.dtype("O")
and len(inp_data) == 1
):
inp_data = inp_data[0]
try:
# If the annotated type has a `.init` method, use it to initialize
# the object. Otherwise, call `get_init_func` to generate a init func
# for the object
val = get_init_func(_attr_type)(inp_data)
except Exception as e:
print(
f'Error initializing "{cls}" on attribute "{attr_name}" of type "{_attr_type}"'
)
breakpoint()
raise e
else:
data[attr_name] = val
return cls(**data)
# %%
_get_optional_type_test()
_is_str_test()
_type_is_str_test()
_is_sequence_not_str_test()
get_init_func_tests()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment