Last active
November 6, 2022 14:59
-
-
Save blakewrege/a280b379d0d329467dd45c80c9993240 to your computer and use it in GitHub Desktop.
Parallel Prims Algorithm for Spark Graphx
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
4,6 | |
1,2,5 | |
1,3,8 | |
1,4,4 | |
2,3,8 | |
2,4,7 | |
3,4,1 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
((1,v1),(2,v2),5) | |
((1,v1),(3,v3),8) | |
((1,v1),(4,v4),4) | |
((2,v2),(3,v3),8) | |
((2,v2),(4,v4),7) | |
((3,v3),(4,v4),1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.spark._ | |
import org.apache.log4j.Logger | |
import org.apache.log4j.Level | |
import org.apache.spark.SparkContext | |
import org.apache.spark.SparkContext._ | |
import org.apache.spark.SparkConf | |
import org.apache.spark.graphx._ | |
import org.apache.spark.rdd.RDD | |
import org.apache.spark.graphx.util._ | |
object ParallelPrims { | |
Logger.getLogger("org").setLevel(Level.OFF) | |
Logger.getLogger("akka").setLevel(Level.OFF) | |
def main(args: Array[String]) { | |
val conf = new SparkConf().setAppName("Parallel Prims").setMaster("local") | |
val sc = new SparkContext(conf) | |
val logFile = "NodeData.txt" | |
val logData = sc.textFile(logFile, 2).cache() | |
// Splitting off header node | |
val headerAndRows = logData.map(line => line.split(",").map(_.trim)) | |
val header = headerAndRows.first | |
val data = headerAndRows.filter(_(0) != header(0)) | |
// Parse number of Nodes and Edges from header | |
val numNodes = header(0).toInt | |
val numEdges = header(1).toInt | |
val vertexArray = new Array[(Long, String)](numNodes) | |
val edgeArray = new Array[Edge[Int]](numEdges) | |
// Create vertex array | |
var count = 0 | |
for (count <- 0 to numNodes - 1) { | |
vertexArray(count) = (count.toLong + 1, ("v" + (count + 1)).toString()) | |
} | |
count = 0 | |
val rrdarr = data.take(data.count.toInt) | |
// Create edge array | |
for (count <- 0 to (numEdges - 1)) { | |
val line = rrdarr(count) | |
val cols = line.toList | |
val edge = Edge(cols(0).toLong, cols(1).toLong, cols(2).toInt) | |
edgeArray(count) = Edge(cols(0).toLong, cols(1).toLong, cols(2).toInt) | |
} | |
// Creating graphx graph | |
val vertexRDD: RDD[(Long, (String))] = sc.parallelize(vertexArray) | |
val edgeRDD: RDD[Edge[Int]] = sc.parallelize(edgeArray) | |
val graph: Graph[String, Int] = Graph(vertexRDD, edgeRDD) | |
graph.triplets.take(6).foreach(println) | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Node(id : Int) { | |
var edges = List[Edge]() | |
def addEdge(other : Node, weight : Int) = { | |
val edge = new Edge(this, other, weight) | |
edges = edge :: edges | |
edge | |
} | |
override def toString() = id.toString | |
} | |
class Edge(val from : Node, val to : Node, val weight : Int) extends Ordered[Edge] { | |
// Inverse ordering; should really be external. | |
def compare(that : Edge) = that.weight compare weight | |
override def toString() = from + " <--> " + to + " (" + weight + ")" | |
} | |
object Main { | |
def main(args : Array[String]) { | |
val n1 = new Node(1) | |
val n2 = new Node(2) | |
val n3 = new Node(3) | |
val n4 = new Node(4) | |
val n5 = new Node(5) | |
val n6 = new Node(6) | |
n1.addEdge(n2, 6) | |
n1.addEdge(n3, 1) | |
n1.addEdge(n4, 5) | |
n2.addEdge(n1, 6) | |
n2.addEdge(n3, 5) | |
n2.addEdge(n5, 3) | |
n3.addEdge(n1, 1) | |
n3.addEdge(n2, 5) | |
n3.addEdge(n4, 5) | |
n3.addEdge(n5, 6) | |
n3.addEdge(n6, 4) | |
n4.addEdge(n1, 5) | |
n4.addEdge(n3, 5) | |
n4.addEdge(n6, 2) | |
n5.addEdge(n2, 3) | |
n5.addEdge(n3, 6) | |
n5.addEdge(n6, 6) | |
n6.addEdge(n3, 4) | |
n6.addEdge(n4, 2) | |
n6.addEdge(n5, 6) | |
val graph = List(n1, n2, n3, n4, n5, n6) | |
generateMST(graph).sortWith(_ < _).foreach(println) | |
} | |
// Prim's algorithm: | |
// * Choose a random node. | |
// * Place edges from that node in priority queue. | |
// * While queue not empty: | |
// * pop highest weighted edge | |
// * if end-point in already-included set, continue | |
// * else: | |
// * add edge to edge-set | |
// * add node to node-set | |
// * add edges from new node to queue | |
def generateMST(graph : List[Node]) : List[Edge] = { | |
import scala.util.Random | |
val startNode = graph(Random.nextInt(graph.length)) | |
var mst_nodes = List(startNode) | |
var mst_edges = List[Edge]() | |
val pq = new scala.collection.mutable.PriorityQueue[Edge] | |
startNode.edges.foreach(e => pq.enqueue(e)) | |
while (!pq.isEmpty) { | |
val edge = pq.dequeue | |
if (!(mst_nodes contains edge.to)) { | |
mst_nodes = edge.to :: mst_nodes | |
mst_edges = edge :: mst_edges | |
edge.to.edges.foreach(e => pq.enqueue(e)) | |
} | |
} | |
mst_edges | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment