|
# SPDX-License-Identifier: MIT |
|
|
|
import dataclasses |
|
import enum |
|
import types |
|
|
|
|
|
# The next two internal classes exist to properly handle accessing .value. |
|
# |
|
# Ordinarily, in a dataclass, .value on a member returns an object of |
|
# the member type (e.g., an int). Here, since we are combining the |
|
# dataclass (the member type) and the enum itself into one, we need to |
|
# do a little special handling. |
|
# |
|
# If we have a dataclass field named "value", then we let the enum treat |
|
# that field as the underlying data type. The purpose of this is to |
|
# allow lookups by value: |
|
# |
|
# @dataenum |
|
# class Errno: |
|
# value: int |
|
# strerror: str |
|
# |
|
# EPERM = 1, "Operation not permitted" |
|
# ENOENT = 2, "No such file or directory" |
|
# ... |
|
# |
|
# try: |
|
# ... |
|
# except OSError as e: |
|
# print(Errno(e.errno).strerror) |
|
# |
|
# If we do not have a dataclass field named "value", then we want .value |
|
# to point back to self (i.e. SomeEnum.blah.value is SomeEnum.blah). |
|
# |
|
# The Enum implementation does some trickery with the specific attribute |
|
# "value": it sets this to a property/descriptor that behaves |
|
# differently as a class attribute and an instance attribute, so you can |
|
# have an enum member named "value" and access it from the enum class. |
|
# From an instance, the descriptor's getter returns self._value_ and |
|
# there is no setter. This gets in our way if we wish to have a |
|
# dataclass field named "value", as the __init__ function generated by |
|
# @dataclass tries to assign to it. So, if we do have a field named |
|
# "value" (which implies that we cannot have a member named "value"), |
|
# then we inherit from _ValueFieldDataEnum, which simply serves to |
|
# shadow the "value" descriptor so assignment behaves normally in a |
|
# subclass. |
|
# |
|
# We also need two more tweaks in _ValueFieldDataEnum. First, because |
|
# the Enum implementation uses ._value_ instead of .value internally |
|
# (e.g. in computing the member map), we need to set both. Second, the |
|
# __repr__ now needs to operate on self, not on self._value_, so we |
|
# include a modified version of Enum.__repr__ that hard-codes the |
|
# dataclass repr behavior but runs it on the right object. |
|
# |
|
# If we don't have a field named "value", we inherit from |
|
# _SelfValuedDataEnum, which instead sets ._value_ to self. |
|
# |
|
# In either case, we have to set self._value_ in __new__, not in |
|
# __init__, as mentioned in the enum documentation. (The reason for this |
|
# is complicated. Calling the enum constructor, as in Errno(e.errno) in |
|
# the example above, should return one of the existing members of the |
|
# enum, and so Enum overrides __new__ with special behavior. But at some |
|
# point you actually do need to construct instances to create those |
|
# members. So, it takes advantage of the fact that you can explicitly |
|
# call a base class's __new__ without using constructor syntax, and |
|
# bypass the current __new__. However, an implication of this is that |
|
# __init__ is not called either. So Enum manually calls it, but it calls |
|
# it after it looks up (or sets) ._value_ on the object, and so setting |
|
# self._value_ in __init__ is too late. I suspect this is a solvable |
|
# bug.) |
|
|
|
class _ValueFieldDataEnum(enum.Enum): |
|
""" |
|
A version of the Enum base for dataclasses with a field named 'value'. |
|
|
|
See the comments above for an explanation. |
|
""" |
|
value = enum.nonmember(None) |
|
|
|
def __new__(cls, *args, **kwargs): |
|
self = super(enum.Enum, cls).__new__(cls) |
|
self.__init__(*args, **kwargs) |
|
self._value_ = self.value |
|
# This is tricky - as mentioned above, Enum manually calls |
|
# __init__ but too late. We just called it ourselves, and we |
|
# don't want to call it twice, so replace it with an empty |
|
# implementation (that also does not call super().__init__) to |
|
# effectively supppress Enum's manual call. |
|
self.__init__ = lambda self, *args, **kwargs: None |
|
return self |
|
|
|
def __repr__(self): |
|
return "<%s.%s: %s>" % (self.__class__.__name__, self._name_, enum._dataclass_repr(self)) |
|
|
|
|
|
class _SelfValuedDataEnum(enum.Enum): |
|
""" |
|
A version of the Enum base for dataclasses without a field named 'value'. |
|
|
|
See the comments in above for an explanation. |
|
""" |
|
def __new__(cls, *args, **kwargs): |
|
self = super(enum.Enum, cls).__new__(cls) |
|
self._value_ = self |
|
return self |
|
|
|
|
|
def dataenum(cls): |
|
""" |
|
Turn a class into both an enum and a dataclass. |
|
|
|
Every field in a dataclass must have an annotation (it may also |
|
have an actual value, either a default or dataclass.field()). |
|
Members of an enum [should not have a type annotation][1]. Therefore, we |
|
can express an enum whose member type is a dataclass in a single |
|
class definition: anything typed is a field and anything untyped is |
|
a member. |
|
|
|
This decorator turns the argument class into a dataclass using the |
|
:func:`dataclasses.dataclass` decorator, creates a new |
|
:class:`enum.Enum` subclass that inherits from the dataclass, and |
|
returns the new subclass. |
|
|
|
Typical usage: |
|
|
|
@dataenum |
|
class MyDataEnum: |
|
field1: int |
|
field2: str |
|
|
|
member1: 1, "one" |
|
member2: 2, "two" |
|
|
|
[1] https://typing.readthedocs.io/en/latest/spec/enums.html |
|
""" |
|
datacls = dataclasses.dataclass(frozen=True)(cls) |
|
fields = {field.name for field in dataclasses.fields(datacls)} |
|
|
|
if "name" in fields: |
|
raise TypeError("Cannot have a dataclass field named 'name' in a dataenum") |
|
|
|
if "value" in fields: |
|
parent = _ValueFieldDataEnum |
|
else: |
|
parent = _SelfValuedDataEnum |
|
|
|
def populate_ns(ns): |
|
# Anything that isn't a dataclass field becomes an instance of |
|
# the dataclass and a member of the enum. |
|
for k, v in cls.__dict__.items(): |
|
if k not in fields and k[0] != "_": |
|
ns[k] = v |
|
|
|
enumcls = types.new_class(cls.__name__, (datacls, parent), {}, populate_ns) |
|
return enumcls |
|
|
|
### |
|
|
|
import dataclasses |
|
from typing import Callable |
|
import unittest |
|
|
|
|
|
# Based on https://blog.glyph.im/2025/01/active-enum.html |
|
@dataenum |
|
class SomeNumber: |
|
result: int |
|
effect: Callable[[], None] |
|
|
|
one = 1, lambda: print("one!") |
|
two = 2, lambda: print("two!") |
|
three = 3, lambda: print("three!") |
|
|
|
|
|
@dataenum |
|
class SomeNumberV: |
|
value: int |
|
effect: Callable[[], None] |
|
|
|
one = 1, lambda: print("one!") |
|
two = 2, lambda: print("two!") |
|
three = 3, lambda: print("three!") |
|
|
|
|
|
@dataenum |
|
class Defaulted: |
|
field1: int = 1 |
|
member1 = 2 |
|
field2: int = 3 |
|
|
|
|
|
@dataenum |
|
class Featureful: |
|
field1: str = dataclasses.field(repr=False) |
|
field2: int = dataclasses.field(default=1) |
|
field3: list = dataclasses.field(default_factory=list) |
|
member1 = "Hello" |
|
|
|
|
|
@dataenum |
|
class MemberNamedValue: |
|
a: int |
|
hue = 1 |
|
saturation = 2 |
|
value = 3 |
|
|
|
|
|
class DataEnumTests(unittest.TestCase): |
|
|
|
def test_somenumber(self): |
|
"""Test a @dataenum without a value field.""" |
|
|
|
with self.assertRaises(ValueError): |
|
SomeNumber(1) |
|
self.assertIs(SomeNumber.one.value, SomeNumber.one) |
|
|
|
self.assertEqual(SomeNumber.one.result, 1) |
|
self.assertTrue(callable(SomeNumber.one.effect)) |
|
|
|
self.assertNotEqual(SomeNumber.one, SomeNumber.two) |
|
|
|
self.assertTrue(dataclasses.is_dataclass(SomeNumber)) |
|
self.assertTrue(dataclasses.is_dataclass(SomeNumber.one)) |
|
|
|
self.assertTrue(repr(SomeNumber.one).startswith("<SomeNumber.one: result=1, effect=")) |
|
|
|
self.assertEqual([i.result for i in SomeNumber], [1, 2, 3]) |
|
|
|
def test_basic_value(self): |
|
"""Test a @dataenum with a value field.""" |
|
self.assertIs(SomeNumberV(1), SomeNumberV.one) |
|
|
|
self.assertEqual(SomeNumberV.one.value, 1) |
|
self.assertTrue(callable(SomeNumberV.one.effect)) |
|
|
|
self.assertNotEqual(SomeNumberV.one, SomeNumberV.two) |
|
|
|
self.assertTrue(dataclasses.is_dataclass(SomeNumberV)) |
|
self.assertTrue(dataclasses.is_dataclass(SomeNumberV.one)) |
|
|
|
self.assertTrue(repr(SomeNumberV.one).startswith("<SomeNumberV.one: value=1, effect=")) |
|
|
|
self.assertEqual([i.value for i in SomeNumberV], [1, 2, 3]) |
|
|
|
def test_assignment(self): |
|
"""Typed assignments should count as fields, not enum members.""" |
|
self.assertEqual(len(Defaulted), 1) |
|
self.assertIn(Defaulted.member1, Defaulted) |
|
self.assertNotIn(Defaulted.field1, Defaulted) |
|
self.assertEqual(Defaulted.member1.field1, 2) |
|
self.assertEqual(Defaulted.member1.field2, 3) |
|
# This is consistent with @dataclass |
|
self.assertEqual(Defaulted.field1, 1) |
|
|
|
def test_dataclass_features(self): |
|
"""dataclasses.field should work.""" |
|
self.assertNotIn("Hello", repr(Featureful.member1)) |
|
self.assertEqual(Featureful.member1.field2, 1) |
|
self.assertEqual(Featureful.member1.field3, []) |
|
|
|
def test_invalid_fields(self): |
|
"""Test various invalid field names.""" |
|
with self.assertRaises(TypeError): |
|
@dataenum |
|
class A: |
|
a = dataclasses.field(default=3) |
|
b = 1 |
|
|
|
with self.assertRaises(TypeError): |
|
@dataenum |
|
class A: |
|
name: str |
|
value: int |
|
a = "hi", 1 |
|
|
|
# This one doesn't raise, but _reserved_ is considered a random |
|
# class variable on the dataclass, not a member. |
|
@dataenum |
|
class A: |
|
value: int |
|
_reserved_ = 1 |
|
|
|
self.assertEqual(A._reserved_, 1) |
|
self.assertNotIsInstance(A._reserved_, A) |
|
|
|
def test_member_named_value(self): |
|
"""Ensure 'value' can be a member name instead of a field name.""" |
|
self.assertEqual(MemberNamedValue.hue.a, 1) |
|
self.assertEqual(MemberNamedValue.saturation.a, 2) |
|
self.assertEqual(MemberNamedValue.value.a, 3) |
|
|
|
def test_init_called_once(self): |
|
i = 0 |
|
@dataenum |
|
class Count: |
|
a: int |
|
def __post_init__(self): |
|
nonlocal i |
|
i += 1 |
|
x = 1 |
|
y = 2 |
|
self.assertEqual(i, 2) |
|
|
|
def test_init_called_once_value(self): |
|
i = 0 |
|
@dataenum |
|
class Count: |
|
value: int |
|
def __post_init__(self): |
|
nonlocal i |
|
i += 1 |
|
x = 1 |
|
y = 2 |
|
self.assertEqual(i, 2) |