Created
June 24, 2024 18:33
-
-
Save kwsp/69440134a4bee0dcbc05cf97e1446bfd to your computer and use it in GitHub Desktop.
InitMixin
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
typed_mat_mixin.py provides a mixin `InitMixin` that tries | |
to generate a `.init` method to a typed Dataclass to automatically | |
deserialize .mat files exported from MATLAB to a Python Dataclass. | |
Supported types: | |
- strings | |
- int, float, np.uint8, and most np typedefs | |
- np.ndarray (1D and 2D tested) | |
- Type arrays as np.array[ndims, dtype] | |
- For example, a 2D array of type uint8 would be `np.ndarray[2, np.uint8]` | |
Usage: | |
```python | |
from dataclasses import dataclass | |
import numpy as np | |
import scipy.io as sio | |
from typed_mat_mixin import InitMixin | |
@dataclass | |
class InnerObject(InitMixin): | |
var3: int | |
arr1: np.ndarray[1, np.float64] | |
@dataclass | |
class Sequence(InitMixin): | |
var1: int | |
var2: float | |
path "something.mat" | |
seq = Sequence.init(sio.loadmat(path)) | |
``` | |
""" | |
# %% | |
from typing import Any, Callable, TypeVar, GenericAlias, Optional, get_args, TypeGuard | |
import types | |
import inspect | |
import numbers | |
import numpy as np | |
# %% | |
def npdtype(cls: type | None = None, align=False): | |
""" | |
Converts an annotated Python class into a numpy.dtype object. | |
https://numpy.org/doc/stable/reference/generated/numpy.dtype.html | |
Propose to support using numpy.dtype as a class decorator. | |
The type annotation can be | |
* A type compatible with np.dtype, such as int, np.int, np.float64 | |
* A 2-tuple of (type, shape) | |
""" | |
def wrap(cls): | |
anno = inspect.get_annotations(cls) | |
dtype = [] | |
for attrname, attrtype in anno.items(): | |
if isinstance(attrtype, tuple): | |
dtype.append((attrname, *attrtype)) | |
else: | |
dtype.append((attrname, attrtype)) | |
return np.dtype(dtype, align=align) | |
if cls is None: | |
return wrap | |
return wrap(cls) | |
# %% | |
# Until https://peps.python.org/pep-0646/ gets full support, we cannot explicitly | |
# type array dimensions. | |
# Here's a hack to make declarations simpler and add custom init functions for arrays # of different shapes and types | |
T = TypeVar("T", bound=object) | |
_Type = TypeVar("_Type") # generic type used throughout module | |
_InitFuncT = Callable[[Any], _Type] # generic init function type | |
_init_func_cache: dict[_Type, _InitFuncT] = {} | |
# %% | |
def _get_optional_type(tp: type) -> Optional[type]: | |
""" | |
If tp is in {Optional[T], T | None, Union[T, None]}, return T | |
else return None | |
""" | |
args = get_args(tp) | |
match args: | |
case type(), types.NoneType: | |
return args[0] | |
case types.NoneType, type(): | |
return args[1] | |
return None | |
def _get_optional_type_test(): | |
assert _get_optional_type(Optional[int]) is int | |
assert _get_optional_type(Optional[str]) is str | |
assert _get_optional_type(str | None) is str | |
assert _get_optional_type(None | str) is str | |
assert _get_optional_type(str) is None | |
# %% | |
def _is_str(inp) -> bool: | |
""" | |
Returns true if `inp` is a str, bytes, or np.ndarray[Any, np.character] | |
""" | |
match inp: | |
case np.ndarray(): | |
return issubclass(inp.dtype.type, np.character) | |
case bytes() | str(): | |
return True | |
return False | |
def _is_str_test(): | |
assert _is_str("Hello") | |
assert _is_str(b"Hello") | |
assert _is_str(np.array("Hello")) | |
assert _is_str(np.array(["Hello"])) | |
assert not _is_str(1) | |
assert not _is_str([1, 2]) | |
assert not _is_str(np.ndarray([1, 2])) | |
def _type_is_str(tp) -> bool: | |
return tp in (str, bytes) | |
def _type_is_str_test(): | |
assert _type_is_str(str) | |
assert _type_is_str(bytes) | |
# %% | |
def _is_sequence_not_str(inp) -> bool: | |
return hasattr(inp, "__iter__") and not _is_str(inp) | |
def _is_sequence_not_str_test(): | |
inp = [1, 2, 3, 4] | |
assert _is_sequence_not_str(inp) | |
inp = np.array([1, 2, 3, 4]) | |
assert _is_sequence_not_str(inp) | |
inp = np.array([[1, 2, 3, 4]]) | |
assert _is_sequence_not_str(inp) | |
inp = "hello" | |
assert not _is_sequence_not_str(inp) | |
inp = b"hello" | |
assert not _is_sequence_not_str(inp) | |
inp = np.array("hello") | |
assert not _is_sequence_not_str(inp) | |
inp = np.array(["hello"]) | |
assert not _is_sequence_not_str(inp) | |
inp = np.array(b"hello") | |
assert not _is_sequence_not_str(inp) | |
inp = np.array([b"hello"]) | |
assert not _is_sequence_not_str(inp) | |
# %% | |
def _get_scalar_init_func(_type: _Type) -> _Type: | |
def _init(inp): | |
if isinstance(inp, np.ndarray): | |
# handle cases of np.array([[ 1.0 ]]) | |
if inp.size != 1: | |
raise ValueError(f"Input array must be size 1, got", inp) | |
inp = inp.item() | |
else: | |
# handle [[ 1.0 ]] | |
while _is_sequence_not_str(inp): | |
inp = inp[0] | |
return _type(inp) | |
return _init | |
def _get_string_init_func(_type: _Type) -> _Type: | |
def _init(inp): | |
if _is_sequence_not_str(inp): | |
assert len(inp) == 1 | |
return _init(inp[0]) | |
# squeeze to handle np.array(["hello"]) | |
if isinstance(inp, np.ndarray): | |
inp = inp.squeeze() | |
return _type(inp) | |
return _init | |
def _is_tuple_of_int(obj) -> TypeGuard[tuple[int, ...]]: | |
"Check the given object is a tuple of int, e.g. (2, 3, 4)" | |
assert isinstance(obj, tuple) | |
return all(isinstance(a, int) for a in obj) | |
import builtins | |
def _get_ndarray_init_func(_type: type[np.ndarray]): | |
""" | |
Take a specialization of a np.ndarray generic type | |
and return a init function that converts an input to _type | |
https://numpy.org/devdocs/reference/generated/numpy.ndarray.__class_getitem__.html | |
Check the *number* of dimensions and the dtype. | |
_type examples: | |
np.ndarray[(int, int), float] <- 2D array with any shape | |
np.ndarray[tuple[int, int], np.uint8] <- 2D array with any shape | |
np.ndarray[(2, 3), int] <- 2D array with shape (2, 3) | |
Note: `builtins.tuple`` supports `[]` since 3.9 | |
""" | |
assert _type.__origin__ == np.ndarray | |
assert len(_type.__args__) == 2 | |
# annotated shape and type | |
_shape_t, _dtype = _type.__args__ | |
n_dims = 0 | |
shape = None | |
match _shape_t: | |
case Any: | |
pass | |
case int(): | |
n_dims = 1 | |
shape = _shape_t | |
case tuple() if _is_tuple_of_int(_shape_t): | |
# np.ndarray[(2, 3), dtype] | |
# One dimension can be -1 | |
n_dims = len(_shape_t) | |
shape = _shape_t | |
case GenericAlias(__origin__=builtins.tuple): | |
# np.ndarray[tuple[int, int], dtype] | |
n_dims = len(_shape_t.__args__) | |
case _: | |
raise ValueError("Failed to parse np.ndarray dimension: ", _shape_t) | |
# specialize for dim == 1 | |
if n_dims == 1: | |
# This is specific to .mat files | |
# Mat files saves 1D arrays as 2D | |
# not even a 2D array, but rather array(array([...], dtype), object) | |
# when loaded with scipy.io.loadmat | |
def _init(inp): | |
return np.asarray(inp[0], dtype=_dtype) | |
return _init | |
# generic version | |
def _init(inp): | |
try: | |
return np.asarray(inp, dtype=_dtype).reshape(shape) | |
except TypeError as e: | |
# Try this heinous case | |
return np.asarray([x for x in inp[0]], dtype=_dtype).reshape(shape) | |
return _init | |
def _get_list_init_func(_type: type[_Type]): | |
""" | |
Support List[int] etc | |
""" | |
assert _type.__origin__ == list | |
_dtype = _type.__args__[0] | |
_dtype_init = get_init_func(_dtype) | |
def _init(inp) -> _Type: | |
return [_dtype_init(v) for v in inp] | |
return _init | |
def _noop(inp): | |
return inp | |
def get_init_func(_type) -> _InitFuncT: | |
""" | |
Get the init function for `_type` _Type. | |
for specializations of GenericAlias, we need to inspect the type annotation | |
(e.g. for np.ndarray, we need to inspect for dims and dtype), so we have to | |
use a wrapper function (rather than a simple dict lookup). | |
Use {Any, object} to use noop as the init function. | |
Throws NotImplementedError for unknown types. | |
""" | |
global _init_func_cache | |
def c(func): | |
_init_func_cache[_type] = func | |
return func | |
# Use the `init` classmethod if available | |
if hasattr(_type, "init"): | |
return _type.init | |
# Any defaults to noop | |
if _type in (Any, object): | |
return _noop | |
if func := _init_func_cache.get(_type): | |
return func | |
if isinstance(_type, GenericAlias): | |
# _type is a GenericAlias specialization | |
if _type.__origin__ is np.ndarray: | |
return c(_get_ndarray_init_func(_type)) | |
if _type.__origin__ is list: | |
return c(_get_list_init_func(_type)) | |
if _type.__origin__ is tuple: | |
raise NotImplementedError() | |
# Scalar types {int, float, np.number} | |
# note: np.number includes all {np.uint8, np.float64, ...} | |
# https://numpy.org/doc/stable/reference/arrays.scalars.html | |
# | |
# numbers.Number includes np.numbers. | |
if issubclass(_type, numbers.Number): | |
return c(_get_scalar_init_func(_type)) | |
# string types | |
if _type in (str, bytes): | |
return c(_get_string_init_func(_type)) | |
raise NotImplementedError(f"get_init_func not implemented for type {_type}") | |
# %% | |
def get_init_func_tests(): | |
assert get_init_func(int)([[1.0]]) == 1 | |
assert get_init_func(float)([[1.0]]) == 1.0 | |
assert get_init_func(str)(["hello"]) == "hello" | |
# 1d array (integer n_dims), uint8 | |
tp = np.ndarray[tuple[int], np.uint8] | |
a = get_init_func(tp)([[2, 2, 2]]) | |
gt = np.array([2, 2, 2], np.uint8) | |
assert a.shape == (3,) | |
assert np.allclose(a, gt) | |
assert a.dtype == gt.dtype | |
# 1d array (tuple shape), uint8 | |
tp = np.ndarray[(1,), np.uint8] | |
a = get_init_func(tp)([[2, 2, 2]]) | |
gt = np.array([2, 2, 2], np.uint8) | |
assert a.shape == (3,) | |
assert np.allclose(a, gt) | |
assert a.dtype == gt.dtype | |
# 1d array (tuple generic), uint8 | |
tp = np.ndarray[tuple[int], np.uint8] | |
a = get_init_func(tp)([[2, 2, 2]]) | |
gt = np.array([2, 2, 2], np.uint8) | |
assert a.shape == (3,) | |
assert np.allclose(a, gt) | |
assert a.dtype == gt.dtype | |
# 2d array, | |
tp = np.ndarray[3, float] | |
a = get_init_func(tp)([[2, 2, 2]]) | |
gt = np.array([2, 2, 2], float) | |
assert np.allclose(a, gt) | |
assert a.dtype == gt.dtype | |
# 2d array, | |
tp = np.ndarray[(2, 3), np.float64] | |
a = get_init_func(tp)([[2, 2, 2], [1, 2, 3]]) | |
gt = np.array([[2, 2, 2], [1, 2, 3]], np.float64) | |
assert a.shape == (2, 3) | |
assert np.allclose(a, gt) | |
assert a.dtype == gt.dtype | |
# 2d array with variable shape | |
tp = np.ndarray[(2, -1), np.float64] | |
a = get_init_func(tp)([[2, 2, 2], [1, 2, 3]]) | |
gt = np.array([[2, 2, 2], [1, 2, 3]], np.float64) | |
assert a.shape == (2, 3) | |
assert np.allclose(a, gt) | |
assert a.dtype == gt.dtype | |
# strings | |
a = get_init_func(str)(["hello"]) | |
gt = "hello" | |
assert a == gt | |
a = get_init_func(str)(np.array(["hello"])) | |
gt = "hello" | |
assert a == gt | |
a = get_init_func(str)(np.array([["hello"]])) | |
gt = "hello" | |
assert a == gt | |
# %% | |
# get_init_func(str)(np.array(["hello"])) | |
StructT = TypeVar("StructT", bound="InitMixin") | |
class InitMixin: | |
""" | |
Init mixin to initialize dataclass attributes from | |
an object loaded with scipy.io.loadmat | |
""" | |
@classmethod | |
def init(cls: type[StructT], inp: object) -> StructT: | |
""" | |
Inspects the attribute annotations on the type `cls` and try to load them | |
attributes (with type) from the `inp` object. | |
If a non-optional attr is missing in `inp`, a KeyError is raised. | |
""" | |
# For initializing a matlab struct/cell read from scipy.io.loadmat, | |
# it's always one index in | |
if isinstance(inp, np.ndarray): | |
inp = inp[0] | |
data = {} | |
# alternatively, use dataclasses.fields | |
# https://docs.python.org/3/library/dataclasses.html#dataclasses.fields | |
anno = inspect.get_annotations(cls) | |
for attr_name, attr_type in anno.items(): | |
# when attr_name is a python reserved keyword | |
# e.g. "type" and "class", I use a "_" suffix | |
# so the name cane be used in a dataclass | |
_attr_name = attr_name.removesuffix("_") | |
# Check if attr_type is an optional type | |
_attr_type = attr_type | |
is_optional = False | |
if _tp := _get_optional_type(attr_type): | |
_attr_type = _tp | |
is_optional = True | |
# If a field is marked as `any`, make it optional | |
if attr_type is Any: | |
is_optional = True | |
try: | |
inp_data = inp[_attr_name] | |
except KeyError as e: | |
if is_optional: | |
data[attr_name] = None | |
continue | |
raise KeyError(f"Key {attr_name} not found when initializing f{cls}") | |
if ( | |
isinstance(inp_data, np.ndarray) | |
and inp_data.dtype == np.dtype("O") | |
and len(inp_data) == 1 | |
): | |
inp_data = inp_data[0] | |
try: | |
# If the annotated type has a `.init` method, use it to initialize | |
# the object. Otherwise, call `get_init_func` to generate a init func | |
# for the object | |
val = get_init_func(_attr_type)(inp_data) | |
except Exception as e: | |
print( | |
f'Error initializing "{cls}" on attribute "{attr_name}" of type "{_attr_type}"' | |
) | |
breakpoint() | |
raise e | |
else: | |
data[attr_name] = val | |
return cls(**data) | |
# %% | |
_get_optional_type_test() | |
_is_str_test() | |
_type_is_str_test() | |
_is_sequence_not_str_test() | |
get_init_func_tests() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment