This is an automated email from the ASF dual-hosted git repository.
kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 88b29c5076d4 [SPARK-47570][SS] Integrate range scan encoder changes
with timer implementation
88b29c5076d4 is described below
commit 88b29c5076d48f4ecbed402a693a8ccce57cd7d0
Author: jingz-db <[email protected]>
AuthorDate: Wed Mar 27 13:37:48 2024 +0900
[SPARK-47570][SS] Integrate range scan encoder changes with timer
implementation
### What changes were proposed in this pull request?
Previously timer state implementation was using No prefix rocksdb state
encoder. When doing `iterator()` or `prefix()`, the returned iterator is not
sorted on timestamp value. After Anish's PR for supporting range scan encoder,
we could integrate it with `TimerStateImpl` such that we will use range scan
encoder on `timer to key`.
### Why are the changes needed?
The changes are part of the work around adding new stateful streaming
operator for arbitrary state mgmt that provides a bunch of new features listed
in the SPIP JIRA here - https://issues.apache.org/jira/browse/SPARK-45939
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Added unit tests in `TimerSuite`
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #45709 from jingz-db/integrate-range-scan.
Authored-by: jingz-db <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
---
.../streaming/StatefulProcessorHandleImpl.scala | 8 ++-
.../sql/execution/streaming/TimerStateImpl.scala | 19 ++++--
.../streaming/TransformWithStateExec.scala | 16 ++---
.../sql/execution/streaming/state/TimerSuite.scala | 69 +++++++++++++++++++---
4 files changed, 85 insertions(+), 27 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
index 9b905ad5235d..5f3b794fd117 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
@@ -163,12 +163,14 @@ class StatefulProcessorHandleImpl(
}
/**
- * Function to retrieve all registered timers for all grouping keys
+ * Function to retrieve all expired registered timers for all grouping keys
+ * @param expiryTimestampMs Threshold for expired timestamp in milliseconds,
this function
+ * will return all timers that have timestamp less
than passed threshold
* @return - iterator of registered timers for all grouping keys
*/
- def getExpiredTimers(): Iterator[(Any, Long)] = {
+ def getExpiredTimers(expiryTimestampMs: Long): Iterator[(Any, Long)] = {
verifyTimerOperations("get_expired_timers")
- timerState.getExpiredTimers()
+ timerState.getExpiredTimers(expiryTimestampMs)
}
/**
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala
index 6166374d25e9..af321eecb4db 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala
@@ -91,7 +91,7 @@ class TimerStateImpl(
val tsToKeyCFName = timerCFName + TimerStateUtils.TIMESTAMP_TO_KEY_CF
store.createColFamilyIfAbsent(tsToKeyCFName, keySchemaForSecIndex,
- schemaForValueRow, NoPrefixKeyStateEncoderSpec(keySchemaForSecIndex),
+ schemaForValueRow, RangeKeyScanStateEncoderSpec(keySchemaForSecIndex, 1),
useMultipleValuesPerKey = false, isInternal = true)
private def getGroupingKey(cfName: String): Any = {
@@ -110,7 +110,6 @@ class TimerStateImpl(
// We maintain a secondary index that inverts the ordering of the timestamp
// and grouping key
- // TODO: use range scan encoder to encode the secondary index key
private def encodeSecIndexKey(groupingKey: Any, expiryTimestampMs: Long):
UnsafeRow = {
val keyByteArr =
keySerializer.apply(groupingKey).asInstanceOf[UnsafeRow].getBytes()
val keyRow = secIndexKeyEncoder(InternalRow(expiryTimestampMs, keyByteArr))
@@ -187,10 +186,15 @@ class TimerStateImpl(
}
/**
- * Function to get all the registered timers for all grouping keys
+ * Function to get all the expired registered timers for all grouping keys.
+ * Perform a range scan on timestamp and will stop iterating once the key
row timestamp equals or
+ * exceeds the limit (as timestamp key is increasingly sorted).
+ * @param expiryTimestampMs Threshold for expired timestamp in milliseconds,
this function
+ * will return all timers that have timestamp less
than passed threshold.
* @return - iterator of all the registered timers for all grouping keys
*/
- def getExpiredTimers(): Iterator[(Any, Long)] = {
+ def getExpiredTimers(expiryTimestampMs: Long): Iterator[(Any, Long)] = {
+ // this iter is increasingly sorted on timestamp
val iter = store.iterator(tsToKeyCFName)
new NextIterator[(Any, Long)] {
@@ -199,7 +203,12 @@ class TimerStateImpl(
val rowPair = iter.next()
val keyRow = rowPair.key
val result = getTimerRowFromSecIndex(keyRow)
- result
+ if (result._2 < expiryTimestampMs) {
+ result
+ } else {
+ finished = true
+ null.asInstanceOf[(Any, Long)]
+ }
} else {
finished = true
null.asInstanceOf[(Any, Long)]
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
index 39365e92185a..d3640ebd8850 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
@@ -160,26 +160,18 @@ case class TransformWithStateExec(
case ProcessingTime =>
assert(batchTimestampMs.isDefined)
val batchTimestamp = batchTimestampMs.get
- val procTimeIter = processorHandle.getExpiredTimers()
- procTimeIter.flatMap { case (keyObj, expiryTimestampMs) =>
- if (expiryTimestampMs < batchTimestamp) {
+ processorHandle.getExpiredTimers(batchTimestamp)
+ .flatMap { case (keyObj, expiryTimestampMs) =>
handleTimerRows(keyObj, expiryTimestampMs, processorHandle)
- } else {
- Iterator.empty
}
- }
case EventTime =>
assert(eventTimeWatermarkForEviction.isDefined)
val watermark = eventTimeWatermarkForEviction.get
- val eventTimeIter = processorHandle.getExpiredTimers()
- eventTimeIter.flatMap { case (keyObj, expiryTimestampMs) =>
- if (expiryTimestampMs < watermark) {
+ processorHandle.getExpiredTimers(watermark)
+ .flatMap { case (keyObj, expiryTimestampMs) =>
handleTimerRows(keyObj, expiryTimestampMs, processorHandle)
- } else {
- Iterator.empty
}
- }
case _ => Iterator.empty
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala
index 1aae0e0498aa..1af33aa7b5ad 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/TimerSuite.scala
@@ -48,7 +48,8 @@ class TimerSuite extends StateVariableSuiteBase {
Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
timerState.registerTimer(1L * 1000)
assert(timerState.listTimers().toSet === Set(1000L))
- assert(timerState.getExpiredTimers().toSet === Set(("test_key", 1000L)))
+ assert(timerState.getExpiredTimers(Long.MaxValue).toSeq ===
Seq(("test_key", 1000L)))
+ assert(timerState.getExpiredTimers(Long.MinValue).toSeq ===
Seq.empty[Long])
timerState.registerTimer(20L * 1000)
assert(timerState.listTimers().toSet === Set(20000L, 1000L))
@@ -69,8 +70,10 @@ class TimerSuite extends StateVariableSuiteBase {
timerState1.registerTimer(1L * 1000)
timerState2.registerTimer(15L * 1000)
assert(timerState1.listTimers().toSet === Set(15000L, 1000L))
- assert(timerState1.getExpiredTimers().toSet ===
- Set(("test_key", 15000L), ("test_key", 1000L)))
+ assert(timerState1.getExpiredTimers(Long.MaxValue).toSeq ===
+ Seq(("test_key", 1000L), ("test_key", 15000L)))
+ // if timestamp equals to expiryTimestampsMs, will not considered expired
+ assert(timerState1.getExpiredTimers(15000L).toSeq === Seq(("test_key",
1000L)))
assert(timerState1.listTimers().toSet === Set(15000L, 1000L))
timerState1.registerTimer(20L * 1000)
@@ -99,15 +102,67 @@ class TimerSuite extends StateVariableSuiteBase {
ImplicitGroupingKeyTracker.removeImplicitKey()
ImplicitGroupingKeyTracker.setImplicitKey("test_key1")
- assert(timerState1.getExpiredTimers().toSet ===
- Set(("test_key2", 15000L), ("test_key1", 2000L), ("test_key1", 1000L)))
+ assert(timerState1.getExpiredTimers(Long.MaxValue).toSeq ===
+ Seq(("test_key1", 1000L), ("test_key1", 2000L), ("test_key2", 15000L)))
+ assert(timerState1.getExpiredTimers(10000L).toSeq ===
+ Seq(("test_key1", 1000L), ("test_key1", 2000L)))
assert(timerState1.listTimers().toSet === Set(1000L, 2000L))
ImplicitGroupingKeyTracker.removeImplicitKey()
ImplicitGroupingKeyTracker.setImplicitKey("test_key2")
assert(timerState2.listTimers().toSet === Set(15000L))
- assert(timerState2.getExpiredTimers().toSet ===
- Set(("test_key2", 15000L), ("test_key1", 2000L), ("test_key1", 1000L)))
+ assert(timerState2.getExpiredTimers(1500L).toSeq === Seq(("test_key1",
1000L)))
+ }
+ }
+
+ testWithTimeOutMode("Range scan on second index timer key - " +
+ "verify timestamp is sorted for single instance") { timeoutMode =>
+ tryWithProviderResource(newStoreProviderWithStateVariable(true)) {
provider =>
+ val store = provider.getStore(0)
+
+ ImplicitGroupingKeyTracker.setImplicitKey("test_key")
+ val timerState = new TimerStateImpl(store, timeoutMode,
+ Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
+ val timerTimerstamps = Seq(931L, 8000L, 452300L, 4200L, 90L, 1L, 2L, 8L,
3L, 35L, 6L, 9L, 5L)
+ // register/put unordered timestamp into rocksDB
+ timerTimerstamps.foreach(timerState.registerTimer)
+ assert(timerState.getExpiredTimers(Long.MaxValue).toSeq.map(_._2) ===
timerTimerstamps.sorted)
+ assert(timerState.getExpiredTimers(4200L).toSeq.map(_._2) ===
+ timerTimerstamps.sorted.takeWhile(_ < 4200L))
+ assert(timerState.getExpiredTimers(Long.MinValue).toSeq === Seq.empty)
+ ImplicitGroupingKeyTracker.removeImplicitKey()
+ }
+ }
+
+ testWithTimeOutMode("test range scan on second index timer key - " +
+ "verify timestamp is sorted for multiple instances") { timeoutMode =>
+ tryWithProviderResource(newStoreProviderWithStateVariable(true)) {
provider =>
+ val store = provider.getStore(0)
+
+ ImplicitGroupingKeyTracker.setImplicitKey("test_key1")
+ val timerState1 = new TimerStateImpl(store, timeoutMode,
+ Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
+ val timerTimestamps1 = Seq(64L, 32L, 1024L, 4096L, 0L, 1L)
+ timerTimestamps1.foreach(timerState1.registerTimer)
+
+ val timerState2 = new TimerStateImpl(store, timeoutMode,
+ Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
+ val timerTimestamps2 = Seq(931L, 8000L, 452300L, 4200L)
+ timerTimestamps2.foreach(timerState2.registerTimer)
+ ImplicitGroupingKeyTracker.removeImplicitKey()
+
+ ImplicitGroupingKeyTracker.setImplicitKey("test_key3")
+ val timerState3 = new TimerStateImpl(store, timeoutMode,
+ Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
+ val timerTimerStamps3 = Seq(1L, 2L, 8L, 3L)
+ timerTimerStamps3.foreach(timerState3.registerTimer)
+ ImplicitGroupingKeyTracker.removeImplicitKey()
+
+ assert(timerState1.getExpiredTimers(Long.MaxValue).toSeq.map(_._2) ===
+ (timerTimestamps1 ++ timerTimestamps2 ++ timerTimerStamps3).sorted)
+ assert(timerState1.getExpiredTimers(Long.MinValue).toSeq === Seq.empty)
+ assert(timerState1.getExpiredTimers(8000L).toSeq.map(_._2) ===
+ (timerTimestamps1 ++ timerTimestamps2 ++
timerTimerStamps3).sorted.takeWhile(_ < 8000L))
}
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]