Last active
January 23, 2025 11:57
-
-
Save GrigorievNick/bf920e32f70cb1cf8308cd601e415d12 to your computer and use it in GitHub Desktop.
Spark Structure Streaming GraceFullShutdown on Sigterm. Sigterm will not interrupt currently running batch, but due to asynс nature of SparkQueryListner.onProgres method, can interrupt next batch during first few moments.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.hadoop.util.ShutdownHookManager | |
import org.apache.spark.sql.streaming.StreamingQueryListener | |
import org.apache.spark.sql.streaming.StreamingQueryManager | |
import org.slf4j.LoggerFactory | |
import java.util.UUID | |
import java.util.concurrent.ConcurrentHashMap | |
import java.util.concurrent.SynchronousQueue | |
import java.util.concurrent.TimeUnit | |
/** | |
* Streaming Query stop, do not wait until batch ready. | |
* Just immediately cancel all spark job generated by stream, and interrupt query execution thread. | |
* This listener create a shutdown wait until current batch has finish, and try kill stream as fast as possible until next batch generated. | |
* But because on onQueryProgress is async, it's possible that next batch will start and will be canceled. | |
*/ | |
class GracefulStopOnShutdownListener(streams: StreamingQueryManager) extends StreamingQueryListener { | |
private val log = LoggerFactory.getLogger(getClass) | |
private val runningQuery = new ConcurrentHashMap[UUID, (Runnable, SynchronousQueue[Boolean])]() | |
override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = { | |
val stream = streams.get(event.id) | |
val stopSignalChannel = new SynchronousQueue[Boolean]() | |
val shutdownHook: Runnable = () => { | |
if (stream.isActive) { | |
log.info(s"stop signal arrived,query ${stream.id} wait for until current batch ready") | |
val stopSignal = true | |
stopSignalChannel.put(stopSignal) | |
log.info(s"Send stop ${stream.id}") | |
stream.stop() | |
stream.awaitTermination() | |
log.info(s"Query ${stream.id} stopped") | |
} | |
} | |
ShutdownHookManager.get().addShutdownHook(shutdownHook, 100, 20, TimeUnit.MINUTES) | |
runningQuery.put(stream.id, (shutdownHook, stopSignalChannel)) | |
log.info(s"Register shutdown hook for query ${event.id}") | |
} | |
override def onQueryProgress(event: StreamingQueryListener.QueryProgressEvent): Unit = { | |
log.info(s"Query ${event.progress.id} batch ready " + event.progress.batchId) | |
val (_, stopSignalChannel) = runningQuery.get(event.progress.id) | |
stopSignalChannel.poll() | |
} | |
override def onQueryTerminated(event: StreamingQueryListener.QueryTerminatedEvent): Unit = { | |
val (shutdownHook, stopSignalChannel) = runningQuery.remove(event.id) | |
log.info(s"Do shutdown hook for ${event.id} exist: " + ShutdownHookManager.get().hasShutdownHook(shutdownHook)) | |
if (!ShutdownHookManager.get().isShutdownInProgress) ShutdownHookManager.get().removeShutdownHook(shutdownHook) | |
log.info(s"query ${event.id} shutdown, release hook.") | |
stopSignalChannel.poll() | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import com.appsflyer.raw.data.ingestion.GracefulStopOnShutdownListener | |
import org.apache.hadoop.util.ShutdownHookManager | |
import org.apache.spark.sql.DataFrame | |
import org.apache.spark.sql.SaveMode | |
import org.apache.spark.sql.SparkSession | |
import org.apache.spark.sql.types.IntegerType | |
import org.apache.spark.sql.types.StructField | |
import org.apache.spark.sql.types.StructType | |
import org.slf4j.LoggerFactory | |
import java.util.concurrent.TimeUnit | |
import scala.concurrent.duration._ | |
import scala.util.Try | |
object Main { | |
val log = LoggerFactory.getLogger(getClass) | |
def main(args: Array[String]): Unit = { | |
val spark = SparkSession.builder().master("local").getOrCreate() | |
val dummyTablePath = "/tmp/table_path" | |
import spark.implicits._ | |
(0 to 100).toDF.write.mode(SaveMode.Overwrite).parquet(dummyTablePath) | |
val ingestion = new Thread(() => | |
Try { | |
while (true) { | |
(0 to 2).map(i => i -> s"${i}DummyVal").toDF.write.mode(SaveMode.Append).parquet(dummyTablePath) | |
Thread.sleep(1000) | |
} | |
} | |
) | |
ingestion.start() | |
def runBatch(data: DataFrame, id: Long): Unit = { | |
spark.sparkContext.setJobDescription(id.toString) | |
log.info(s"start batch $id") | |
// Emulate long job execution | |
val res = data.rdd.map { x => Thread.sleep(10.seconds.toMillis); x }.take(2) | |
log.info(s"Batch $id size ${res.mkString("Array(", ", ", ")")}") | |
} | |
// comment to reproduce issue | |
spark.streams.addListener(new GracefulStopOnShutdownListener(spark.streams)) | |
val stream = spark | |
.readStream | |
.schema(StructType(Seq(StructField("int", IntegerType)))) | |
.option("maxFilesPerTrigger", "1") | |
.parquet(dummyTablePath) | |
.writeStream | |
.foreachBatch(runBatch _) | |
.start() | |
// Uncomment to reproduce issue | |
// val shutdownHook: Runnable = () => { | |
// if (stream.isActive) { | |
// log.info(s"stop signal arrived,query ${stream.id} wait for until current batch ready") | |
// stream.stop() | |
// stream.awaitTermination() | |
// log.info(s"Query ${stream.id} stopped") | |
// } | |
// } | |
// ShutdownHookManager.get().addShutdownHook(shutdownHook, 100, 10, TimeUnit.MINUTES) | |
// dummy hook to have ability check spark UI | |
ShutdownHookManager.get().addShutdownHook(() => Thread.sleep(1000000000), 99, 10, TimeUnit.MINUTES) | |
stream.awaitTermination() | |
println("finish") | |
} | |
} |
Thank you for sharing - very useful for me. I had to convert it to Java - so I wanted to share this, too.
import static java.util.Objects.requireNonNull;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.TimeUnit;
import org.apache.hadoop.util.ShutdownHookManager;
import org.apache.spark.sql.streaming.StreamingQueryException;
import org.apache.spark.sql.streaming.StreamingQueryListener;
import org.apache.spark.sql.streaming.StreamingQueryManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Credits to https://gist.github.com/GrigorievNick/bf920e32f70cb1cf8308cd601e415d12
*/
public class GracefulStopOnShutdownListener extends StreamingQueryListener {
private static final Logger logger = LoggerFactory.getLogger(GracefulStopOnShutdownListener.class);
private final StreamingQueryManager streamingQueryManager;
private final Map<UUID, Tuple> runningQuery = new ConcurrentHashMap<>();
public GracefulStopOnShutdownListener(StreamingQueryManager streamingQueryManager) {
this.streamingQueryManager = requireNonNull(streamingQueryManager);
}
@Override
public void onQueryStarted(QueryStartedEvent event) {
final var stream = streamingQueryManager.get(event.id());
var stopSignalChannel = new SynchronousQueue<Boolean>();
Runnable shutdownHook = () -> {
if (stream.isActive()) {
try {
logger.info("stop signal arrived, query {} wait for until current batch ready", stream.id());
stopSignalChannel.put(Boolean.TRUE);
logger.info("Send stop {}", stream.id());
stream.stop();
stream.awaitTermination();
logger.info("Query {} stopped", stream.id());
} catch (InterruptedException e) {
logger.warn("Interrupted", e);
Thread.currentThread().interrupt();
} catch (StreamingQueryException e) {
logger.warn("Unexpected exception", e);
throw new RuntimeException(e);
}
}
};
ShutdownHookManager.get().addShutdownHook(shutdownHook, 100, 20, TimeUnit.MINUTES);
runningQuery.put(stream.id(), new Tuple(shutdownHook, stopSignalChannel));
logger.info("Registered shutdown hook for query {}", stream.id());
}
@Override
public void onQueryProgress(QueryProgressEvent event) {
logger.info("onQueryProgress: Query batch {} finished ", event.progress().batchId());
var tuple = runningQuery.get(event.progress().id());
tuple.stopSignalChannel.poll();
}
@Override
public void onQueryTerminated(QueryTerminatedEvent event) {
var tuple = runningQuery.remove(event.id());
logger.info("Does shutdown hook for {} exist: {}", event.id(), ShutdownHookManager.get().hasShutdownHook(tuple.shutdownHook));
if (!ShutdownHookManager.get().isShutdownInProgress()) ShutdownHookManager.get().removeShutdownHook(tuple.shutdownHook);
logger.info("query {} shutdown, release hook.", event.id());
tuple.stopSignalChannel.poll();
}
private static class Tuple {
final Runnable shutdownHook;
final SynchronousQueue<Boolean> stopSignalChannel;
Tuple(Runnable shutdownHook, SynchronousQueue<Boolean> stopSignalChannel) {
this.shutdownHook = shutdownHook;
this.stopSignalChannel = stopSignalChannel;
}
}
}
Great job on this! This really should be native to spark structured streaming without the custom implementation.
With spark 3.5. Structure Stream query listener i terface achieve new function. OnIdle, this can help to make it even more graceful
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
stopSignalChannel.poll
to unblock shutdown hook thread inonQueryTerminated
put
thread blocked, until some other calltake
orpoll
. Please check Java doc to this class for more details.stopSignalChannel: SynchronousQueue
. As I say just check how this class behave and what it's purpose. Brieflyput
block shutdown hook thread, untilpoll
from other Listener thread called.