Created
April 24, 2024 07:27
-
-
Save yvki/70e99a8921745224315b549cd9b96108 to your computer and use it in GitHub Desktop.
PySpark Data Handling Cheatsheet ⚙️
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# 1. Data Loading | |
# 1.1 Load from CSV file | |
df = spark.read.csv('filename.csv', header=True, inferSchema=True) | |
# 1.2 Load from JDBC (Java Database) | |
df = spark.read.format("jdbc").options(url="jdbc_url", dbtable="tablename").load() | |
# 1.3 Load from text file | |
df = spark.read.text('filename.txt') | |
# 2. Data Inspection | |
# 2.1 Display top rows | |
df.show() | |
# 2.2 Print schema | |
df.printSchema() | |
# 2.3 Summary statistics | |
df.describe.show() | |
# 2.4 Count rows | |
df.count() | |
# 2.5 Display columns | |
df.columns | |
# 3. Data Cleaning | |
# 3.1 Drop missing values | |
df.na.drop() | |
# 3.2 Fill missing values | |
df.na.fill(value) | |
# 3.3 Drop column | |
df.drop('columnname') | |
# 3.4 Rename column | |
df.withColumnRenamed('oldcolumnname', 'newcolumnname') | |
# 4. Data Transformation | |
# 4.1 Select column(s) | |
df.select('column1name', 'column2name') | |
# 4.2 Add new or transform old column | |
df.withColumn('newcolumnname', expression) | |
# 4.3 Filter rows greater or lesser than constraint | |
df.filter(df['columnname'] > value) | |
df.filter(df['columnname'] < value) | |
# 4.4 Aggregation and group by column | |
df.groupby('columnname').agg({'columnname': 'sum'}) | |
# 4.5 Sort rows in decreasing format, default is ascending ie. asc) | |
df.sort(df['columnname'].desc()) | |
# 4.6 Count column values | |
df.groupBy('columnname').count().show() | |
# 4.7 Retrieve distinct column values | |
df.select('columnname').distinct().show() | |
# 4.8 Aggregations | |
df.groupBy().sum('columnname').show() | |
df.groupBy().max('columnname').show() | |
df.groupBy().min('columnname').show() | |
df.groupBy().avg('columnname').show() | |
# 4.9 Custom aggregations | |
from pyspark.sql import functions as F | |
df.groupBy('groupcolumnname').agg(F.sum('sumcolumnname')) | |
# 5. SQL queries on DataFrames | |
# 5.1 Create temporary view | |
df.createOrReplaceTempView('viewname') | |
# 5.2 Retreive all records in view | |
spark.sql('SELECT * FROM viewname WHERE condition = ?') | |
# 5.3 Register DataFrame as table | |
df.createOrReplaceTempView('temptablename') | |
# 6. Simple statistical analysis | |
# 6.1 Correlation matrix | |
from pyspark.ml.stat import Correlation | |
Correlation.corr(df, 'columnname') | |
# 6.2 Covariance | |
df.stat.cov('column1name', 'column2name') | |
# 6.3 Items Frequency | |
df.stat.freqItems(['column1name', 'column2name']) | |
# 6.4 Sample By | |
df.sampleBy('columnname', fractions={'class1': 0.1, 'class2': 0.2}) | |
# 7. Data handling for missing and duplicated values | |
# 7.1 Fill missing values in column | |
df.fillna({'columnname': value}) | |
# 7.2 Remove duplicates | |
df.dropDuplicates() | |
# 7.3 Replace missing values | |
df.na.replace(['oldvalue'], ['newvalue'], 'columnname') | |
# 7.4 Remove rows with null values | |
df.na.drop() | |
# 7.5 Count total null values in each column | |
from pyspark.sql import functions as F | |
df.select([F.count(F.when(F.isnull(c), c)).alias(c) for c in df.columns]) | |
# 8. Data conversion and export | |
# 8.1 Convert to Pandas DataFrame | |
pandas_df = df.toPandas() | |
# 8.2 Write DataFrame to CSV | |
df.write.csv('filepath.csv') | |
# 8.3 Write DataFrame to JDBC (Java Database) | |
df.write.format("jdbc").options(url="jdbc_url", dbtable="tablename").save() | |
# 9. Column operations | |
# 9.1 Change column type | |
df.withColumn('columnname', df['columnname'].cast('new_type')) | |
# 9.2 Split column into multiple | |
df.withColumn('newcolumnname', split(df['columnname'], 'delimiter')[0]) | |
# 9.3 Concatenate columns | |
df.withColumn('newcolumnname', concat_ws(' ', df['column1name'], df['column2name'])) | |
# 10. Date and time operations | |
# 10.1 Get current date | |
df.withColumn('currentdate', current_date()) | |
# 10.2 Format date | |
df.withColumn('formatteddate', date_format('dateColumn', 'yyyyMMdd')) | |
# 10.3 Date arithmetic (eg. date + 15 days) | |
df.withColumn('dateafterdays', date_add(df['date'], 15)) | |
# 11. Advanced date processing | |
# 11.1 Window functions | |
from pyspark.sql.window import Window | |
df.withColumn('rank', rank().over(Window.partitionBy('column1name').orderBy('column2name'))) | |
# 11.2 Pivot table | |
df.groupBy('columnname').pivot('pivotcolumnname').sum('sumcolumnname') | |
# 11.3 User defined functions (UDF) | |
from pyspark.sql.functions import udf | |
udf_function = udf(your_python_function) | |
df.withColumn('newcolumnname', udf_function(df['columnname'])) | |
# 12. Performance optimization | |
# 12.1 Dataframe caching | |
df.cache() | |
# 12.2 Dataframe repartitioning | |
df.repartition(10) | |
# 12.3 Dataframe joining with broadcast | |
df.join(broadcast(df2), 'key', 'inner') | |
# 12.4 Dataframe inner join | |
df1.join(df2, df1['id'] == df2['id']) | |
# 12.5 Dataframe outer left join | |
df1.join(df2, df1['id'] == df2['id'], 'left_outer') | |
# 12.6 Dataframe outer right join | |
df1.join(df2, df1['id'] == df2['id'], 'right_outer') | |
# 13. Complex data types | |
# 13.1 Exploding arrays | |
df.withColumn('exploded', explode(df['arraycolumnname'])) | |
# 13.2 Struct works | |
df.select(df['structcolumnname']['field']) | |
# 13.3 Maps handling | |
df.select(map_keys(df['mapcolumnname'])) | |
# 13.4 Json reading | |
df = spark.read.json('filepath.json') | |
# 13.5 Json explosion | |
df.selectExpr('jsoncolumn.*') | |
# 14. Load and save model | |
# 14.1 Load model | |
from pyspark.ml.classification import LogisticRegressionModel | |
LogisticRegressionModel.load('modelpath') | |
# 14.2 Save model | |
model.save('modelpath') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment