Created
March 19, 2021 16:58
-
-
Save mkaranasou/3c0720f7868c753b74214fd2272f8447 to your computer and use it in GitHub Desktop.
Get feature permutations, one for each row, using pyspark
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pyspark | |
from pyspark.sql import functions as F | |
def get_features_permutations( | |
df: pyspark.DataFrame, | |
feature_names: list, | |
output_col='features_permutations' | |
): | |
""" | |
Creates a column for the ordered features and then shuffles it. | |
The result is a dataframe with a column `output_col` that contains: | |
[feat2, feat4, feat3, feat1], | |
[feat3, feat4, feat2, feat1], | |
[feat1, feat2, feat4, feat3], | |
... | |
""" | |
return df.withColumn( | |
output_col, | |
F.shuffle( | |
F.array(*[F.lit(f) for f in feature_names]) | |
) | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment