Skip to content

Instantly share code, notes, and snippets.

@InvisibleTech
Created October 10, 2015 04:38

Revisions

  1. InvisibleTech created this gist Oct 10, 2015.
    52 changes: 52 additions & 0 deletions pivotRDD.scala
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,52 @@
    /*
    This Pivot sample is based on the 5th answer given on:
    http://stackoverflows.xyz/questions/30260015/reshaping-pivoting-data-in-spark-rdd-and-or-spark-dataframes
    The answer above was written in Python, which I don't know very well. In addition, my Spark-Fu
    is still somewhat introductory in some areas. To help with other aspects of translating the Python
    sample I used these references:
    http://codingjunkie.net/spark-agr-by-key/
    http://scala4fun.tumblr.com/post/84792374567/mergemaps
    http://alvinalexander.com/scala/how-sort-scala-sequences-seq-list-array-buffer-vector-ordering-ordered
    */

    import scala.collection._

    import org.apache.spark.sql.Row;
    import org.apache.spark.sql.types.{StructType,StructField,StringType, IntegerType};

    val rdd = sc .parallelize(List(("X01",41,"US",3),
    ("X01",41,"UK",1),
    ("X01",41,"CA",2),
    ("X02",72,"US",4),
    ("X02",72,"UK",6),
    ("X02",72,"CA",7),
    ("X02",72,"XX",8)))

    val krows = rdd.map(t => Row.fromSeq(t.productIterator.toList)).keyBy(_(0).toString)

    def seqPivot(m: mutable.Map[String, Any], r: Row): mutable.Map[String, Any] = {
    m += (r(2).toString -> r(3))
    m += ("Age" -> r(1))
    m
    }

    def cmbPivot(m1:mutable.Map[String, Any], m2:mutable.Map[String, Any]): mutable.Map[String, Any] = {
    m1 ++= m2
    m1
    }

    val pivoted = krows.aggregateByKey(mutable.Map.empty[String, Any])(seqPivot, cmbPivot)

    val orderedColnames = pivoted.values.map(v => v.keys.toSet).reduce((s, t) => s.union(t)).toSeq.sortWith(_ < _)

    val schema = StructType(List(StructField("ID", StringType, true)) ++ (for (c <- orderedColnames) yield StructField(c, IntegerType, true)))

    val keyedRows = pivoted.map(t => List(t._1) ++ (for (c <- orderedColnames) yield t._2.getOrElse(c, null))).map(row => Row.fromSeq(row))

    val result = sqlContext.createDataFrame(keyedRows, schema)

    result.show