This is an automated email from the ASF dual-hosted git repository.
maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 6b5917beff30 [SPARK-46961][SS] Using ProcessorContext to store and
retrieve handle
6b5917beff30 is described below
commit 6b5917beff30c813a362584a135a587001df1390
Author: Eric Marnadi <[email protected]>
AuthorDate: Mon Mar 4 21:20:23 2024 +0300
[SPARK-46961][SS] Using ProcessorContext to store and retrieve handle
### What changes were proposed in this pull request?
Setting the processorHandle as a part of the statefulProcessor, so that the
user doesn't have to explicitly keep track of it, and can instead simply call
`getStatefulProcessorHandle`
### Why are the changes needed?
This enhances the usability of the State API
### Does this PR introduce _any_ user-facing change?
Yes, this is an API change. This enhances usability of the
StatefulProcessorHandle and the TransformWithState operator.
### How was this patch tested?
Existing unit tests are sufficient
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #45359 from ericm-db/handle-context.
Authored-by: Eric Marnadi <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
---
.../src/main/resources/error/error-classes.json | 7 +++
docs/sql-error-conditions.md | 7 +++
.../apache/spark/sql/errors/ExecutionErrors.scala | 6 +++
.../spark/sql/streaming/StatefulProcessor.scala | 38 ++++++++++++---
.../streaming/TransformWithStateExec.scala | 4 +-
.../streaming/TransformWithListStateSuite.scala | 14 ++----
.../sql/streaming/TransformWithStateSuite.scala | 54 ++++++++++------------
7 files changed, 84 insertions(+), 46 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-classes.json
b/common/utils/src/main/resources/error/error-classes.json
index 6ccd841ccd0f..7cf3e9c533ca 100644
--- a/common/utils/src/main/resources/error/error-classes.json
+++ b/common/utils/src/main/resources/error/error-classes.json
@@ -3337,6 +3337,13 @@
],
"sqlState" : "42802"
},
+ "STATE_STORE_HANDLE_NOT_INITIALIZED" : {
+ "message" : [
+ "The handle has not been initialized for this StatefulProcessor.",
+ "Please only use the StatefulProcessor within the transformWithState
operator."
+ ],
+ "sqlState" : "42802"
+ },
"STATE_STORE_MULTIPLE_VALUES_PER_KEY" : {
"message" : [
"Store does not support multiple values per key"
diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md
index f026c456eb2d..7be01f8cb513 100644
--- a/docs/sql-error-conditions.md
+++ b/docs/sql-error-conditions.md
@@ -2091,6 +2091,13 @@ Star (*) is not allowed in a select list when GROUP BY
an ordinal position is us
Failed to remove default column family with reserved name=`<colFamilyName>`.
+### STATE_STORE_HANDLE_NOT_INITIALIZED
+
+[SQLSTATE:
42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
+
+The handle has not been initialized for this StatefulProcessor.
+Please only use the StatefulProcessor within the transformWithState operator.
+
### STATE_STORE_MULTIPLE_VALUES_PER_KEY
[SQLSTATE:
42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala
b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala
index b74a67b49bda..7910c386fcf1 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala
@@ -53,6 +53,12 @@ private[sql] trait ExecutionErrors extends
DataTypeErrorsBase {
e)
}
+ def stateStoreHandleNotInitialized(): SparkRuntimeException = {
+ new SparkRuntimeException(
+ errorClass = "STATE_STORE_HANDLE_NOT_INITIALIZED",
+ messageParameters = Map.empty)
+ }
+
def failToRecognizePatternAfterUpgradeError(
pattern: String, e: Throwable): SparkUpgradeException = {
new SparkUpgradeException(
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala
b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala
index 76794136dd49..42a9430bf39d 100644
---
a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala
+++
b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.streaming
import java.io.Serializable
import org.apache.spark.annotation.{Evolving, Experimental}
+import org.apache.spark.sql.errors.ExecutionErrors
/**
* Represents the arbitrary stateful logic that needs to be provided by the
user to perform
@@ -29,17 +30,18 @@ import org.apache.spark.annotation.{Evolving, Experimental}
@Evolving
private[sql] trait StatefulProcessor[K, I, O] extends Serializable {
+ /**
+ * Handle to the stateful processor that provides access to the state store
and other
+ * stateful processing related APIs.
+ */
+ private var statefulProcessorHandle: StatefulProcessorHandle = null
+
/**
* Function that will be invoked as the first method that allows for users to
* initialize all their state variables and perform other init actions
before handling data.
- * @param handle - reference to the statefulProcessorHandle that the user
can use to perform
- * actions like creating state variables, accessing queryInfo
etc. Please refer to
- * [[StatefulProcessorHandle]] for more details.
* @param outputMode - output mode for the stateful processor
*/
- def init(
- handle: StatefulProcessorHandle,
- outputMode: OutputMode): Unit
+ def init(outputMode: OutputMode): Unit
/**
* Function that will allow users to interact with input data rows along
with the grouping key
@@ -59,5 +61,27 @@ private[sql] trait StatefulProcessor[K, I, O] extends
Serializable {
* Function called as the last method that allows for users to perform
* any cleanup or teardown operations.
*/
- def close (): Unit
+ def close (): Unit = {}
+
+ /**
+ * Function to set the stateful processor handle that will be used to
interact with the state
+ * store and other stateful processor related operations.
+ *
+ * @param handle - instance of StatefulProcessorHandle
+ */
+ final def setHandle(handle: StatefulProcessorHandle): Unit = {
+ statefulProcessorHandle = handle
+ }
+
+ /**
+ * Function to get the stateful processor handle that will be used to
interact with the state
+ *
+ * @return handle - instance of StatefulProcessorHandle
+ */
+ final def getHandle: StatefulProcessorHandle = {
+ if (statefulProcessorHandle == null) {
+ throw ExecutionErrors.stateStoreHandleNotInitialized()
+ }
+ statefulProcessorHandle
+ }
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
index 5a80fb1209ba..117bc722f09e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
@@ -156,6 +156,7 @@ case class TransformWithStateExec(
setStoreMetrics(store)
setOperatorMetrics()
statefulProcessor.close()
+ statefulProcessor.setHandle(null)
processorHandle.setHandleState(StatefulProcessorHandleState.CLOSED)
})
}
@@ -228,7 +229,8 @@ case class TransformWithStateExec(
val processorHandle = new StatefulProcessorHandleImpl(
store, getStateInfo.queryRunId, keyEncoder, isStreaming)
assert(processorHandle.getHandleState ==
StatefulProcessorHandleState.CREATED)
- statefulProcessor.init(processorHandle, outputMode)
+ statefulProcessor.setHandle(processorHandle)
+ statefulProcessor.init(outputMode)
processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED)
processDataWithPartition(singleIterator, store, processorHandle)
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala
index f7ed813badde..3d085da4ab58 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala
@@ -27,12 +27,10 @@ case class InputRow(key: String, action: String, value:
String)
class TestListStateProcessor
extends StatefulProcessor[String, InputRow, (String, String)] {
- @transient var _processorHandle: StatefulProcessorHandle = _
@transient var _listState: ListState[String] = _
- override def init(handle: StatefulProcessorHandle, outputMode: OutputMode):
Unit = {
- _processorHandle = handle
- _listState = handle.getListState("testListState")
+ override def init(outputMode: OutputMode): Unit = {
+ _listState = getHandle.getListState("testListState")
}
override def handleInputRows(
@@ -84,14 +82,12 @@ class TestListStateProcessor
class ToggleSaveAndEmitProcessor
extends StatefulProcessor[String, String, String] {
- @transient var _processorHandle: StatefulProcessorHandle = _
@transient var _listState: ListState[String] = _
@transient var _valueState: ValueState[Boolean] = _
- override def init(handle: StatefulProcessorHandle, outputMode: OutputMode):
Unit = {
- _processorHandle = handle
- _listState = handle.getListState("testListState")
- _valueState = handle.getValueState("testValueState")
+ override def init(outputMode: OutputMode): Unit = {
+ _listState = getHandle.getListState("testListState")
+ _valueState = getHandle.getValueState("testValueState")
}
override def handleInputRows(
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
index a4a04e0b5077..8a87472a023a 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.streaming
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, SparkRuntimeException}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.streaming._
import
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled,
RocksDBStateStoreProvider,
StateStoreMultipleColumnFamiliesNotSupportedException}
@@ -30,14 +30,9 @@ object TransformWithStateSuiteUtils {
class RunningCountStatefulProcessor extends StatefulProcessor[String, String,
(String, String)]
with Logging {
@transient private var _countState: ValueState[Long] = _
- @transient var _processorHandle: StatefulProcessorHandle = _
-
- override def init(
- handle: StatefulProcessorHandle,
- outputMode: OutputMode) : Unit = {
- _processorHandle = handle
- assert(handle.getQueryInfo().getBatchId >= 0)
- _countState = _processorHandle.getValueState[Long]("countState")
+
+ override def init(outputMode: OutputMode): Unit = {
+ _countState = getHandle.getValueState[Long]("countState")
}
override def handleInputRows(
@@ -62,17 +57,11 @@ class RunningCountMostRecentStatefulProcessor
with Logging {
@transient private var _countState: ValueState[Long] = _
@transient private var _mostRecent: ValueState[String] = _
- @transient var _processorHandle: StatefulProcessorHandle = _
-
- override def init(
- handle: StatefulProcessorHandle,
- outputMode: OutputMode) : Unit = {
- _processorHandle = handle
- assert(handle.getQueryInfo().getBatchId >= 0)
- _countState = _processorHandle.getValueState[Long]("countState")
- _mostRecent = _processorHandle.getValueState[String]("mostRecent")
- }
+ override def init(outputMode: OutputMode): Unit = {
+ _countState = getHandle.getValueState[Long]("countState")
+ _mostRecent = getHandle.getValueState[String]("mostRecent")
+ }
override def handleInputRows(
key: String,
inputRows: Iterator[(String, String)],
@@ -96,15 +85,10 @@ class MostRecentStatefulProcessorWithDeletion
extends StatefulProcessor[String, (String, String), (String, String)]
with Logging {
@transient private var _mostRecent: ValueState[String] = _
- @transient var _processorHandle: StatefulProcessorHandle = _
-
- override def init(
- handle: StatefulProcessorHandle,
- outputMode: OutputMode) : Unit = {
- _processorHandle = handle
- assert(handle.getQueryInfo().getBatchId >= 0)
- _processorHandle.deleteIfExists("countState")
- _mostRecent = _processorHandle.getValueState[String]("mostRecent")
+
+ override def init(outputMode: OutputMode): Unit = {
+ getHandle.deleteIfExists("countState")
+ _mostRecent = getHandle.getValueState[String]("mostRecent")
}
override def handleInputRows(
@@ -132,7 +116,7 @@ class RunningCountStatefulProcessorWithError extends
RunningCountStatefulProcess
inputRows: Iterator[String],
timerValues: TimerValues): Iterator[(String, String)] = {
// Trying to create value state here should fail
- _tempState = _processorHandle.getValueState[Long]("tempState")
+ _tempState = getHandle.getValueState[Long]("tempState")
Iterator.empty
}
}
@@ -195,6 +179,18 @@ class TransformWithStateSuite extends StateStoreMetricsTest
}
}
+ test("Use statefulProcessor without transformWithState - handle should be
absent") {
+ val processor = new RunningCountStatefulProcessor()
+ val ex = intercept[Exception] {
+ processor.getHandle
+ }
+ checkError(
+ ex.asInstanceOf[SparkRuntimeException],
+ errorClass = "STATE_STORE_HANDLE_NOT_INITIALIZED",
+ parameters = Map.empty
+ )
+ }
+
test("transformWithState - batch should succeed") {
val inputData = Seq("a", "b")
val result = inputData.toDS()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]