Skip to content

Instantly share code, notes, and snippets.

@AlexanderNenninger
Last active April 14, 2024 13:51
Show Gist options
  • Save AlexanderNenninger/e9ec365ea32f9828a4a304b35022b805 to your computer and use it in GitHub Desktop.
Save AlexanderNenninger/e9ec365ea32f9828a4a304b35022b805 to your computer and use it in GitHub Desktop.
Map datasets with metadata to parquet and back. Supports partitioning, cloud buckets etc. thanks to pyarrow.
*.parquet binary merge=ours
*.parquet filter=lfs diff=lfs merge=lfs -text
*
!.gitignore
!custom_dataset_to_parquet.py
!test_custom_dataset_to_parqet.py
!.gitattributes
!jobs.parquet
!environment.yml
from __future__ import annotations
import json
from dataclasses import dataclass, field, is_dataclass
from math import inf, nan
from typing import Any, ClassVar, Dict, Optional
from warnings import warn
import polars as pl
import pyarrow.parquet as pq
from polars.testing import assert_frame_equal
class ConversionWarning(Warning):
"""Warns if Serialization and Deserialization will change data types."""
pass
class MyJSONEncoder(json.JSONEncoder):
"""JSONEncoder that can handle data classes.
Usage:
```python
@dataclass(frozen=True)
class Foo:
bar: str
foo = Foo(bar="baz")
encoded = json.dumps(foo, cls=MyJSONEncoder)
```
You can extend it to handle other custom types.
"""
def default(self, o):
if is_dataclass(o):
return {f"__python__/dataclasses/{o.__class__.__name__}": vars(o)}
return super.default(o)
class MyJSONDecoder(json.JSONDecoder):
"""JSONDecoder that can handle data classes.
Usage:
```python
@dataclass(frozen=True)
class Foo:
bar: str
foo = Foo(bar="baz")
encoded = json.dumps(foo, cls=MyJSONEncoder)
foo_copy = json.dumps(foo, cls=MyJSONDecoder, classes=[Foo])
assert foo==foo_copy
```
You can extend it to handle other custom types.
"""
def __init__(self, *args, classes=[], **kwargs):
self.dataclass_name_mapping = {cls.__name__: cls for cls in classes}
json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
def object_hook(self, dct):
for cls_name in self.dataclass_name_mapping:
identifier = f"__python__/dataclasses/{cls_name}"
if identifier in dct:
return self.dataclass_name_mapping[cls_name](**dct[identifier])
return dct
@dataclass
class MetaData:
"""Some dataclass for testing."""
foo: str = "Foo"
bar: float = inf
baz: Optional[int] = None
units: Dict[str, str] = field(default_factory=dict)
def assert_eq(self, other):
assert self == other, f"{self=}, {other=}"
@dataclass
class DataSet:
"""Custom Dataset that saves metadata within a parquet file as utf-8 encoded json."""
metadata: Optional[MetaData]
dataframe: pl.DataFrame
_class_id: ClassVar[str] = "__python__/my_namespace/DataSet"
def __post_init__(self):
self.class_id = f"__python__/datasets/{self.__class__.__name__}"
def assert_eq(self, other: DataSet, **kwargs):
"""Assert that self==other. Required since `dataframe` is a `polars.DataFrame`.
Args:
other (DataSet): Object to compare to.
"""
assert isinstance(other, self.__class__)
self.metadata.assert_eq(other.metadata)
self.metadata.assert_eq(other.metadata)
assert_frame_equal(self.dataframe, other.dataframe, **kwargs)
def write_parquet(self, location: Any, **kwargs):
"""Serialize DataSet to parquet.
Args:
file (Any): Any file location that can be handled by `pyarrow.parquet.write_table`.
**kwargs : Will be passed on to `pyarrow.parquet.write_table`.
"""
# dump metadata to a utf-8 encoded json string using custom encoder.
metadata_bytes = json.dumps(self.metadata, cls=MyJSONEncoder).encode("utf-8")
# convert `dataframe` to Pyarrow table.
table = self.dataframe.to_arrow()
# Add own metadata to the table schema.
existing_metadata = table.schema.metadata or {}
new_metadata = {
self._class_id.encode("utf-8"): metadata_bytes,
**existing_metadata,
}
table = table.replace_schema_metadata(new_metadata)
# Categorical columns with lexical ordering will be mapped to Categorical(ordering="physical")
for col in self.dataframe:
if col.dtype.is_(pl.Categorical("lexical")):
warn(
f"Column {col.name} with dtype {col.dtype} will be converted to {pl.Categorical("physical")}",
ConversionWarning,
)
# Write table to parquet.
if partition_cols := kwargs.get("partition_cols"):
# All partition columns will be cast to `Categorical(ordering="physical")`.
for partion_col in partition_cols:
if not self.dataframe[partion_col].dtype.is_(
pl.Categorical("physical")
):
warn(
f"Column {partion_col} of dtype {self.dataframe[partion_col].dtype}"
f"will be converted to {pl.Categorical("physical")}.",
ConversionWarning,
)
pq.write_to_dataset(table=table, root_path=location, **kwargs)
else:
pq.write_table(table=table, where=location, **kwargs)
@classmethod
def read_parquet(cls, location: Any, **kwargs) -> DataSet:
"""Deserialize `DataSet` from parquet.
Args:
file (Any): where to store the data. Anything supported by Pyarrow
works.
Returns:
DataSet: Deserialized `DataSet`.
"""
table = pq.read_table(location, **kwargs)
# get metadata from table
table_metadata = table.schema.metadata
try: # Try to parse `MetaData` from `table_metadata`.
dataset_metadata_str = table_metadata.pop(
cls._class_id.encode("utf-8")
).decode("utf-8")
dataset_metadata = json.loads(
dataset_metadata_str, cls=MyJSONDecoder, classes=[MetaData]
)
except (KeyError, json.JSONDecodeError):
dataset_metadata = None
# Replace metadata from original table.
table = table.replace_schema_metadata(table_metadata)
dataframe = pl.from_arrow(table)
return DataSet(metadata=dataset_metadata, dataframe=dataframe)
if __name__ == "__main__":
import shutil
from io import BytesIO
from pathlib import Path
try:
# MetaData test
metadata = MetaData()
serialized = json.dumps(metadata, cls=MyJSONEncoder)
deserialized = json.loads(serialized, cls=MyJSONDecoder, classes=[MetaData])
metadata.assert_eq(deserialized)
print(serialized)
# DataSet test
# Partitioning casts column to categorical.
dataframe = pl.DataFrame(
{
"partition": ["1", "1", "2", "2"],
"integer": [4, 5, 6, None],
"float": [4.0, None, nan, inf],
"string": ["d", "e", "f", None],
},
).with_columns(pl.col("partition").cast(pl.Categorical(ordering="physical")))
dataset = DataSet(metadata=metadata, dataframe=dataframe)
with BytesIO() as f:
dataset.write_parquet(f)
deserialized_dataset = DataSet.read_parquet(f)
dataset.assert_eq(deserialized_dataset)
print(deserialized_dataset.metadata)
print(deserialized_dataset.dataframe)
# With partitions
Path("./data").mkdir(exist_ok=True)
dataset.write_parquet(
"data/dataset",
partition_cols=["partition"],
existing_data_behavior="delete_matching",
)
partitioned_dataset = DataSet.read_parquet("data/dataset")
dataset.assert_eq(partitioned_dataset, check_column_order=False)
print(partitioned_dataset)
except Exception as e:
raise e
finally:
shutil.rmtree("./data", ignore_errors=True)
shutil.rmtree("./data", ignore_errors=True)
from math import inf, nan
import polars as pl
import polars.testing
import polars.testing.parametric
import pytest
from hypothesis import given
from custom_dataset_to_parquet import ConversionWarning, DataSet, MetaData
@pytest.fixture()
def dataframe():
return pl.DataFrame(
{
"partition": ["1", "1", "2", "2"],
"lexical_partition": ["1", "1", "2", "2"],
"string_partition": ["1", "2", "1", "2"],
"int_partition": [1, 2, 1, 2],
"integer": [1, 2, 3, None],
"float": [4.0, None, nan, inf],
"string": ["d", "e", "f", None],
"categorical": ["a", "b", "a", None],
"list": [[1, 2], [3, None, 5], None, [7, 8, 9]],
"datetime": ["2024-01-01", "2024-02-29T16:32", None, None],
},
schema_overrides={
"partition": pl.Categorical(ordering="physical"),
"lexical_partition": pl.Categorical(ordering="lexical"),
"categorical": pl.Categorical(ordering="physical"),
"list": pl.List(pl.Int32),
},
).with_columns(
pl.col("datetime").str.strptime(
dtype=pl.Datetime(time_unit="ms", time_zone="Europe/Berlin"),
)
)
@pytest.fixture
def metadata():
return MetaData()
@pytest.fixture
def dataset(metadata, dataframe):
return DataSet(metadata, dataframe)
def test_write_read_without_partitioning(dataset, tmp_path):
with pytest.warns(ConversionWarning):
dataset.write_parquet(tmp_path / "no-partition.parquet")
deserialized = DataSet.read_parquet(tmp_path / "no-partition.parquet")
dataset.assert_eq(deserialized, check_dtype=False)
# Changed dtype.
assert deserialized.dataframe["lexical_partition"].dtype.is_(
pl.Categorical("physical")
), f"{deserialized.dataframe["lexical_partition"].dtype}"
pl.testing.assert_frame_equal(
dataset.dataframe.select(pl.exclude("lexical_partition")),
deserialized.dataframe.select(pl.exclude("lexical_partition")),
)
def test_write_read_with_partitioning(dataset, tmp_path):
# No warnings
with pytest.warns(ConversionWarning):
dataset.write_parquet(
tmp_path / "partition_no_warning", partition_cols=["partition"]
)
deserialized = DataSet.read_parquet(tmp_path / "partition_no_warning")
dataset.assert_eq(deserialized, check_column_order=False, check_dtype=False)
# Changed dtype
assert deserialized.dataframe["lexical_partition"].dtype.is_(
pl.Categorical("physical")
), f"{deserialized.dataframe["lexical_partition"].dtype}"
def test_lexical_partition(dataset, tmp_path):
with pytest.warns(ConversionWarning):
partion_col = "lexical_partition"
dataset.write_parquet(tmp_path / partion_col, partition_cols=[partion_col])
deserialized = DataSet.read_parquet(tmp_path / partion_col)
assert dataset.dataframe[partion_col].dtype.is_(pl.Categorical("lexical"))
assert deserialized.dataframe[partion_col].dtype.is_(pl.Categorical("physical"))
polars.testing.assert_frame_equal(
dataset.dataframe[[partion_col, "integer"]],
deserialized.dataframe[[partion_col, "integer"]],
check_row_order=False,
check_column_order=False,
check_dtype=False,
)
def test_string_partition(dataset, tmp_path):
with pytest.warns(ConversionWarning):
partion_col = "string_partition"
dataset.write_parquet(tmp_path / partion_col, partition_cols=[partion_col])
deserialized = DataSet.read_parquet(tmp_path / partion_col)
polars.testing.assert_frame_equal(
dataset.dataframe[[partion_col, "integer"]],
deserialized.dataframe[[partion_col, "integer"]],
check_row_order=False,
check_column_order=False,
check_dtype=False,
)
def test_int_partition(dataset, tmp_path):
with pytest.warns(ConversionWarning):
partion_col = "int_partition"
dataset.write_parquet(tmp_path / partion_col, partition_cols=[partion_col])
deserialized = DataSet.read_parquet(tmp_path / partion_col)
polars.testing.assert_frame_equal(
dataset.dataframe[[partion_col, "integer"]],
deserialized.dataframe.select(
pl.col(partion_col).cast(pl.Int64), "integer"
),
check_row_order=False,
check_column_order=False,
check_dtype=False,
)
@given(
polars.testing.parametric.dataframes(
cols=100, size=100, allowed_dtypes=pl.FLOAT_DTYPES | pl.DATETIME_DTYPES
),
include_cols=[pl.Categorical("physical")],
)
def large_dataset(df: pl.DataFrame):
print(df)
@AlexanderNenninger
Copy link
Author

Damn, got a typo in the file name.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment