Created
September 30, 2022 15:26
-
-
Save chnlkw/3627a61df0b5e7e0fc266430ef380f51 to your computer and use it in GitHub Desktop.
simple Graph Execution Plan, with evaluator and type inference
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
object ExamplePlan extends App { | |
sealed trait Plan | |
case class ScanNode(label: String, filter: Map[String, Plan], nodeVar: String) extends Plan | |
case class GetEdge(label: String, filter: Map[String, Plan], srcRef: String, edgeVar: String, dstVar: String, prev: Plan) extends Plan | |
case class GetNodeProp(label: String, nodeVar: String, prev: Plan) extends Plan | |
case class Projection(columns: Map[String, Plan], prev: Plan) extends Plan | |
case class GetField(fieldName: String, x: Plan) extends Plan | |
case class Ref(x: String) extends Plan | |
case class VString(value: String) extends Plan | |
trait Graph { | |
type NodeKey = Int | |
type NodeValue = Map[String, Any] | |
type EdgeValue = Map[String, Any] | |
val v: Map[String, Map[NodeKey, NodeValue]] | |
val e: Map[String, Map[NodeKey, List[(NodeKey, EdgeValue)]]] | |
def nodes(label: String): Seq[(NodeKey, NodeValue)] = v(label).toSeq | |
def node(label: String, key: NodeKey): Option[NodeValue] = v(label).get(key) | |
def edges(label: String): Seq[(NodeKey, EdgeValue, NodeKey)] = | |
for { | |
(src, edges) <- e(label).toSeq | |
(dst, e) <- edges | |
} yield (src, e, dst) | |
def neighbours(label: String, src: NodeKey): Seq[(NodeKey, EdgeValue)] = e(label)(src) | |
} | |
val myGraph: Graph = new Graph { | |
val v: Map[String, Map[NodeKey, NodeValue]] = Map( | |
"Person" -> Map( | |
1 -> Map("name" -> "Alice"), | |
2 -> Map("name" -> "Bob"), | |
3 -> Map("name" -> "Charlie"), | |
4 -> Map("name" -> "David"), | |
) | |
) | |
val e: Map[String, Map[NodeKey, List[(NodeKey, EdgeValue)]]] = Map( | |
"Friend" -> Map( | |
1 -> List(2 -> Map()), | |
2 -> List(3 -> Map(), 4 -> Map()), | |
) | |
) | |
} | |
def evaluator(g: Graph, p: Plan, context: Map[String, Any]): Any = p match { | |
case ScanNode(label, filter, nodeVar) => | |
g.nodes(label) | |
.filter { case (nodeKey, nodeValue) => | |
val k: Seq[Boolean] = filter.toSeq.map { case (propName, propValue) => nodeValue(propName) == evaluator(g, propValue, context) } | |
k.forall(x => x) | |
} | |
.map { case (nodeKey, nodeValue) => Map(nodeVar -> (nodeValue + ("id" -> nodeKey))) } | |
case GetEdge(label: String, filter: Map[String, Plan], srcRef: String, edgeVar: String, dstVar: String, prev: Plan) => | |
evaluator(g, prev, context).asInstanceOf[Seq[Map[String, Any]]].flatMap( | |
(row: Map[String, Any]) => { | |
val srcKey = row(srcRef).asInstanceOf[Map[String, Any]]("id").asInstanceOf[g.NodeKey] | |
g.neighbours(label, srcKey).map { case (dstKey, edgeValue) => | |
row + (dstVar -> Map("id" -> dstKey)) + (edgeVar -> edgeValue) | |
} | |
} | |
) | |
case GetNodeProp(label: String, nodeVar: String, prev: Plan) => | |
evaluator(g, prev, context).asInstanceOf[Seq[Map[String, Any]]].map( | |
(row: Map[String, Any]) => { | |
val nodeWithId = row(nodeVar).asInstanceOf[Map[String, Any]] | |
val nodeKey = nodeWithId("id").asInstanceOf[g.NodeKey] | |
val nodeValue = g.node(label, nodeKey).get | |
row + (nodeVar -> (nodeWithId ++ nodeValue)) | |
} | |
) | |
case Projection(columns: Map[String, Plan], prev: Plan) => | |
evaluator(g, prev, context).asInstanceOf[Seq[Map[String, Any]]].map( | |
(row: Map[String, Any]) => { | |
val newCtx = context ++ row | |
columns.map { case (colName, expr) => colName -> evaluator(g, expr, newCtx) } | |
} | |
) | |
case GetField(fieldName: String, x: Plan) => evaluator(g, x, context).asInstanceOf[Map[String, Any]](fieldName) | |
case Ref(x: String) => context(x) | |
case VString(value: String) => value | |
} | |
val getTwoHopNeighbourCypher: String = "MATCH (a:Person {name:'Alice'})-[e*2:Friend]->(b:Person) RETURN b.name" | |
val p1 = ScanNode(label = "Person", filter = Map("name" -> VString("Alice")), nodeVar = "a") | |
val p2 = GetEdge(label = "Friend", filter = Map(), srcRef = "a", edgeVar = "e", dstVar = "b1", prev = p1) | |
val p3 = GetEdge(label = "Friend", filter = Map(), srcRef = "b1", edgeVar = "e2", dstVar = "b", prev = p2) | |
val p4 = GetNodeProp(label = "Person", nodeVar = "b", prev = p3) | |
val p5 = Projection(columns = Map("twoHopFriendName" -> GetField("name", Ref("b"))), prev = p4) | |
println(evaluator(myGraph, p1, Map())) | |
println(evaluator(myGraph, p2, Map())) | |
println(evaluator(myGraph, p3, Map())) | |
println(evaluator(myGraph, p4, Map())) | |
println(evaluator(myGraph, p5, Map())) | |
sealed trait Ty | |
case class Record(fields: Map[String, Ty]) extends Ty | |
case object TString extends Ty | |
case object TInt extends Ty | |
case object TBool extends Ty | |
case class Table(row: Ty) extends Ty | |
trait GraphSchema { | |
val nodeTypes: Map[String, (Ty, Map[String, Ty])] // label -> (keyType, propTypeMap) | |
val edgeTypes: Map[String, (String, String, Map[String, Ty])] // edgeLabel -> (srcLabel, srcType, dstLabel, dstType, edgePropTypeMap) | |
} | |
val graphSchema: GraphSchema = new GraphSchema { | |
override val nodeTypes: Map[String, (Ty, Map[String, Ty])] = Map("Person" -> (TInt, Map("name" -> TString))) | |
override val edgeTypes: Map[String, (String, String, Map[String, Ty])] = Map("Friend" -> ("Person", "Person", Map())) | |
} | |
def inferType(plan: Plan, graphSchema: GraphSchema, ctx: Map[String, Ty]): Ty = plan match { | |
case Ref(x) => ctx.getOrElse(x, throw Exception(s"variable not found $x")) | |
case VString(value) => TString | |
case ScanNode(label, filter, nodeVar) => Table(Record(Map( | |
nodeVar -> { | |
val (k, v) = graphSchema.nodeTypes(label) | |
Record(v + ("id" -> k)) | |
} | |
))) | |
case Projection(columns, prev) => | |
val prevTy: Ty = inferType(prev, graphSchema, ctx) | |
val Table(Record(prevCols: Map[String, Ty])) = prevTy | |
val newCtx = ctx ++ prevCols | |
Table(Record(columns.map { case (colName, itemPlan) => | |
colName -> inferType(itemPlan, graphSchema, newCtx) | |
})) | |
case GetField(fieldName, x) => inferType(x, graphSchema, ctx) match { | |
case Record(fields) => fields.getOrElse(fieldName, throw Exception(s"field not found $x")) | |
case _ => throw Exception("not record type") | |
} | |
case GetEdge(label, filter, srcRef, edgeVar, dstVar, prev) => | |
val Table(Record(prevCols)) = inferType(prev, graphSchema, ctx) | |
val (srcLabel, dstLabel, edgePropTy) = graphSchema.edgeTypes(label) | |
val (dstKeyTy, dstPropTy) = graphSchema.nodeTypes(dstLabel) | |
Table(Record(prevCols + (edgeVar -> Record(edgePropTy)) + (dstVar -> Record(Map("id" -> dstKeyTy))))) | |
case GetNodeProp(label, nodeVar, prev) => | |
val Table(Record(prevCols)) = inferType(prev, graphSchema, ctx) | |
val nodePropTy = graphSchema.nodeTypes(label)._2 | |
val nodeVarType = prevCols(nodeVar) match { | |
case Record(fields) => Record(fields ++ nodePropTy) | |
} | |
Table(Record(prevCols + (nodeVar -> nodeVarType))) | |
} | |
val t1 = inferType(p1, graphSchema, Map()); | |
val t2 = inferType(p2, graphSchema, Map()); | |
val t3 = inferType(p3, graphSchema, Map()); | |
val t4 = inferType(p4, graphSchema, Map()); | |
val t5 = inferType(p5, graphSchema, Map()); | |
println(t1) | |
println(t2) | |
println(t3) | |
println(t4) | |
println(t5) | |
// val t1 = Table(Record(Map("a" -> Record(Map("name" -> TString, "id" -> TInt))))) | |
// val t5 = Table(Record(Map("towHopFriendName" -> TString))) | |
case class LetRec(v: String, exp: Plan, next: Plan) extends Plan | |
case class Lambda(x: String, v: Plan) extends Plan | |
case class Apply(f: Plan, v: Plan) extends Plan | |
case class VInt(n: Int) extends Plan | |
case class VBool(value: Boolean) extends Plan | |
case class If(cond: Plan, trueBody: Plan, falseBody: Plan) extends Plan | |
case class PrimOp(op: String, args: List[Plan]) extends Plan | |
case class EmptyTable() extends Plan | |
case class ConcatTable(l: Plan, r: Plan) extends Plan | |
extension (p: Plan) { | |
def <(r: Plan): Plan = PrimOp("<", List(p, r)) | |
def >(r: Plan): Plan = PrimOp(">", List(p, r)) | |
def +(r: Plan): Plan = PrimOp("+", List(p, r)) | |
def apply(r: Plan): Plan = Apply(p, r) | |
} | |
val recursivePlan = LetRec( | |
"func", | |
Lambda("input", Lambda("hop", | |
If(Ref("hop") < VInt(10), | |
{ // then concat(input, input.getEdge.getEdge) | |
val a = Ref("input") | |
val b = GetEdge(label = "Friend", filter = Map("gender" -> VBool(true)), srcRef = "a", edgeVar = "e", dstVar = "b", prev = a) | |
val c = GetEdge(label = "Friend", filter = Map("gender" -> VBool(false)), srcRef = "b", edgeVar = "f", dstVar = "c", prev = b) | |
val next = Projection(columns = Map("a" -> Ref("c")), prev = c) | |
ConcatTable( | |
Ref("input"), | |
Ref("func").apply(next).apply(Ref("hop") + VInt(1)) | |
) | |
} | |
, | |
{ // else | |
EmptyTable() | |
} | |
) | |
)), | |
Ref("func").apply(ScanNode("Person", Map("name" -> VString("Alice")), "a")).apply(VInt(0)) | |
) | |
println(recursivePlan) | |
case class LetRec2(v: String, ty: Ty, exp: Plan, next: Plan) extends Plan | |
object RecursiveScheme { | |
case class Fix[F[_]](unfix: F[Fix[F]]) | |
sealed trait PlanF[A] | |
case class ScanNode[A](label: String, filter: Map[String, A], nodeVar: String) extends PlanF[A] | |
case class Projection[A](columns: Map[String, A], prev: Plan) extends PlanF[A] | |
type Plan = Fix[PlanF] | |
case class TyPlanF[A](ty: Ty, plan: PlanF[A]) | |
type TyPlan = Fix[TyPlanF] | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment