Skip to content

Instantly share code, notes, and snippets.

@GrigorievNick
Last active January 23, 2025 11:57
Show Gist options
  • Save GrigorievNick/bf920e32f70cb1cf8308cd601e415d12 to your computer and use it in GitHub Desktop.
Save GrigorievNick/bf920e32f70cb1cf8308cd601e415d12 to your computer and use it in GitHub Desktop.
Spark Structure Streaming GraceFullShutdown on Sigterm. Sigterm will not interrupt currently running batch, but due to asynс nature of SparkQueryListner.onProgres method, can interrupt next batch during first few moments.
import org.apache.hadoop.util.ShutdownHookManager
import org.apache.spark.sql.streaming.StreamingQueryListener
import org.apache.spark.sql.streaming.StreamingQueryManager
import org.slf4j.LoggerFactory
import java.util.UUID
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.SynchronousQueue
import java.util.concurrent.TimeUnit
/**
* Streaming Query stop, do not wait until batch ready.
* Just immediately cancel all spark job generated by stream, and interrupt query execution thread.
* This listener create a shutdown wait until current batch has finish, and try kill stream as fast as possible until next batch generated.
* But because on onQueryProgress is async, it's possible that next batch will start and will be canceled.
*/
class GracefulStopOnShutdownListener(streams: StreamingQueryManager) extends StreamingQueryListener {
private val log = LoggerFactory.getLogger(getClass)
private val runningQuery = new ConcurrentHashMap[UUID, (Runnable, SynchronousQueue[Boolean])]()
override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = {
val stream = streams.get(event.id)
val stopSignalChannel = new SynchronousQueue[Boolean]()
val shutdownHook: Runnable = () => {
if (stream.isActive) {
log.info(s"stop signal arrived,query ${stream.id} wait for until current batch ready")
val stopSignal = true
stopSignalChannel.put(stopSignal)
log.info(s"Send stop ${stream.id}")
stream.stop()
stream.awaitTermination()
log.info(s"Query ${stream.id} stopped")
}
}
ShutdownHookManager.get().addShutdownHook(shutdownHook, 100, 20, TimeUnit.MINUTES)
runningQuery.put(stream.id, (shutdownHook, stopSignalChannel))
log.info(s"Register shutdown hook for query ${event.id}")
}
override def onQueryProgress(event: StreamingQueryListener.QueryProgressEvent): Unit = {
log.info(s"Query ${event.progress.id} batch ready " + event.progress.batchId)
val (_, stopSignalChannel) = runningQuery.get(event.progress.id)
stopSignalChannel.poll()
}
override def onQueryTerminated(event: StreamingQueryListener.QueryTerminatedEvent): Unit = {
val (shutdownHook, stopSignalChannel) = runningQuery.remove(event.id)
log.info(s"Do shutdown hook for ${event.id} exist: " + ShutdownHookManager.get().hasShutdownHook(shutdownHook))
if (!ShutdownHookManager.get().isShutdownInProgress) ShutdownHookManager.get().removeShutdownHook(shutdownHook)
log.info(s"query ${event.id} shutdown, release hook.")
stopSignalChannel.poll()
}
}
import com.appsflyer.raw.data.ingestion.GracefulStopOnShutdownListener
import org.apache.hadoop.util.ShutdownHookManager
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
import org.slf4j.LoggerFactory
import java.util.concurrent.TimeUnit
import scala.concurrent.duration._
import scala.util.Try
object Main {
val log = LoggerFactory.getLogger(getClass)
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local").getOrCreate()
val dummyTablePath = "/tmp/table_path"
import spark.implicits._
(0 to 100).toDF.write.mode(SaveMode.Overwrite).parquet(dummyTablePath)
val ingestion = new Thread(() =>
Try {
while (true) {
(0 to 2).map(i => i -> s"${i}DummyVal").toDF.write.mode(SaveMode.Append).parquet(dummyTablePath)
Thread.sleep(1000)
}
}
)
ingestion.start()
def runBatch(data: DataFrame, id: Long): Unit = {
spark.sparkContext.setJobDescription(id.toString)
log.info(s"start batch $id")
// Emulate long job execution
val res = data.rdd.map { x => Thread.sleep(10.seconds.toMillis); x }.take(2)
log.info(s"Batch $id size ${res.mkString("Array(", ", ", ")")}")
}
// comment to reproduce issue
spark.streams.addListener(new GracefulStopOnShutdownListener(spark.streams))
val stream = spark
.readStream
.schema(StructType(Seq(StructField("int", IntegerType))))
.option("maxFilesPerTrigger", "1")
.parquet(dummyTablePath)
.writeStream
.foreachBatch(runBatch _)
.start()
// Uncomment to reproduce issue
// val shutdownHook: Runnable = () => {
// if (stream.isActive) {
// log.info(s"stop signal arrived,query ${stream.id} wait for until current batch ready")
// stream.stop()
// stream.awaitTermination()
// log.info(s"Query ${stream.id} stopped")
// }
// }
// ShutdownHookManager.get().addShutdownHook(shutdownHook, 100, 10, TimeUnit.MINUTES)
// dummy hook to have ability check spark UI
ShutdownHookManager.get().addShutdownHook(() => Thread.sleep(1000000000), 99, 10, TimeUnit.MINUTES)
stream.awaitTermination()
println("finish")
}
}
@idkburkes
Copy link

Great job on this! This really should be native to spark structured streaming without the custom implementation.

@GrigorievNick
Copy link
Author

With spark 3.5. Structure Stream query listener i terface achieve new function. OnIdle, this can help to make it even more graceful

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment