Created
July 18, 2024 11:43
-
-
Save thangarajan8/2adfe8c041315a9ad57e499b46b29a15 to your computer and use it in GitHub Desktop.
delete.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pyspark.sql import SparkSession | |
from pyspark.sql.functions import col, sum as spark_sum, when | |
# Initialize Spark session | |
# spark = SparkSession.builder \ | |
# .appName("Sales Analysis") \ | |
# .getOrCreate() | |
# Sample data | |
data = [ | |
(100, 'a', 'b', 'c', 's1','p1'), | |
(200, 'a', 'b', 'c', 's2','p1'), | |
(300, 'a', 'b', 'c', 's3','p2'), | |
(100, 'd', 'e', 'f', 's4','p2'), | |
(100, 'd', 'e', 'f', 's5','p3'), | |
(100, 'd', 'e', 'f', 's6','p4') | |
] | |
# Define schema | |
schema = ["sales", "region", "country", "city", "stage",'pincode'] | |
# Create DataFrame | |
df = spark.createDataFrame(data, schema=schema) | |
def create_new_column(df, filter_cols, filter_vals, agg_col, new_col): | |
""" | |
Function to filter, aggregate, and create/update a new column based on conditions in PySpark DataFrame. | |
Args: | |
- df (DataFrame): Input PySpark DataFrame. | |
- filter_cols (list of str): List of column names for filtering. | |
- filter_vals (list or tuple): List or tuple of values corresponding to filter_cols. | |
- agg_col (str): Column to aggregate. | |
- new_col (str): Name of the new column to be created or updated. | |
Returns: | |
- DataFrame: DataFrame with the updated/new column added based on the conditions. | |
""" | |
# Construct filter conditions | |
filter_condition = None | |
for col_name, col_val in zip(filter_cols, filter_vals): | |
if filter_condition is None: | |
filter_condition = col(col_name) == col_val | |
else: | |
filter_condition = filter_condition & (col(col_name) == col_val) | |
# Filter the data | |
filtered_data = df.filter(filter_condition) | |
filtered_data.show() | |
# Group by filter columns and aggregate | |
grouped_data = filtered_data.groupBy(*filter_cols) \ | |
.agg(spark_sum(agg_col).alias('total_sales')) | |
# Set the new column based on the condition | |
result = filtered_data.join(grouped_data, filter_cols, 'left_outer') \ | |
.withColumn(new_col, when(col('total_sales') > 500, 'OK').otherwise('Not OK')) \ | |
.drop('total_sales') | |
return result | |
# Example usage: Filter on region=a, country=b, city=c, aggregate sales, and create/update a new column based on the condition | |
fil = ['a','b','c','p1'] | |
# fil1 = ['d', 'e', 'f',''] | |
filtered_result = create_new_column(df, ['region', 'country','city','pincode'], fil , 'sales', 'new_column') | |
# Show the final result | |
filtered_result.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment