Last active
November 15, 2020 09:15
-
-
Save HeartSaVioR/9a3aeeef0f1d8ee97516743308b14cd6 to your computer and use it in GitHub Desktop.
Implementation of session window with event time and watermark via flatMapGroupsWithState, and SPARK-10816
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
case class SessionInfo(sessionStartTimestampMs: Long, | |
sessionEndTimestampMs: Long, | |
numEvents: Int) { | |
/** Duration of the session, between the first and last events + session gap */ | |
def durationMs: Long = sessionEndTimestampMs - sessionStartTimestampMs | |
} | |
case class SessionUpdate(id: String, | |
sessionStartTimestampSecs: Long, | |
sessionEndTimestampSecs: Long, | |
durationSecs: Long, | |
numEvents: Int) | |
test("session window - flatMapGroupsWithState") { | |
import java.sql.Timestamp | |
val inputData = MemoryStream[(String, Long)] | |
val events = inputData.toDF() | |
.select($"_1".as("value"), $"_2".as("timestamp")) | |
.as[(String, Timestamp)] | |
.flatMap { case (v, timestamp) => | |
v.split(" ").map { word => (word, timestamp) } | |
} | |
.as[(String, Timestamp)] | |
.withWatermark("_2", "30 seconds") | |
val outputMode = OutputMode.Append() // below stateFunc also supports OutputMode.Update | |
val sessionGapMills = 10 * 1000 | |
val stateFunc: (String, Iterator[(String, Timestamp)], GroupState[List[SessionInfo]]) | |
=> Iterator[SessionUpdate] = | |
(sessionId: String, events: Iterator[(String, Timestamp)], | |
state: GroupState[List[SessionInfo]]) => { | |
def handleEvict(sessionId: String, state: GroupState[List[SessionInfo]]) | |
: Iterator[SessionUpdate] = { | |
state.getOption match { | |
case Some(lst) => | |
// assuming sessions are sorted by session start timestamp | |
val (evicted, kept) = lst.span { | |
s => s.sessionEndTimestampMs < state.getCurrentWatermarkMs() | |
} | |
if (kept.isEmpty) { | |
state.remove() | |
} else { | |
state.update(kept) | |
state.setTimeoutTimestamp(kept.head.sessionEndTimestampMs) | |
} | |
outputMode match { | |
case s if s == OutputMode.Append() => | |
evicted.iterator.map(si => SessionUpdate(sessionId, | |
si.sessionStartTimestampMs / 1000, | |
si.sessionEndTimestampMs / 1000, | |
si.durationMs / 1000, si.numEvents)) | |
case s if s == OutputMode.Update() => Seq.empty[SessionUpdate].iterator | |
case s => throw new UnsupportedOperationException(s"Not supported output mode $s") | |
} | |
case None => | |
state.remove() | |
Seq.empty[SessionUpdate].iterator | |
} | |
} | |
def mergeSession(session1: SessionInfo, session2: SessionInfo): SessionInfo = { | |
SessionInfo( | |
sessionStartTimestampMs = Math.min(session1.sessionStartTimestampMs, | |
session2.sessionStartTimestampMs), | |
sessionEndTimestampMs = Math.max(session1.sessionEndTimestampMs, | |
session2.sessionEndTimestampMs), | |
numEvents = session1.numEvents + session2.numEvents) | |
} | |
def handleEvents(sessionId: String, events: Iterator[(String, Timestamp)], | |
state: GroupState[List[SessionInfo]]): Iterator[SessionUpdate] = { | |
import java.{util => ju} | |
import scala.collection.mutable | |
import collection.JavaConverters._ | |
// we assume only previous sessions are sorted: events are not guaranteed to be sorted. | |
// we also assume the number of sessions for each key is not huge, which is valid | |
// unless end users set huge watermark delay as well as smaller session gap. | |
val newSessions: ju.LinkedList[SessionInfo] = state.getOption match { | |
case Some(lst) => new ju.LinkedList[SessionInfo](lst.asJava) | |
case None => new ju.LinkedList[SessionInfo]() | |
} | |
// this is to track the change of sessions for update mode | |
// if you define "update" as returning whole new sessions on given key, | |
// you can remove this and logic to track sessions | |
val updatedSessions = new mutable.ListBuffer[SessionInfo]() | |
while (events.hasNext) { | |
val ev = events.next() | |
// convert each event to one of session window | |
val event = SessionInfo(ev._2.getTime, ev._2.getTime + sessionGapMills, 1) | |
// find matched session | |
var index = 0 | |
var updated = false | |
while (!updated && index < newSessions.size()) { | |
val session = newSessions.get(index) | |
if (event.sessionEndTimestampMs < session.sessionStartTimestampMs) { | |
// no matched session, and following sessions will not be matched | |
newSessions.add(index, event) | |
updated = true | |
updatedSessions += event | |
} else if (event.sessionStartTimestampMs > session.sessionEndTimestampMs) { | |
// continue to next session | |
index += 1 | |
} else { | |
// matched: update session | |
var newSession = session.copy( | |
sessionStartTimestampMs = Math.min(session.sessionStartTimestampMs, | |
event.sessionStartTimestampMs), | |
sessionEndTimestampMs = Math.max(session.sessionEndTimestampMs, | |
event.sessionEndTimestampMs), | |
numEvents = session.numEvents + event.numEvents) | |
// we are going to replace previous session with new session, so previous session should be removed from updated sessions | |
// same occurs below if statements | |
updatedSessions -= session | |
// check for a chance to concatenate new session and next session | |
if (index + 1 < newSessions.size()) { | |
val nextSession = newSessions.get(index + 1) | |
if (newSession.sessionEndTimestampMs <= nextSession.sessionStartTimestampMs) { | |
newSession = mergeSession(newSession, nextSession) | |
updatedSessions -= nextSession | |
newSessions.remove(index + 1) | |
} | |
} | |
// check for a chance to concatenate new session and previous session | |
if (index - 1 >= 0) { | |
val prevSession = newSessions.get(index - 1) | |
if (newSession.sessionEndTimestampMs <= prevSession.sessionStartTimestampMs) { | |
newSession = mergeSession(newSession, prevSession) | |
updatedSessions -= prevSession | |
newSessions.remove(index - 1) | |
index -= 1 | |
} | |
} | |
newSessions.set(index, newSession) | |
updatedSessions += newSession | |
updated = true | |
} | |
} | |
if (!updated) { | |
// none matched so far, add to last | |
newSessions.addLast(event) | |
updatedSessions += event | |
} | |
} | |
val newSessionsForScala = newSessions.asScala.toList | |
state.update(newSessionsForScala) | |
// there must be at least one session available | |
// set timeout to earliest sessions' session end: we will traverse and evict sessions | |
state.setTimeoutTimestamp(newSessionsForScala.head.sessionEndTimestampMs) | |
outputMode match { | |
case s if s == OutputMode.Update() => | |
updatedSessions.iterator.map(si => | |
SessionUpdate(sessionId, si.sessionStartTimestampMs / 1000, | |
si.sessionEndTimestampMs / 1000, si.durationMs / 1000, si.numEvents)) | |
case s if s == OutputMode.Append() => Seq.empty[SessionUpdate].iterator | |
case s => throw new UnsupportedOperationException(s"Not supported output mode $s") | |
} | |
} | |
if (state.hasTimedOut) { | |
handleEvict(sessionId, state) | |
} else { | |
handleEvents(sessionId, events, state) | |
} | |
} | |
val sessionUpdates = events | |
.groupByKey(event => event._1) | |
.flatMapGroupsWithState[List[SessionInfo], SessionUpdate]( | |
outputMode, timeoutConf = GroupStateTimeout.EventTimeTimeout())(stateFunc) | |
// codes for verifying output place here | |
} | |
// below test code is providing same result as above | |
test("session window - session_window (SPARK-10816)") { | |
val inputData = MemoryStream[(String, Long)] | |
// Split the lines into words, treat words as sessionId of events | |
val events = inputData.toDF() | |
.select($"_1".as("value"), $"_2".as("timestamp")) | |
.withColumn("eventTime", $"timestamp".cast("timestamp")) | |
.selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime") | |
.withWatermark("eventTime", "10 seconds") | |
val sessionUpdates = events | |
.groupBy(session_window($"eventTime", "10 seconds") as 'session, 'sessionId) | |
.agg(count("*").as("numEvents")) | |
.selectExpr("sessionId", "CAST(session.start AS LONG)", "CAST(session.end AS LONG)", | |
"CAST(session.end AS LONG) - CAST(session.start AS LONG) AS durationMs", | |
"numEvents") | |
// codes for verifying output place here | |
} | |
/* | |
// for verifying append mode | |
testStream(sessionUpdates, outputMode)( | |
AddData(inputData, | |
("hello world spark streaming", 40L), | |
("world hello structured streaming", 41L) | |
), | |
// watermark: 11 | |
// current sessions | |
// ("hello", 40, 51, 11, 2), | |
// ("world", 40, 51, 11, 2), | |
// ("streaming", 40, 51, 11, 2), | |
// ("spark", 40, 50, 10, 1), | |
// ("structured", 41, 51, 10, 1) | |
CheckNewAnswer( | |
), | |
// placing new sessions "before" previous sessions | |
AddData(inputData, ("spark streaming", 25L)), | |
// watermark: 11 | |
// current sessions | |
// ("spark", 25, 35, 10, 1), | |
// ("streaming", 25, 35, 10, 1), | |
// ("hello", 40, 51, 11, 2), | |
// ("world", 40, 51, 11, 2), | |
// ("streaming", 40, 51, 11, 2), | |
// ("spark", 40, 50, 10, 1), | |
// ("structured", 41, 51, 10, 1) | |
CheckNewAnswer( | |
), | |
// late event which session's end 10 would be later than watermark 11: should be dropped | |
AddData(inputData, ("spark streaming", 0L)), | |
// watermark: 11 | |
// current sessions | |
// ("spark", 25, 35, 10, 1), | |
// ("streaming", 25, 35, 10, 1), | |
// ("hello", 40, 51, 11, 2), | |
// ("world", 40, 51, 11, 2), | |
// ("streaming", 40, 51, 11, 2), | |
// ("spark", 40, 50, 10, 1), | |
// ("structured", 41, 51, 10, 1) | |
CheckNewAnswer( | |
), | |
// concatenating multiple previous sessions into one | |
AddData(inputData, ("spark streaming", 30L)), | |
// watermark: 11 | |
// current sessions | |
// ("spark", 25, 50, 25, 3), | |
// ("streaming", 25, 51, 26, 4), | |
// ("hello", 40, 51, 11, 2), | |
// ("world", 40, 51, 11, 2), | |
// ("structured", 41, 51, 10, 1) | |
CheckNewAnswer( | |
), | |
// placing new sessions after previous sessions | |
AddData(inputData, ("hello apache spark", 60L)), | |
// watermark: 30 | |
// current sessions | |
// ("spark", 25, 50, 25, 3), | |
// ("streaming", 25, 51, 26, 4), | |
// ("hello", 40, 51, 11, 2), | |
// ("world", 40, 51, 11, 2), | |
// ("structured", 41, 51, 10, 1), | |
// ("hello", 60, 70, 10, 1), | |
// ("apache", 60, 70, 10, 1), | |
// ("spark", 60, 70, 10, 1) | |
CheckNewAnswer( | |
), | |
AddData(inputData, ("structured streaming", 90L)), | |
// watermark: 60 | |
// current sessions | |
// ("hello", 60, 70, 10, 1), | |
// ("apache", 60, 70, 10, 1), | |
// ("spark", 60, 70, 10, 1), | |
// ("structured", 90, 100, 10, 1), | |
// ("streaming", 90, 100, 10, 1) | |
CheckNewAnswer( | |
("spark", 25, 50, 25, 3), | |
("streaming", 25, 51, 26, 4), | |
("hello", 40, 51, 11, 2), | |
("world", 40, 51, 11, 2), | |
("structured", 41, 51, 10, 1) | |
) | |
) | |
*/ | |
/* | |
// for verifying update mode | |
testStream(sessionUpdates, outputMode)( | |
AddData(inputData, | |
("hello world spark streaming", 40L), | |
("world hello structured streaming", 41L) | |
), | |
// watermark: 11 | |
// current sessions | |
// ("hello", 40, 51, 11, 2), | |
// ("world", 40, 51, 11, 2), | |
// ("streaming", 40, 51, 11, 2), | |
// ("spark", 40, 50, 10, 1), | |
// ("structured", 41, 51, 10, 1) | |
CheckNewAnswer( | |
("hello", 40, 51, 11, 2), | |
("world", 40, 51, 11, 2), | |
("streaming", 40, 51, 11, 2), | |
("spark", 40, 50, 10, 1), | |
("structured", 41, 51, 10, 1) | |
), | |
// placing new sessions "before" previous sessions | |
AddData(inputData, ("spark streaming", 25L)), | |
// watermark: 11 | |
// current sessions | |
// ("spark", 25, 35, 10, 1), | |
// ("streaming", 25, 35, 10, 1), | |
// ("hello", 40, 51, 11, 2), | |
// ("world", 40, 51, 11, 2), | |
// ("streaming", 40, 51, 11, 2), | |
// ("spark", 40, 50, 10, 1), | |
// ("structured", 41, 51, 10, 1) | |
CheckNewAnswer( | |
("spark", 25, 35, 10, 1), | |
("streaming", 25, 35, 10, 1) | |
), | |
// late event which session's end 10 would be later than watermark 11: should be dropped | |
AddData(inputData, ("spark streaming", 0L)), | |
// watermark: 11 | |
// current sessions | |
// ("spark", 25, 35, 10, 1), | |
// ("streaming", 25, 35, 10, 1), | |
// ("hello", 40, 51, 11, 2), | |
// ("world", 40, 51, 11, 2), | |
// ("streaming", 40, 51, 11, 2), | |
// ("spark", 40, 50, 10, 1), | |
// ("structured", 41, 51, 10, 1) | |
CheckNewAnswer( | |
), | |
// concatenating multiple previous sessions into one | |
AddData(inputData, ("spark streaming", 30L)), | |
// watermark: 11 | |
// current sessions | |
// ("spark", 25, 50, 25, 3), | |
// ("streaming", 25, 51, 26, 4), | |
// ("hello", 40, 51, 11, 2), | |
// ("world", 40, 51, 11, 2), | |
// ("structured", 41, 51, 10, 1) | |
CheckNewAnswer( | |
("spark", 25, 50, 25, 3), | |
("streaming", 25, 51, 26, 4) | |
), | |
// placing new sessions after previous sessions | |
AddData(inputData, ("hello apache spark", 60L)), | |
// watermark: 30 | |
// current sessions | |
// ("spark", 25, 50, 25, 3), | |
// ("streaming", 25, 51, 26, 4), | |
// ("hello", 40, 51, 11, 2), | |
// ("world", 40, 51, 11, 2), | |
// ("structured", 41, 51, 10, 1), | |
// ("hello", 60, 70, 10, 1), | |
// ("apache", 60, 70, 10, 1), | |
// ("spark", 60, 70, 10, 1) | |
CheckNewAnswer( | |
("hello", 60, 70, 10, 1), | |
("apache", 60, 70, 10, 1), | |
("spark", 60, 70, 10, 1) | |
), | |
AddData(inputData, ("structured streaming", 90L)), | |
// watermark: 60 | |
// current sessions | |
// ("hello", 60, 70, 10, 1), | |
// ("apache", 60, 70, 10, 1), | |
// ("spark", 60, 70, 10, 1), | |
// ("structured", 90, 100, 10, 1), | |
// ("streaming", 90, 100, 10, 1) | |
// evicted | |
// ("spark", 25, 50, 25, 3), | |
// ("streaming", 25, 51, 26, 4), | |
// ("hello", 40, 51, 11, 2), | |
// ("world", 40, 51, 11, 2), | |
// ("structured", 41, 51, 10, 1) | |
CheckNewAnswer( | |
("structured", 90, 100, 10, 1), | |
("streaming", 90, 100, 10, 1) | |
) | |
) | |
*/ |
안녕하세요! :)
spark-sql artifact 의 test-jar 를 dependency 로 추가하시면 됩니다.
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_${scala.binary.version}</artifactId>
<version>${spark.version}</version>
<type>test-jar</type>
<scope>test</scope>
</dependency>
빠른 답변 감사합니다!
즐거운 한가위 보내세요 :)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
안녕하세요, 좋은 코드 감사합니다.
저는 spark에 대해서 배우기 시작한지 얼마 안 된 학생입니다.
다름이 아니라, 주석처리된 223번째 줄의 testStream의 경우 spark의 StreamTest에서 나온 것이 맞나요?
어떻게 가져다 쓰셨는지가 궁금해서 여쭤봅니다.
저의 경우에는
import org.apache.spark.sql.streaming.StreamTest
class StreamTestClass extends StreamTest
라고 해줘도 import 할 수가 없었습니다.
부족하지만 작은 도움이나마 주시면 매우 감사드리겠습니다.