Created
October 18, 2016 08:40
-
-
Save andrearota/5910b5c5ac65845f23856b2415474c38 to your computer and use it in GitHub Desktop.
Creating Spark UDF with extra parameters via currying
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Problem: creating a Spark UDF that take extra parameter at invocation time. | |
// Solution: using currying | |
// http://stackoverflow.com/questions/35546576/how-can-i-pass-extra-parameters-to-udfs-in-sparksql | |
// We want to create hideTabooValues, a Spark UDF that set to -1 fields that contains any of given taboo values. | |
// E.g. forbiddenValues = [1, 2, 3] | |
// dataframe = [1, 2, 3, 4, 5, 6] | |
// dataframe.select(hideTabooValues(forbiddenValues)) :> [-1, -1, -1, 4, 5, 6] | |
// | |
// Implementing this in Spark, we find two major issues: | |
// 1) Spark UDF factories do not support parameter types other than Columns | |
// 2) While we can define the UDF behaviour, we are not able to tell the taboo list content before actual invocation. | |
// | |
// To overcome these limitations, we need to exploit Scala functional programming capabilities, using currying. | |
import org.apache.spark.sql._ | |
import org.apache.spark.sql.types._ | |
// Just create a simple dataframe with integers from 0 to 999. | |
val rowRDD = sc.parallelize(0 to 999).map(Row(_)) | |
val schema = StructType(StructField("value", IntegerType, true) :: Nil) | |
val rowDF = sqlContext.createDataFrame(rowRDD, schema) | |
// Here we use currying: hideTabooValues is a partial function of type (List[Int]) => UserDefinedFunction | |
def hideTabooValues(taboo: List[Int]) = udf((n: Int) => if (taboo.contains(n)) -1 else n) | |
// Semplifying, you can see hideTabooValues as a UDF factory, that specialises the given UDF definition at invocation time. | |
// This will show that, without giving a parameter, hideTabooValues is just a function. | |
hideTabooValues _ | |
// res7: List[Int] => org.apache.spark.sql.UserDefinedFunction = <function1> | |
// It's time to try our UDF! Let's define the taboo list | |
val forbiddenValues = List(0, 1, 2) | |
// And then use Spark SQL to apply the UDF. You can see two invocation here: the first creates the specific UDF | |
// with the given taboo list, and the second uses the UDF itself in a classic select instruction. | |
rowDF.select(hideTabooValues(forbiddenValues)(rowDF("value"))).show(6) | |
// +----------+ | |
// |UDF(value)| | |
// +----------+ | |
// | -1| | |
// | -1| | |
// | -1| | |
// | 3| | |
// | 4| | |
// | 5| | |
// +----------+ | |
// only showing top 6 rows | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment