Created
November 7, 2017 00:25
-
-
Save mvonthron/81cbd4a9060d3085711e5e142280dda6 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.sql.Timestamp | |
import java.util.UUID | |
// from the awesome http://db-blog.web.cern.ch/ | |
import ch.cern.sparkmeasure.StageMetrics | |
import org.apache.spark.SparkConf | |
import org.apache.spark.sql.functions.{col, lit} | |
import org.apache.spark.sql.types._ | |
import org.apache.spark.sql.{DataFrame, Row, SparkSession} | |
import org.joda.time.DateTime | |
import scala.collection.JavaConverters._ | |
import scala.util.{Failure, Success, Try} | |
case class Scenario(name: String, filter: DataFrame => DataFrame) | |
object Main extends App { | |
val id = UUID.randomUUID | |
val Jan = new Timestamp(DateTime.parse("2017-01-01T00:00:00Z").getMillis) | |
val Feb = new Timestamp(DateTime.parse("2017-02-01T00:00:00Z").getMillis) | |
val Mar = new Timestamp(DateTime.parse("2017-03-01T00:00:00Z").getMillis) | |
val scenarios = Seq( | |
Scenario("No filter", identity), | |
Scenario("Filter on INT", _.filter(col("int_col") >= 2)), | |
Scenario("Filter on non null INT", _.filter(col("int_null_col").isNotNull)), | |
Scenario("Filter on BOOL", _.filter(col("bool_col") =!= false)), | |
Scenario("Filter on == FLOAT", _.filter(col("float_col") === 1.0)), | |
Scenario("Filter on <= FLOAT", _.filter(col("float_col") <= 2.0)), | |
Scenario("Filter on STRING", _.filter(col("string_col") === "event 1")), | |
Scenario("Filter on non null STRING", _.filter(col("string_null_col").isNotNull)), | |
Scenario("Filter on >= TS", _.filter(col("timestamp_col") >= Feb)), | |
Scenario("Filter on == TS", _.filter(col("timestamp_col") === Feb)), | |
Scenario("Filter on missing column", _.filter(col("doesntexist") === 2)), | |
Scenario("Filter on nested col INT", _.filter(col("nested.int_col") === 2)), | |
Scenario("Filter on dotted col INT", _.filter(col("`dotted.int_col`") === 2)), | |
Scenario("Column projection", _.select(col("int_col"), col("timestamp_col"))), | |
Scenario("Select constant", _.select(lit("foo").as("bar"))) | |
) | |
val conf = new SparkConf() | |
.set("spark.master", "local[1]") | |
.set("spark.app.name", "PredicatePushdownTest") | |
.set("spark.sql.parquet.filterPushdown", "true") | |
.set("spark.sql.parquet.mergeSchema", "false") | |
// .set("spark.sql.parquet.int96AsTimestamp", "false") | |
// .set("spark.sql.parquet.int64AsTimestampMillis", "true") // <- when set true, statistics are filled but still not used | |
.set("parquet.filter.statistics.enabled", "true") | |
.set("parquet.filter.dictionary.enabled", "true") | |
val sparkSession = SparkSession.builder().config(conf).getOrCreate() | |
sparkSession.sparkContext.setLogLevel("WARN") | |
println(s"TESTING SPARK VERSION ${sparkSession.version}") | |
val schema = StructType( | |
Seq( | |
StructField("id", IntegerType), | |
StructField("string_col", StringType), | |
StructField("string_null_col", StringType), | |
StructField("timestamp_col", TimestampType), | |
StructField("int_col", IntegerType), | |
StructField("int_null_col", IntegerType), | |
StructField("float_col", FloatType), | |
StructField("bool_col", BooleanType), | |
StructField("nested", StructType(Seq(StructField("int_col", IntegerType)))), | |
StructField("dotted.int_col", IntegerType) | |
) | |
) | |
val rows = Seq( | |
Row(1, "event 1", "event 1", Jan, 1, 1, 1.asInstanceOf[Float], true, Row(1), 1), | |
Row(2, "event 2", "event 2", Feb, 2, 2, 2.asInstanceOf[Float], true, Row(2), 2), | |
Row(3, "event 1", null, Mar, 3, null, 3.asInstanceOf[Float], false, Row(3), 3) | |
) | |
val initDf = sparkSession.createDataFrame(rows.asJava, schema) | |
initDf.show() | |
// partition by a non-tested column to force one record per files (to avoid row-grouping) | |
initDf.write.partitionBy("id").parquet(s"/tmp/restitution/test-$id") | |
var testResult = Seq.empty[DataFrame] | |
for(sc <- scenarios) { | |
val stageMetrics = StageMetrics(sparkSession) | |
println() | |
println("***************************************") | |
println(sc.name) | |
println("***************************************") | |
val size = stageMetrics.runAndMeasure { | |
val in = sparkSession.read.parquet(s"/tmp/restitution/test-$id") | |
Try { | |
val df = sc.filter(in) | |
df.explain() | |
df.collect().length | |
} match { | |
case Success(n) => n | |
case Failure(e) => | |
println(e) | |
-1 | |
} | |
} | |
testResult :+= getStats(sc.name, size, stageMetrics) | |
println("***************************************") | |
// scala.io.StdIn.readLine() | |
} | |
testResult.reduce(_ union _).show(false) | |
println(s"END TESTS SPARK ${sparkSession.version}") | |
sparkSession.stop() | |
def getStats(name: String, resultSize: Int, metrics: StageMetrics) = { | |
val (begin, end) = (metrics.beginSnapshot, metrics.endSnapshot) | |
metrics.createStageMetricsDF() | |
sparkSession.sql( | |
s"""select "$name" as testName, $resultSize as resultSize, count(*) numStages, sum(numTasks), max(completionTime) - min(submissionTime) as elapsedTime, sum(recordsRead), sum(bytesRead) | |
|from PerfStageMetrics | |
|where submissionTime >= $begin and completionTime <= $end | |
""".stripMargin | |
) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Results (partial output)
SPARK 2.1.1, 2.1.2
SPARK 2.2.0