This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
if __name__ == '__main__': | |
from pyspark.sql import SparkSession, functions as F | |
from pyspark import SparkConf | |
from pyspark.sql import functions as F | |
conf = SparkConf() | |
spark = SparkSession.builder \ | |
.config(conf=conf) \ | |
.appName('Dataframe with Indexes') \ | |
.getOrCreate() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import operator | |
import os | |
import time | |
import warnings | |
from pyspark.ml.linalg import Vectors, VectorUDT | |
from pyspark.sql import functions as F, SparkSession, types as T, Window |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
+----+-------------------------------------++-------------+---------------------+ | |
|id |features |prediction |marginal_contribution| | |
+----+-------------------------------------+--------------+---------------------+ | |
|1677|[0.349,0.141,0.162,0.162,0.162,0.349]|0.0 |null | | |
|1677|[0.886,0.141,0.162,0.162,0.162,0.349]|0.0 |0.0 | | |
|2250|[0.106,0.423,0.777,0.777,0.777,0.886]|0.0 |null | | |
|2250|[0.886,0.423,0.777,0.777,0.777,0.886]|0.0 |0.0 | | |
|2453|[0.801,0.423,0.777,0.777,0.87,0.886] |0.0 |null | | |
+----+-------------------------------------+--------------+---------------------+ | |
only showing top 5 rows |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Row: Row(id=964, features=DenseVector([0.886, 0.423, 0.777, 0.777, 0.777, 0.886])) | |
Calculating SHAP values for "f0"... | |
+----+-----+-----+-----+-----+-----+-----+-------------------------------------+-----+-----------+------------------------+------------------------------------------------------------------------------+ | |
|id |f0 |f1 |f2 |f3 |f4 |f5 |features |label|is_selected|features_permutations |x | | |
+----+-----+-----+-----+-----+-----+-----+-------------------------------------+-----+-----------+------------------------+------------------------------------------------------------------------------+ | |
|1677|0.349|0.141|0.162|0.162|0.162|0.349|[0.349,0.141,0.162,0.162,0.162,0.349]|1 |false |[f5, f2, f1, f4, f3, f0]|[[0.349,0.141,0.162,0.162,0.162,0.349], [0.886,0.141,0.162,0.162,0.162,0.349]]| | |
|2250|0.106|0.938|0.434|0.434|0.434|0.106|[0.106,0.938,0.434,0.434,0.434,0.106]|0 |false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import numpy as np | |
import pyspark | |
from shapley_spark_calculation import \ | |
calculate_shapley_values, select_row | |
from pyspark.ml.classification import RandomForestClassifier, LinearSVC, \ | |
DecisionTreeClassifier | |
from pyspark.ml.evaluation import BinaryClassificationEvaluator | |
from pyspark.ml.feature import VectorAssembler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# broadcast the row of interest and ordered feature names | |
ROW_OF_INTEREST_BROADCAST = spark.sparkContext.broadcast( | |
row_of_interest[features_col] | |
) | |
ORDERED_FEATURE_NAMES = spark.sparkContext.broadcast(feature_names) | |
# set up the udf - x-j and x+j need to be calculated for every row | |
def calculate_x( | |
feature_j, z_features, curr_feature_perm | |
): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pyspark | |
from pyspark.sql import functions as F | |
def get_features_permutations( | |
df: pyspark.DataFrame, | |
feature_names: list, | |
output_col='features_permutations' | |
): | |
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from psutil import virtual_memory | |
from pyspark import SparkConf | |
from pyspark.ml.linalg import Vectors, VectorUDT | |
from pyspark.sql import functions as F, SparkSession, types as T, Window | |
def get_spark_session(): | |
""" | |
With an effort to optimize memory and partitions |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
df = spark.DataFrame(...) | |
dict(df.dtypes).get('features') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
ALTER TABLE IF EXISTS table_y2020_w15 | |
DROP CONSTRAINT table_y2020_w15_created_at_check, | |
ADD CONSTRAINT table_y2020_w15_created_at_check CHECK (created_at >= '2020-04-06 00:00:00' AND created_at <= '2020-04-12 23:59:59.999999' ); |
NewerOlder