Last active
May 24, 2025 17:30
-
-
Save takezoe/b9e6f4b92bff65a3adb6fca7ebce6fd9 to your computer and use it in GitHub Desktop.
Simple utility to rewrite duplicated subqueries to CTE
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package io.github.takezoe.cte | |
import wvlet.airframe.sql.model.Expression.UnquotedIdentifier | |
import wvlet.airframe.sql.model.LogicalPlan | |
import wvlet.airframe.sql.model.LogicalPlan._ | |
import wvlet.airframe.sql.parser.SQLParser.{parse => parseSQL} | |
import wvlet.airframe.sql.parser.SQLGenerator.{print => toSQL} | |
// Simple utility to rewrite duplicated subqueries to CTE. | |
// | |
// Add the following library dependency to your project to use this: | |
// "org.wvlet.airframe" %% "airframe-sql" % "2025.1.10" | |
class SubqueryToCteRewriter(threshold: Int = 2, convertSimpleSelect: Boolean = false) { | |
def rewrite(sql: String): String = { | |
val logicalPlan = parseSQL(sql) | |
val duplicatedSubqueries = extractDuplicatedSubqueries(logicalPlan) | |
if (duplicatedSubqueries.isEmpty) { | |
sql | |
} else { | |
// Replace subqueries with CTERelationRef | |
val result = logicalPlan.transform { | |
case p: Project => | |
duplicatedSubqueries.get(toSQL(p)) match { | |
case Some(name) => parseSQL(s"SELECT * FROM $name") | |
case _ => p | |
} | |
} | |
// Add or update WITH clause | |
val queries = duplicatedSubqueries.map { case (sql, name) => | |
val p = parseSQL(sql) | |
WithQuery(UnquotedIdentifier(name, None), p.asInstanceOf[Relation], None, None) | |
}.toSeq | |
val newPlan = if (hasWith(logicalPlan)) { | |
result.transform { | |
case w: With => w.copy(queries = w.queries ++ queries) | |
} | |
} else { | |
Query(With(recursive = false, queries, None), result.asInstanceOf[Relation], None) | |
} | |
toSQL(newPlan) | |
} | |
} | |
// Map(SQL -> Query name) | |
private def extractDuplicatedSubqueries(p: LogicalPlan): Map[String, String] = { | |
val subqueries = scala.collection.mutable.Map[String, Int]() | |
p.traverse { | |
case p: Project if convertSimpleSelect || !isSimpleSelect(p) => | |
subqueries.updateWith(toSQL(p)){ | |
case Some(i) => Some(i + 1) | |
case None => Some(1) | |
} | |
} | |
val duplicatedSubqueries = subqueries | |
.filter { case (_, count) => count >= threshold } | |
.keys | |
.zipWithIndex.map { case (sql, i) => (sql, s"t$i") } | |
.toMap | |
duplicatedSubqueries | |
} | |
private def hasWith(p: LogicalPlan): Boolean = { | |
var hasWithClause = false | |
p.traverse { | |
case _: With => hasWithClause = true | |
} | |
hasWithClause | |
} | |
private def isSimpleSelect(p: Project): Boolean = { | |
p.children.size == 1 && p.children.head.isInstanceOf[TableRef] | |
} | |
} | |
object SubqueryToCteRewriter extends App { | |
println(new SubqueryToCteRewriter(convertSimpleSelect = true).rewrite("SELECT * FROM a WHERE id = 1 UNION ALL SELECT * FROM a WHERE id = 2")) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment