Created
March 27, 2025 17:22
-
-
Save deanm0000/b488398f3192a909a43b92aa3a8af472 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Iterable, cast | |
import polars as pl | |
from polars.testing import assert_frame_equal | |
import numpy as np | |
from datetime import timedelta | |
import datetime | |
import asyncio | |
import time | |
from io import StringIO | |
n = 1_000 | |
def make_df(n): | |
return ( | |
pl.DataFrame( | |
{ | |
"a": np.random.randint(0, 10, n), | |
"b": np.random.normal(0, 1, n), | |
"d": np.random.choice(["A", "B"], n), | |
} | |
) | |
.with_columns(c=pl.datetime(2022, 1, 1) + pl.duration(hours=pl.col("a"))) | |
.with_columns(pl.all().cast(pl.String)) | |
) | |
async def collect(lf): | |
"""need this wrapper b/c asyncio.create_task won't take | |
lf.collect_async as it isn't a coroutine | |
""" | |
return await lf.collect_async() | |
async def waiter(tasks: Iterable[asyncio.Task[pl.DataFrame]], col: str) -> pl.Series: | |
while True: | |
done, waiting = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) | |
if len(done) == 0: | |
raise ValueError("all failed") | |
done = [x for x in done if x.exception() is None] | |
for i, dn in enumerate(done): | |
dn_res = dn.result() | |
if dn_res.dtypes[0] == pl.Int64 or i == len(done) - 1: | |
for tsk in waiting: | |
tsk.cancel() | |
return dn_res.to_series() | |
if len(waiting) == 0: | |
raise ValueError("all failed2") | |
tasks = waiting | |
async def auto_cast(df: pl.DataFrame) -> pl.DataFrame: | |
all_cols = set(df.columns) | |
tasks = {x: [] for x in df.columns} | |
for c in df.columns: | |
for dt in [pl.Int64, pl.Float64, pl.Datetime, pl.Date]: | |
if dt == pl.Datetime: | |
tasks[c].append( | |
asyncio.create_task( | |
collect( | |
df.lazy().select( | |
pl.col(c).str.to_datetime("%Y-%m-%d %H:%M:%S%.f") | |
) | |
) | |
) | |
) | |
elif dt == pl.Date: | |
tasks[c].append( | |
asyncio.create_task( | |
collect(df.lazy().select(pl.col(c).str.to_date("%Y-%m-%d"))) | |
) | |
) | |
else: | |
tasks[c].append( | |
asyncio.create_task(collect(df.lazy().select(pl.col(c).cast(dt)))) | |
) | |
res = await asyncio.gather( | |
*[waiter(tsks, c) for c, tsks in tasks.items()], return_exceptions=True | |
) | |
res = [x for x in res if isinstance(x, pl.Series)] | |
casted_cols = {x.name for x in res} | |
str_cols = all_cols.difference(casted_cols) | |
new_df = pl.DataFrame(res).select(df[c] if c in str_cols else c for c in df.columns) | |
return new_df | |
async def main(): | |
n = 1000000 | |
df = make_df(n) | |
strt = time.time() | |
new_df = await auto_cast(df) | |
print(time.time() - strt) | |
strt = time.time() | |
buf = StringIO() | |
df.write_csv(buf) | |
buf.seek(0) | |
new_df2 = pl.read_csv(buf, try_parse_dates=True) | |
print(time.time() - strt) | |
assert_frame_equal(new_df, new_df2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment