Skip to content

Instantly share code, notes, and snippets.

@deanm0000
Created July 23, 2024 15:51
Show Gist options
  • Save deanm0000/485291942b47b4d32113973eff493e72 to your computer and use it in GitHub Desktop.
Save deanm0000/485291942b47b4d32113973eff493e72 to your computer and use it in GitHub Desktop.
make fake ufunc to avoid map_batches
import polars as pl
import pyarrow.compute as pc
# Example df
df = pl.DataFrame(
[
pl.Series("a", [1, 2, 3], dtype=pl.Int64),
]
)
# decorator definition
def make_fake_ufunc(func):
def wrapper(*args, **kwargs):
for arg in args:
if isinstance(arg, pl.Expr):
return arg.__array_ufunc__(wrapper, "__call__", *args, **kwargs)
for kwarg in kwargs.values():
if isinstance(kwarg, pl.Expr):
return kwarg.__array_ufunc__(wrapper, "__call__", *args, **kwargs)
is_pyarrow_compute = False
if hasattr(func, "__arrow_compute_function__"):
is_pyarrow_compute = True
new_args = []
for arg in args:
if isinstance(arg, pl.Series):
new_args.append(arg.to_arrow())
else:
new_args.append(arg)
args = new_args
for k, kwarg in kwargs.items():
if isinstance(arg, pl.Series):
kwargs[k] = kwarg.to_arrow()
init_resp = func(*args, **kwargs)
if isinstance(init_resp, pl.Series):
return init_resp
elif isinstance(init_resp, (list, tuple)):
return pl.Series(init_resp)
elif is_pyarrow_compute:
return pl.from_arrow(init_resp)
else:
return pl.Series([init_resp])
setattr(wrapper, "signature", "fake")
setattr(wrapper, "nout", 1)
setattr(wrapper, "types", ["??->?"])
return wrapper
# making a func using the decorator
@make_fake_ufunc
def blah(x):
_sum = 0
for _x in x:
_sum += _x
return _sum
# testing using the decorated function
df.select(blah(pl.col("a")))
# testing the decorator with pyarrow compute function
pccumsum = make_fake_ufunc(pc.cumulative_sum)
df.select(pccumsum(pl.col("a")))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment