Skip to content

Instantly share code, notes, and snippets.

@cedricbastin
Created February 2, 2016 19:12
Show Gist options
  • Save cedricbastin/f2c86d9769c8144b9e1c to your computer and use it in GitHub Desktop.
Save cedricbastin/f2c86d9769c8144b9e1c to your computer and use it in GitHub Desktop.
Databricks coding challenge
//sbt run
//or
//scalac GroupBy.scala; scala GroupByRun
import java.util.UUID
import scala.collection.immutable.HashMap
import scala.util.Random
import java.io._
/**
* System Information
* contain memory related information such that we know how much information we can contain in memory
* and when we have to write it to the disk
* In the applied method one can see that on average the memory stays 50% unused,
* if we had easier access to the memory information at runtime this could easily be improved!
*/
object SysInfo {
def debug(s:String) = if (false) println("debug"+s)
//memory constants
val memory = 20 //in bytes
val Ksize = 2
val Vsize = 1
var Kcount = 0
var Vcount = 0
def hasMemory[K, V](map:Map[K, List[V]]) = {
//TODO: very expensive function, should not be called at each iteration! => rather have var counters
val nbKeys = map.keys.size
val nbVals = map.values.foldLeft(0)((b, list) => b + list.size)
((Ksize*nbKeys) + (Vsize*nbVals) < memory)
}
//information about cache files
val keyFile = "keyFile"
val prefix = "BufferMap"
}
/**
* the main interface to use the groupBy functionality
* a different use case could be to mix in the trait GroupBy wherever it is needed
*/
object GroupByInst extends GroupBy
trait GroupBy {
def groupBy[K, V](input: Iterator[(K, V)]): Iterator[(K, List[V])] = {
implicit val uuid = java.util.UUID.randomUUID()
val emptyMap = new HashMap[K, (List[V], ObjectOutputStream)]()
val stream = input.foldLeft(new CachedMapStream[K, V](new HashMap())){case (map, (k, v)) => map + (k -> v)}
stream.iterator
}
}
/**
* The CachedMapStream takes care of writing the data to disk whenever the main memory is full
*/
case class CachedMapStream[K, V](mymap: HashMap[K, List[V]])(implicit uuid: UUID) {
import SysInfo._
//initialize base cache folder
val baseDir = "tempdata"+uuid
val cachefile = new File(baseDir)
cachefile.mkdir()
/**
* add new element to mapping
*/
def +(kv:(K, V)):CachedMapStream[K, V] = {
debug("add: "+kv)
if (!hasMemory(mymap)) {
flushAll()
new CachedMapStream(new HashMap() + (kv._1 -> List(kv._2))) //start with a new empty map
} else {
kv match {
case (k, v) =>
CachedMapStream(mymap + (k -> (v :: mymap.getOrElse(k, Nil))))
}
}
}
/**
* retrieve element from mapping
*/
def get(k: K):Option[List[V]] = {
debug("get: "+k)
val file = new File(baseDir, prefix+(k.hashCode()))
val cachedValues = if (file.exists()) {
readAll[V](file)
} else {
Nil
}
(mymap.getOrElse(k, Nil) ::: cachedValues) match {
case Nil => None
case xs => Some(xs)
}
}
def apply(k: K):List[V] = {
debug("apply: "+k)
get(k) match {
case None => Nil //default
case Some(l) => l
}
}
/**
* Whenever the memory limit is reached we write all the data to disk
*/
def flushAll() = {
debug("flushAll")
val kfile = new File(baseDir, keyFile)
val kfos =
if (kfile.createNewFile())
new ObjectOutputStream(new FileOutputStream(kfile, true))
else
new ObjectOutputStream(new FileOutputStream(kfile, true)) {override def writeStreamHeader = {reset()}}
val newkeys = mymap.map {
case (k, list) =>
val file = new File(baseDir, prefix+(k.hashCode())) //HashCode is shorter than the full key
val fos =
if (file.createNewFile())
new ObjectOutputStream(new FileOutputStream(file, true))
else
new ObjectOutputStream(new FileOutputStream(file, true)) {override def writeStreamHeader = {reset()}}
try {
for (v <- list) {
fos.writeObject(v)
}
} catch {
case e:Throwable =>
println(s"EXCEPTION while flushing the values of $k $e")
} finally {
fos.flush()
fos.close()
}
k
}
for (k <- removeDuplicates(newkeys))
kfos.writeObject(k) //print key, should check for duplicates
kfos.flush()
kfos.close()
}
def removeDuplicates(list:Iterable[K]) = {
var deduped = list.toList
val kfile = new File(baseDir, keyFile)
if (!kfile.exists()) {
list
} else {
val kfos = new ObjectInputStream(new FileInputStream(kfile))
var nextElem:Option[K] = None
try {
while (true) {
val k = kfos.readObject().asInstanceOf[K]
if (deduped.contains(k))
deduped = deduped diff List(k)
}
} catch {
case e:Exception =>
kfos.close()
} finally {
}
}
deduped
}
def readAll[A](file: File):List[A] = { //generic to read both keys and values
debug("readall: "+file.getPath)
var retList:List[A] = Nil
val fos = new ObjectInputStream(new FileInputStream(file))
try {
while (true) { //for some reason fos.available() is always 0
val elem = fos.readObject().asInstanceOf[A]
retList = elem :: retList
}
} catch {
case e:Throwable =>
debug("Elements read: "+retList.mkString(", ")+e)
} finally {
fos.close()
}
retList
}
def iterator:Iterator[(K, List[V])] = {
val kfile = new File(baseDir, keyFile)
if (!kfile.exists()) {
mymap.keysIterator.map(k => (k, apply(k)))
} else {
new Iterator[(K, List[V])] {
val kit = mymap.keysIterator.map(k => (k, apply(k)))
val kfos = new ObjectInputStream(new FileInputStream(kfile))
var nextElem:Option[K] = None
try {
var next = kfos.readObject().asInstanceOf[K]
while (mymap.keySet.contains(next))
next = kfos.readObject().asInstanceOf[K]
nextElem = Some(next)
} catch {
case e:Exception =>
debug("End of keyfile issue")
nextElem = None
kfos.close()
}
def hasNext = kit.hasNext || !nextElem.isEmpty
def next = {
if (kit.hasNext) {
val (key, _) = kit.next() //discard the returned list
val list = apply(key)
(key, list)
} else {
val tmp = nextElem.get
try {
var next = kfos.readObject().asInstanceOf[K]
while (mymap.keySet.contains(next))
next = kfos.readObject().asInstanceOf[K]
nextElem = Some(next)
} catch {
case e:Exception =>
debug("End of keyfile")
nextElem = None
kfos.close()
}
(tmp, apply(tmp))
}
}
}
}
}
}
//a basic run class to quickly
object GroupByRun extends App {
val random = new Random( java.lang.System.currentTimeMillis())
val it = (0 until 50).map(_ => (random.nextInt(20), random.nextInt(1000)))
val it1 = ((0, 0) :: (1, 0) :: (1, 1) :: (2, 0) :: (2, 1) :: (2, 2) :: Nil)
//Wrapper class as a more complicated case of instance which needs to be serialized to disk
case class Wrap(i:Int)
val it2 = (for (i <- 0 to 30) yield for (j <- 0 to i) yield (Wrap(i), Wrap(j))).flatten
val res = GroupByInst.groupBy(it.toIterator).toList
println("INPUT:")
it.sorted.foldLeft(0)(printByKey(_, _))
println("")
println("")
println("OUTPUT:")
res.sortBy(_._1).foldLeft(0)(printByKey(_, _))
println("")
def printByKey[A, B](arg:(A, (A, B))): A = arg match {
case (acc, tup:(A, B)) =>
if (acc != tup._1)
println("")
print(tup)
tup._1
}
}
@cedricbastin
Copy link
Author

Oh yeah just in case: this will not give you a job offer from Databricks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment