Skip to content

Instantly share code, notes, and snippets.

@yvki
Created April 24, 2024 07:27
Show Gist options
  • Save yvki/70e99a8921745224315b549cd9b96108 to your computer and use it in GitHub Desktop.
Save yvki/70e99a8921745224315b549cd9b96108 to your computer and use it in GitHub Desktop.
PySpark Data Handling Cheatsheet ⚙️
# 1. Data Loading
# 1.1 Load from CSV file
df = spark.read.csv('filename.csv', header=True, inferSchema=True)
# 1.2 Load from JDBC (Java Database)
df = spark.read.format("jdbc").options(url="jdbc_url", dbtable="tablename").load()
# 1.3 Load from text file
df = spark.read.text('filename.txt')
# 2. Data Inspection
# 2.1 Display top rows
df.show()
# 2.2 Print schema
df.printSchema()
# 2.3 Summary statistics
df.describe.show()
# 2.4 Count rows
df.count()
# 2.5 Display columns
df.columns
# 3. Data Cleaning
# 3.1 Drop missing values
df.na.drop()
# 3.2 Fill missing values
df.na.fill(value)
# 3.3 Drop column
df.drop('columnname')
# 3.4 Rename column
df.withColumnRenamed('oldcolumnname', 'newcolumnname')
# 4. Data Transformation
# 4.1 Select column(s)
df.select('column1name', 'column2name')
# 4.2 Add new or transform old column
df.withColumn('newcolumnname', expression)
# 4.3 Filter rows greater or lesser than constraint
df.filter(df['columnname'] > value)
df.filter(df['columnname'] < value)
# 4.4 Aggregation and group by column
df.groupby('columnname').agg({'columnname': 'sum'})
# 4.5 Sort rows in decreasing format, default is ascending ie. asc)
df.sort(df['columnname'].desc())
# 4.6 Count column values
df.groupBy('columnname').count().show()
# 4.7 Retrieve distinct column values
df.select('columnname').distinct().show()
# 4.8 Aggregations
df.groupBy().sum('columnname').show()
df.groupBy().max('columnname').show()
df.groupBy().min('columnname').show()
df.groupBy().avg('columnname').show()
# 4.9 Custom aggregations
from pyspark.sql import functions as F
df.groupBy('groupcolumnname').agg(F.sum('sumcolumnname'))
# 5. SQL queries on DataFrames
# 5.1 Create temporary view
df.createOrReplaceTempView('viewname')
# 5.2 Retreive all records in view
spark.sql('SELECT * FROM viewname WHERE condition = ?')
# 5.3 Register DataFrame as table
df.createOrReplaceTempView('temptablename')
# 6. Simple statistical analysis
# 6.1 Correlation matrix
from pyspark.ml.stat import Correlation
Correlation.corr(df, 'columnname')
# 6.2 Covariance
df.stat.cov('column1name', 'column2name')
# 6.3 Items Frequency
df.stat.freqItems(['column1name', 'column2name'])
# 6.4 Sample By
df.sampleBy('columnname', fractions={'class1': 0.1, 'class2': 0.2})
# 7. Data handling for missing and duplicated values
# 7.1 Fill missing values in column
df.fillna({'columnname': value})
# 7.2 Remove duplicates
df.dropDuplicates()
# 7.3 Replace missing values
df.na.replace(['oldvalue'], ['newvalue'], 'columnname')
# 7.4 Remove rows with null values
df.na.drop()
# 7.5 Count total null values in each column
from pyspark.sql import functions as F
df.select([F.count(F.when(F.isnull(c), c)).alias(c) for c in df.columns])
# 8. Data conversion and export
# 8.1 Convert to Pandas DataFrame
pandas_df = df.toPandas()
# 8.2 Write DataFrame to CSV
df.write.csv('filepath.csv')
# 8.3 Write DataFrame to JDBC (Java Database)
df.write.format("jdbc").options(url="jdbc_url", dbtable="tablename").save()
# 9. Column operations
# 9.1 Change column type
df.withColumn('columnname', df['columnname'].cast('new_type'))
# 9.2 Split column into multiple
df.withColumn('newcolumnname', split(df['columnname'], 'delimiter')[0])
# 9.3 Concatenate columns
df.withColumn('newcolumnname', concat_ws(' ', df['column1name'], df['column2name']))
# 10. Date and time operations
# 10.1 Get current date
df.withColumn('currentdate', current_date())
# 10.2 Format date
df.withColumn('formatteddate', date_format('dateColumn', 'yyyyMMdd'))
# 10.3 Date arithmetic (eg. date + 15 days)
df.withColumn('dateafterdays', date_add(df['date'], 15))
# 11. Advanced date processing
# 11.1 Window functions
from pyspark.sql.window import Window
df.withColumn('rank', rank().over(Window.partitionBy('column1name').orderBy('column2name')))
# 11.2 Pivot table
df.groupBy('columnname').pivot('pivotcolumnname').sum('sumcolumnname')
# 11.3 User defined functions (UDF)
from pyspark.sql.functions import udf
udf_function = udf(your_python_function)
df.withColumn('newcolumnname', udf_function(df['columnname']))
# 12. Performance optimization
# 12.1 Dataframe caching
df.cache()
# 12.2 Dataframe repartitioning
df.repartition(10)
# 12.3 Dataframe joining with broadcast
df.join(broadcast(df2), 'key', 'inner')
# 12.4 Dataframe inner join
df1.join(df2, df1['id'] == df2['id'])
# 12.5 Dataframe outer left join
df1.join(df2, df1['id'] == df2['id'], 'left_outer')
# 12.6 Dataframe outer right join
df1.join(df2, df1['id'] == df2['id'], 'right_outer')
# 13. Complex data types
# 13.1 Exploding arrays
df.withColumn('exploded', explode(df['arraycolumnname']))
# 13.2 Struct works
df.select(df['structcolumnname']['field'])
# 13.3 Maps handling
df.select(map_keys(df['mapcolumnname']))
# 13.4 Json reading
df = spark.read.json('filepath.json')
# 13.5 Json explosion
df.selectExpr('jsoncolumn.*')
# 14. Load and save model
# 14.1 Load model
from pyspark.ml.classification import LogisticRegressionModel
LogisticRegressionModel.load('modelpath')
# 14.2 Save model
model.save('modelpath')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment