-
-
Save kalzoo/27479c08ad3c214c23876ee5cdcc0aa4 to your computer and use it in GitHub Desktop.
Dataclass to Pydantic BaseModel
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
from typing import Any, Type | |
from dataclasses import dataclass, is_dataclass | |
from pydantic import BaseModel, ValidationError | |
from pydantic.main import MetaModel | |
AnyType = Type[Any] | |
def from_dataclass(DataClass: AnyType) -> BaseModel: | |
def _get_model(DataClass: AnyType) -> BaseModel: | |
for field_name, field_type in DataClass.__annotations__.items(): | |
if is_dataclass(field_type): | |
field = _get_model(field_type) | |
DataClass.__annotations__[field_name] = field | |
namespace = { | |
'__annotations__': DataClass.__annotations__, | |
'__module__': DataClass.__module__, | |
'__qualname__': DataClass.__qualname__ | |
} | |
return MetaModel(DataClass.__name__, (BaseModel,), namespace) | |
return _get_model(DataClass) | |
class ModelTest(unittest.TestCase): | |
def test_simple_success(self): | |
@dataclass | |
class SimpleDataClass: | |
val: int | |
SimpleModel = from_dataclass(SimpleDataClass) | |
data = {'val': 123} | |
res = SimpleModel(**data) | |
self.assertEqual(res.dict(), data) | |
def test_simple_validation_error(self): | |
@dataclass | |
class SimpleDataClass: | |
val: int | |
SimpleModel = from_dataclass(SimpleDataClass) | |
data = {'val': 'word'} | |
with self.assertRaises(ValidationError): | |
SimpleModel(**data) | |
def test_nested_success(self): | |
@dataclass | |
class DataClassInner: | |
inner_val: str | |
@dataclass | |
class NestedDataClass: | |
val: str | |
inner_data_class: DataClassInner | |
NestedModel = from_dataclass(NestedDataClass) | |
data = {'val': 'word', 'inner_data_class': {'inner_val': 'word'}} | |
res = NestedModel(**data) | |
self.assertEqual(res.dict(), data) | |
def test_nested_validation_error(self): | |
@dataclass | |
class DataClassInner: | |
inner_val: int | |
@dataclass | |
class NestedDataClass: | |
val: str | |
inner_data_class: DataClassInner | |
NestedModel = from_dataclass(NestedDataClass) | |
data = {'val': 'word', 'inner_data_class': {'inner_val': 'word'}} | |
with self.assertRaises(ValidationError): | |
NestedModel(**data) | |
if __name__ == '__main__': | |
unittest.main(verbosity=2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment