This is an automated email from the ASF dual-hosted git repository.
kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e610d1d8f79b [SPARK-46852][SS] Remove use of explicit key encoder and
pass it implicitly to the operator for transformWithState operator
e610d1d8f79b is described below
commit e610d1d8f79b913cb9ee9236a6325202c58d8397
Author: Anish Shrigondekar <[email protected]>
AuthorDate: Thu Feb 1 22:31:07 2024 +0900
[SPARK-46852][SS] Remove use of explicit key encoder and pass it implicitly
to the operator for transformWithState operator
### What changes were proposed in this pull request?
Remove use of explicit key encoder and pass it implicitly to the operator
for transformWithState operator
### Why are the changes needed?
Changes needed to avoid asking users to provide explicit key encoder and we
also might need them for subsequent timer related changes
### Does this PR introduce _any_ user-facing change?
Yes
### How was this patch tested?
Existing unit tests
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #44974 from anishshri-db/task/SPARK-46852.
Authored-by: Anish Shrigondekar <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
---
.../sql/streaming/StatefulProcessorHandle.scala | 5 +----
.../spark/sql/catalyst/plans/logical/object.scala | 3 +++
.../spark/sql/execution/SparkStrategies.scala | 3 ++-
.../streaming/StatefulProcessorHandleImpl.scala | 13 +++++++++----
.../streaming/TransformWithStateExec.scala | 6 +++++-
.../sql/execution/streaming/ValueStateImpl.scala | 12 +++++-------
.../streaming/state/ValueStateSuite.scala | 22 +++++++++++-----------
.../sql/streaming/TransformWithStateSuite.scala | 8 +++-----
8 files changed, 39 insertions(+), 33 deletions(-)
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala
b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala
index 302de4a3c947..5eaccceb947c 100644
---
a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala
+++
b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.streaming
import java.io.Serializable
import org.apache.spark.annotation.{Evolving, Experimental}
-import org.apache.spark.sql.Encoder
/**
* Represents the operation handle provided to the stateful processor used in
the
@@ -34,12 +33,10 @@ private[sql] trait StatefulProcessorHandle extends
Serializable {
* The user must ensure to call this function only within the `init()`
method of the
* StatefulProcessor.
* @param stateName - name of the state variable
- * @param keyEncoder - Spark SQL Encoder for key
- * @tparam K - type of key
* @tparam T - type of state variable
* @return - instance of ValueState of type T that can be used to store
state persistently
*/
- def getValueState[K, T](stateName: String, keyEncoder: Encoder[K]):
ValueState[T]
+ def getValueState[T](stateName: String): ValueState[T]
/** Function to return queryInfo for currently running task */
def getQueryInfo(): QueryInfo
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 8f937dd5a777..cb8673d20ed3 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -577,6 +577,7 @@ object TransformWithState {
timeoutMode: TimeoutMode,
outputMode: OutputMode,
child: LogicalPlan): LogicalPlan = {
+ val keyEncoder = encoderFor[K]
val mapped = new TransformWithState(
UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
UnresolvedDeserializer(encoderFor[V].deserializer, dataAttributes),
@@ -585,6 +586,7 @@ object TransformWithState {
statefulProcessor.asInstanceOf[StatefulProcessor[Any, Any, Any]],
timeoutMode,
outputMode,
+ keyEncoder.asInstanceOf[ExpressionEncoder[Any]],
CatalystSerde.generateObjAttr[U],
child
)
@@ -600,6 +602,7 @@ case class TransformWithState(
statefulProcessor: StatefulProcessor[Any, Any, Any],
timeoutMode: TimeoutMode,
outputMode: OutputMode,
+ keyEncoder: ExpressionEncoder[Any],
outputObjAttr: Attribute,
child: LogicalPlan) extends UnaryNode with ObjectProducer {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 5d4063d125c8..f5c2f17f8826 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -728,7 +728,7 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
case TransformWithState(
keyDeserializer, valueDeserializer, groupingAttributes,
dataAttributes, statefulProcessor, timeoutMode, outputMode,
- outputAttr, child) =>
+ keyEncoder, outputAttr, child) =>
val execPlan = TransformWithStateExec(
keyDeserializer,
valueDeserializer,
@@ -737,6 +737,7 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
statefulProcessor,
timeoutMode,
outputMode,
+ keyEncoder,
outputAttr,
stateInfo = None,
batchTimestampMs = None,
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
index 758e8c646ffc..d0cd8f7dc0a3 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
@@ -20,7 +20,7 @@ import java.util.UUID
import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.state.StateStore
import org.apache.spark.sql.streaming.{QueryInfo, StatefulProcessorHandle,
ValueState}
import org.apache.spark.util.Utils
@@ -67,8 +67,13 @@ class QueryInfoImpl(
* Class that provides a concrete implementation of a StatefulProcessorHandle.
Note that we keep
* track of valid transitions as various functions are invoked to track object
lifecycle.
* @param store - instance of state store
+ * @param runId - unique id for the current run
+ * @param keyEncoder - encoder for the key
*/
-class StatefulProcessorHandleImpl(store: StateStore, runId: UUID)
+class StatefulProcessorHandleImpl(
+ store: StateStore,
+ runId: UUID,
+ keyEncoder: ExpressionEncoder[Any])
extends StatefulProcessorHandle with Logging {
import StatefulProcessorHandleState._
@@ -108,11 +113,11 @@ class StatefulProcessorHandleImpl(store: StateStore,
runId: UUID)
def getHandleState: StatefulProcessorHandleState = currState
- override def getValueState[K, T](stateName: String, keyEncoder: Encoder[K]):
ValueState[T] = {
+ override def getValueState[T](stateName: String): ValueState[T] = {
verify(currState == CREATED, s"Cannot create state variable with
name=$stateName after " +
"initialization is complete")
store.createColFamilyIfAbsent(stateName)
- val resultState = new ValueStateImpl[K, T](store, stateName, keyEncoder)
+ val resultState = new ValueStateImpl[T](store, stateName, keyEncoder)
resultState
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
index ce651d959afc..82e827685b47 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
@@ -20,6 +20,7 @@ import java.util.concurrent.TimeUnit.NANOSECONDS
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute,
Expression, SortOrder, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.physical.Distribution
import org.apache.spark.sql.execution._
@@ -38,6 +39,7 @@ import org.apache.spark.util.CompletionIterator
* @param statefulProcessor processor methods called on underlying data
* @param timeoutMode defines the timeout mode
* @param outputMode defines the output mode for the statefulProcessor
+ * @param keyEncoder expression encoder for the key type
* @param outputObjAttr Defines the output object
* @param batchTimestampMs processing timestamp of the current batch.
* @param eventTimeWatermarkForLateEvents event time watermark for filtering
late events
@@ -52,6 +54,7 @@ case class TransformWithStateExec(
statefulProcessor: StatefulProcessor[Any, Any, Any],
timeoutMode: TimeoutMode,
outputMode: OutputMode,
+ keyEncoder: ExpressionEncoder[Any],
outputObjAttr: Attribute,
stateInfo: Option[StatefulOperatorStateInfo],
batchTimestampMs: Option[Long],
@@ -162,7 +165,8 @@ case class TransformWithStateExec(
useColumnFamilies = true
) {
case (store: StateStore, singleIterator: Iterator[InternalRow]) =>
- val processorHandle = new StatefulProcessorHandleImpl(store,
getStateInfo.queryRunId)
+ val processorHandle = new StatefulProcessorHandleImpl(store,
getStateInfo.queryRunId,
+ keyEncoder)
assert(processorHandle.getHandleState ==
StatefulProcessorHandleState.CREATED)
statefulProcessor.init(processorHandle, outputMode)
processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala
index 91554de97fe3..5a1b6d01baa3 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala
@@ -21,9 +21,8 @@ import java.io.Serializable
import org.apache.commons.lang3.SerializationUtils
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.streaming.state.StateStore
import org.apache.spark.sql.streaming.ValueState
@@ -38,10 +37,10 @@ import org.apache.spark.sql.types._
* @tparam K - data type of key
* @tparam S - data type of object that will be stored
*/
-class ValueStateImpl[K, S](
+class ValueStateImpl[S](
store: StateStore,
stateName: String,
- keyEnc: Encoder[K]) extends ValueState[S] with Logging {
+ keyExprEnc: ExpressionEncoder[Any]) extends ValueState[S] with Logging {
// TODO: validate places that are trying to encode the key and check if we
can eliminate/
// add caching for some of these calls.
@@ -52,10 +51,9 @@ class ValueStateImpl[K, S](
s"stateName=$stateName")
}
- val exprEnc: ExpressionEncoder[K] = encoderFor(keyEnc)
- val toRow = exprEnc.createSerializer()
+ val toRow = keyExprEnc.createSerializer()
val keyByteArr = toRow
- .apply(keyOption.get.asInstanceOf[K]).asInstanceOf[UnsafeRow].getBytes()
+ .apply(keyOption.get).asInstanceOf[UnsafeRow].getBytes()
val schemaForKeyRow: StructType = new StructType().add("key", BinaryType)
val keyEncoder = UnsafeProjection.create(schemaForKeyRow)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
index 6d929498d65b..49a5fff131ae 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
@@ -25,6 +25,7 @@ import org.apache.hadoop.conf.Configuration
import org.scalatest.BeforeAndAfter
import org.apache.spark.sql.Encoders
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker,
StatefulProcessorHandleImpl}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.ValueState
@@ -87,10 +88,10 @@ class ValueStateSuite extends SharedSparkSession
test("Implicit key operations") {
tryWithProviderResource(newStoreProviderWithValueState(true)) { provider =>
val store = provider.getStore(0)
- val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID())
+ val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
+ Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
- val testState: ValueState[Long] = handle.getValueState[String,
Long]("testState",
- Encoders.STRING)
+ val testState: ValueState[Long] = handle.getValueState[Long]("testState")
assert(ImplicitGroupingKeyTracker.getImplicitKeyOption.isEmpty)
val ex = intercept[Exception] {
testState.update(123)
@@ -118,10 +119,10 @@ class ValueStateSuite extends SharedSparkSession
test("Value state operations for single instance") {
tryWithProviderResource(newStoreProviderWithValueState(true)) { provider =>
val store = provider.getStore(0)
- val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID())
+ val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
+ Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
- val testState: ValueState[Long] = handle.getValueState[String,
Long]("testState",
- Encoders.STRING)
+ val testState: ValueState[Long] = handle.getValueState[Long]("testState")
ImplicitGroupingKeyTracker.setImplicitKey("test_key")
testState.update(123)
assert(testState.get() === 123)
@@ -144,12 +145,11 @@ class ValueStateSuite extends SharedSparkSession
test("Value state operations for multiple instances") {
tryWithProviderResource(newStoreProviderWithValueState(true)) { provider =>
val store = provider.getStore(0)
- val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID())
+ val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
+ Encoders.STRING.asInstanceOf[ExpressionEncoder[Any]])
- val testState1: ValueState[Long] = handle.getValueState[String,
Long]("testState1",
- Encoders.STRING)
- val testState2: ValueState[Long] = handle.getValueState[String,
Long]("testState2",
- Encoders.STRING)
+ val testState1: ValueState[Long] =
handle.getValueState[Long]("testState1")
+ val testState2: ValueState[Long] =
handle.getValueState[Long]("testState2")
ImplicitGroupingKeyTracker.setImplicitKey("test_key")
testState1.update(123)
assert(testState1.get() === 123)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
index 9909919c0cae..70a71f745066 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.streaming
import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{AnalysisException, Encoders, SaveMode}
+import org.apache.spark.sql.{AnalysisException, SaveMode}
import org.apache.spark.sql.execution.streaming._
import
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled,
RocksDBStateStoreProvider}
import org.apache.spark.sql.internal.SQLConf
@@ -38,8 +38,7 @@ class RunningCountStatefulProcessor extends
StatefulProcessor[String, String, (S
outputMode: OutputMode) : Unit = {
_processorHandle = handle
assert(handle.getQueryInfo().getBatchId >= 0)
- _countState = _processorHandle.getValueState[String, Long]("countState",
- Encoders.STRING)
+ _countState = _processorHandle.getValueState[Long]("countState")
}
override def handleInputRows(
@@ -67,8 +66,7 @@ class RunningCountStatefulProcessorWithError extends
RunningCountStatefulProcess
inputRows: Iterator[String],
timerValues: TimerValues): Iterator[(String, String)] = {
// Trying to create value state here should fail
- _tempState = _processorHandle.getValueState[String, Long]("tempState",
- Encoders.STRING)
+ _tempState = _processorHandle.getValueState[Long]("tempState")
Iterator.empty
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]