Last active
August 6, 2024 10:02
-
-
Save AlexanderNenninger/23eaf710929ca0e12d32c99681459919 to your computer and use it in GitHub Desktop.
Tensorize Polars Dataframe
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Any, Callable, Iterable, Protocol, Tuple | |
import polars as pl | |
class SupportsReshape(Protocol): | |
def reshape(shape: int | Tuple[int], *args, **kwargs) -> Any: | |
... | |
def tensorize( | |
frame: pl.DataFrame | pl.LazyFrame, | |
index_columns: Iterable[str], | |
columns_last: bool = True, | |
tensorization_callback: Callable[[pl.DataFrame], SupportsReshape] | None = None, | |
) -> Any: | |
"""Pack DataFrame with N-level index into order N+1 tensor. | |
**Warning**: If index values are duplicated, the first row will be picked. | |
Args: | |
frame (pl.DataFrame | pl.LazyFrame): Polars DataFrame or LazyFrame to tensorize. | |
index_columns (Iterable[str]): Columns forming an index. | |
columns_last (bool, optional): Whether to put the columns of frame as fist index | |
or last. Defaults to True. | |
tensorization_callback (Callable[[pl.DataFrame], SupportsReshape] | None, optional): Function | |
converting Polars DataFrame to Array. If None is passed, a Numpy array is created. | |
Returns: | |
Any: Array containing the padded and tensorized dataframe. | |
""" | |
# Set default value of tensor converter | |
tensorization_callback = tensorization_callback or (lambda df: df.to_numpy()) | |
# Handle both LazyFrames and DataFrames | |
frame = frame.lazy() | |
index_levels = [frame.select(pl.col(icol).unique()) for icol in index_columns] | |
final_size = [frame.collect_schema().len() - len(index_columns)] + [ | |
ilevel.select(pl.len()).collect().item() for ilevel in index_levels | |
] | |
# Reverse shape | |
if columns_last: | |
final_size = final_size[::-1] | |
# Generate product space over all index levels. | |
new_index, index_levels = index_levels[0], index_levels[1:] | |
for ilevel in index_levels: | |
new_index = new_index.join(ilevel, how="cross") | |
# Pad empty index spots | |
padded_df = new_index.join( | |
frame.group_by(index_columns).agg(pl.all().first()), | |
on=index_columns, | |
how="left", | |
).select(pl.exclude(index_columns)) | |
# Pack into tensor. | |
return tensorization_callback(padded_df.collect()).reshape(final_size) | |
if __name__=="__main__": | |
test_data = pl.DataFrame( | |
{ | |
"index_0": ["A", "A", "B", "C", "B", "B"], | |
"index_1": ["a", "b", "a", "b", "a", "b"], | |
"values_0": ["Aa0", "Ab0", "Ba0", "Cb0", "Ba0", "Bb0"], | |
"values_1": ["Aa1", "Ab1", "Ba1", "Cb1", "Ba1", "Bb1"], | |
} | |
) | |
print(test_data) | |
index_columns = ["index_0", "index_1"] | |
res = tensorize(test_data, index_columns, lambda df: df.to_numpy()) | |
print(res) | |
print(res.reshape(-1, test_data.width - len(index_columns))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment