Last active
May 12, 2025 09:48
-
-
Save stuaxo/10b7927c79f0f8ce435770e5556bfb69 to your computer and use it in GitHub Desktop.
dataclass_to_pydantic
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Extend https://yeonwoosung.github.io/posts/pydantic-vs-dataclass/ | |
# to convert nested dataclasses. | |
from functools import lru_cache | |
from dataclasses import fields, _MISSING_TYPE, is_dataclass | |
from typing import Any, Optional | |
import pydantic | |
@lru_cache(maxsize=None) | |
def convert_dataclass_to_pydantic( | |
dcls: type, name: Optional[str] = None | |
) -> type[pydantic.BaseModel]: | |
if name is None: | |
name_ = f"{dcls.__name__}Model" | |
else: | |
name_ = name | |
return pydantic.create_model( # type: ignore | |
name_, | |
**_get_pydantic_field_kwargs(dcls), | |
) | |
def _get_pydantic_field_kwargs(dcls: type) -> dict[str, tuple[type, Any]]: | |
# get attribute names and types from dataclass into pydantic format | |
pydantic_field_kwargs = dict() | |
for _field in fields(dcls): | |
# check is field has default value | |
field_type = _field.type | |
if isinstance(_field.default, _MISSING_TYPE): | |
# no default | |
default = ... | |
else: | |
default = _field.default | |
# TODO: Convert default if it's a dataclass | |
if(is_dataclass(field_type)): | |
# Convert fields that are dataclasses. | |
field_type = convert_dataclass_to_pydantic(field_type) | |
pydantic_field_kwargs[_field.name] = (field_type, default) | |
return pydantic_field_kwargs | |
# Example: | |
# Create nested dataclass | |
from typing import Literal | |
from dataclasses import dataclass | |
@dataclass | |
class Inner: | |
a: str | |
b: str | |
@dataclass | |
class Outer: | |
inner: Inner | |
# Create pydantic schemas from dataclasses | |
InnerModel = convert_dataclass_to_pydantic(Inner) | |
OuterModel = convert_dataclass_to_pydantic(Outer) | |
# Example roundtrip data: | |
# pydantic model data->OuterModel-> raw data | |
model_data = OuterModel(inner=InnerModel(a="hello", b="world")) | |
print(model_data.model_dump()) | |
# raw data -> OuterModel | |
print(OuterModel(**{'inner': {'a': 'hello', 'b': 'world'}})) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment