Skip to content

Instantly share code, notes, and snippets.

@deanm0000
Created March 27, 2025 17:22
Show Gist options
  • Save deanm0000/b488398f3192a909a43b92aa3a8af472 to your computer and use it in GitHub Desktop.
Save deanm0000/b488398f3192a909a43b92aa3a8af472 to your computer and use it in GitHub Desktop.
from typing import Iterable, cast
import polars as pl
from polars.testing import assert_frame_equal
import numpy as np
from datetime import timedelta
import datetime
import asyncio
import time
from io import StringIO
n = 1_000
def make_df(n):
return (
pl.DataFrame(
{
"a": np.random.randint(0, 10, n),
"b": np.random.normal(0, 1, n),
"d": np.random.choice(["A", "B"], n),
}
)
.with_columns(c=pl.datetime(2022, 1, 1) + pl.duration(hours=pl.col("a")))
.with_columns(pl.all().cast(pl.String))
)
async def collect(lf):
"""need this wrapper b/c asyncio.create_task won't take
lf.collect_async as it isn't a coroutine
"""
return await lf.collect_async()
async def waiter(tasks: Iterable[asyncio.Task[pl.DataFrame]], col: str) -> pl.Series:
while True:
done, waiting = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
if len(done) == 0:
raise ValueError("all failed")
done = [x for x in done if x.exception() is None]
for i, dn in enumerate(done):
dn_res = dn.result()
if dn_res.dtypes[0] == pl.Int64 or i == len(done) - 1:
for tsk in waiting:
tsk.cancel()
return dn_res.to_series()
if len(waiting) == 0:
raise ValueError("all failed2")
tasks = waiting
async def auto_cast(df: pl.DataFrame) -> pl.DataFrame:
all_cols = set(df.columns)
tasks = {x: [] for x in df.columns}
for c in df.columns:
for dt in [pl.Int64, pl.Float64, pl.Datetime, pl.Date]:
if dt == pl.Datetime:
tasks[c].append(
asyncio.create_task(
collect(
df.lazy().select(
pl.col(c).str.to_datetime("%Y-%m-%d %H:%M:%S%.f")
)
)
)
)
elif dt == pl.Date:
tasks[c].append(
asyncio.create_task(
collect(df.lazy().select(pl.col(c).str.to_date("%Y-%m-%d")))
)
)
else:
tasks[c].append(
asyncio.create_task(collect(df.lazy().select(pl.col(c).cast(dt))))
)
res = await asyncio.gather(
*[waiter(tsks, c) for c, tsks in tasks.items()], return_exceptions=True
)
res = [x for x in res if isinstance(x, pl.Series)]
casted_cols = {x.name for x in res}
str_cols = all_cols.difference(casted_cols)
new_df = pl.DataFrame(res).select(df[c] if c in str_cols else c for c in df.columns)
return new_df
async def main():
n = 1000000
df = make_df(n)
strt = time.time()
new_df = await auto_cast(df)
print(time.time() - strt)
strt = time.time()
buf = StringIO()
df.write_csv(buf)
buf.seek(0)
new_df2 = pl.read_csv(buf, try_parse_dates=True)
print(time.time() - strt)
assert_frame_equal(new_df, new_df2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment