This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3005dc89084 [SPARK-45558][SS] Introduce a metadata file for streaming 
stateful operator
3005dc89084 is described below

commit 3005dc8908486f63a3e471cd05189881b833daf1
Author: Chaoqin Li <chaoqin...@databricks.com>
AuthorDate: Wed Oct 18 15:49:43 2023 +0900

    [SPARK-45558][SS] Introduce a metadata file for streaming stateful operator
    
    ### What changes were proposed in this pull request?
    Introduce a metadata file for streaming stateful operator, write metadata 
for stateful operator during planning.
    The information to store in the metadata file:
    - operator name (no need to be unique among stateful operators in the query)
    - state store name
    - numColumnsPrefixKey: > 0 if prefix scan is enabled, 0 otherwise
    The body of metadata file will be in json format. The metadata file will be 
versioned just as other streaming metadata file to be future proof.
    
    ### Why are the changes needed?
    The metadata file will improve expose more information about the state 
store, improves debugability and facilitate the development of state related 
feature such as reading and writing state and state repartitioning.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Add unit test and integration tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #43393 from chaoqin-li1123/state_metadata.
    
    Authored-by: Chaoqin Li <chaoqin...@databricks.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 .../spark/sql/execution/QueryExecution.scala       |   2 +-
 .../execution/streaming/IncrementalExecution.scala |  22 ++-
 .../execution/streaming/MicroBatchExecution.scala  |   4 +-
 .../streaming/StreamingSymmetricHashJoinExec.scala |  10 ++
 .../streaming/continuous/ContinuousExecution.scala |   3 +-
 .../streaming/state/OperatorStateMetadata.scala    | 136 ++++++++++++++++
 .../execution/streaming/statefulOperators.scala    |  21 ++-
 .../state/OperatorStateMetadataSuite.scala         | 181 +++++++++++++++++++++
 8 files changed, 374 insertions(+), 5 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index b3c97a83970..3d35300773b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -272,7 +272,7 @@ class QueryExecution(
       new IncrementalExecution(
         sparkSession, logical, OutputMode.Append(), "<unknown>",
         UUID.randomUUID, UUID.randomUUID, 0, None, OffsetSeqMetadata(0, 0),
-        WatermarkPropagator.noop())
+        WatermarkPropagator.noop(), false)
     } else {
       this
     }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index ebdb9caf09e..a67097f6e96 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.streaming
 import java.util.UUID
 import java.util.concurrent.atomic.AtomicInteger
 
+import org.apache.hadoop.fs.Path
+
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{SparkSession, Strategy}
 import org.apache.spark.sql.catalyst.QueryPlanningTracker
@@ -32,6 +34,7 @@ import 
org.apache.spark.sql.execution.aggregate.{HashAggregateExec, MergingSessi
 import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
 import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec
 import 
org.apache.spark.sql.execution.streaming.sources.WriteToMicroBatchDataSourceV1
+import 
org.apache.spark.sql.execution.streaming.state.OperatorStateMetadataWriter
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.OutputMode
 import org.apache.spark.util.Utils
@@ -50,7 +53,8 @@ class IncrementalExecution(
     val currentBatchId: Long,
     val prevOffsetSeqMetadata: Option[OffsetSeqMetadata],
     val offsetSeqMetadata: OffsetSeqMetadata,
-    val watermarkPropagator: WatermarkPropagator)
+    val watermarkPropagator: WatermarkPropagator,
+    val isFirstBatch: Boolean)
   extends QueryExecution(sparkSession, logicalPlan) with Logging {
 
   // Modified planner with stateful operations.
@@ -71,6 +75,8 @@ class IncrementalExecution(
       StreamingGlobalLimitStrategy(outputMode) :: Nil
   }
 
+  private lazy val hadoopConf = sparkSession.sessionState.newHadoopConf()
+
   private[sql] val numStateStores = 
offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key)
     .map(SQLConf.SHUFFLE_PARTITIONS.valueConverter)
     .getOrElse(sparkSession.sessionState.conf.numShufflePartitions)
@@ -177,6 +183,17 @@ class IncrementalExecution(
     }
   }
 
+  object WriteStatefulOperatorMetadataRule extends SparkPlanPartialRule {
+    override val rule: PartialFunction[SparkPlan, SparkPlan] = {
+      case stateStoreWriter: StateStoreWriter if isFirstBatch =>
+        val metadata = stateStoreWriter.operatorStateMetadata()
+        val metadataWriter = new OperatorStateMetadataWriter(new Path(
+          checkpointLocation, 
stateStoreWriter.getStateInfo.operatorId.toString), hadoopConf)
+        metadataWriter.write(metadata)
+        stateStoreWriter
+    }
+  }
+
   object StateOpIdRule extends SparkPlanPartialRule {
     override val rule: PartialFunction[SparkPlan, SparkPlan] = {
       case StateStoreSaveExec(keys, None, None, None, None, stateFormatVersion,
@@ -357,6 +374,9 @@ class IncrementalExecution(
 
     override def apply(plan: SparkPlan): SparkPlan = {
       val planWithStateOpId = plan transform composedRule
+      // The rule doesn't change the plan but cause the side effect that 
metadata is written
+      // in the checkpoint directory of stateful operator.
+      planWithStateOpId transform WriteStatefulOperatorMetadataRule.rule
       simulateWatermarkPropagation(planWithStateOpId)
       planWithStateOpId transform WatermarkPropagationRule.rule
     }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
index 8edbfea3eb2..756ee0ca07e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
@@ -746,6 +746,7 @@ class MicroBatchExecution(
       StreamExecution.IS_CONTINUOUS_PROCESSING, false.toString)
 
     reportTimeTaken("queryPlanning") {
+      val isFirstBatch = lastExecution == null
       lastExecution = new IncrementalExecution(
         sparkSessionToRunBatch,
         triggerLogicalPlan,
@@ -756,7 +757,8 @@ class MicroBatchExecution(
         currentBatchId,
         offsetLog.offsetSeqMetadataForBatchId(currentBatchId - 1),
         offsetSeqMetadata,
-        watermarkPropagator)
+        watermarkPropagator,
+        isFirstBatch)
       lastExecution.executedPlan // Force the lazy generation of execution plan
     }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
index 3ad1dc58cae..20a05a10003 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
@@ -221,6 +221,16 @@ case class StreamingSymmetricHashJoinExec(
 
   override def shortName: String = "symmetricHashJoin"
 
+  private val stateStoreNames =
+    SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide)
+
+  override def operatorStateMetadata(): OperatorStateMetadata = {
+    val info = getStateInfo
+    val operatorInfo = OperatorInfoV1(info.operatorId, shortName)
+    val stateStoreInfo = stateStoreNames.map(StateStoreMetadataV1(_, 0, 
info.numPartitions)).toArray
+    OperatorStateMetadataV1(operatorInfo, stateStoreInfo)
+  }
+
   override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = {
     val watermarkUsedForStateCleanup =
       stateWatermarkPredicates.left.nonEmpty || 
stateWatermarkPredicates.right.nonEmpty
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
index 58119b74f5a..1fcd9499b8c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
@@ -220,7 +220,8 @@ class ContinuousExecution(
         currentBatchId,
         None,
         offsetSeqMetadata,
-        WatermarkPropagator.noop())
+        WatermarkPropagator.noop(),
+        false)
       lastExecution.executedPlan // Force the lazy generation of execution plan
     }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala
new file mode 100644
index 00000000000..4ef39ddfd25
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import java.io.{BufferedReader, InputStreamReader}
+import java.nio.charset.StandardCharsets
+
+import scala.reflect.ClassTag
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FSDataOutputStream, Path}
+import org.json4s.NoTypeHints
+import org.json4s.jackson.Serialization
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, 
MetadataVersionUtil}
+
+/**
+ * Metadata for a state store instance.
+ */
+trait StateStoreMetadata {
+  def storeName: String
+  def numColsPrefixKey: Int
+  def numPartitions: Int
+}
+
+case class StateStoreMetadataV1(storeName: String, numColsPrefixKey: Int, 
numPartitions: Int)
+  extends StateStoreMetadata
+
+/**
+ * Information about a stateful operator.
+ */
+trait OperatorInfo {
+  def operatorId: Long
+  def operatorName: String
+}
+
+case class OperatorInfoV1(operatorId: Long, operatorName: String) extends 
OperatorInfo
+
+trait OperatorStateMetadata {
+  def version: Int
+}
+
+case class OperatorStateMetadataV1(
+    operatorInfo: OperatorInfoV1,
+    stateStoreInfo: Array[StateStoreMetadataV1]) extends OperatorStateMetadata 
{
+  override def version: Int = 1
+}
+
+object OperatorStateMetadataV1 {
+
+  private implicit val formats = Serialization.formats(NoTypeHints)
+
+  private implicit val manifest = Manifest
+    
.classType[OperatorStateMetadataV1](implicitly[ClassTag[OperatorStateMetadataV1]].runtimeClass)
+
+  def metadataFilePath(stateCheckpointPath: Path): Path =
+    new Path(new Path(stateCheckpointPath, "_metadata"), "metadata")
+
+  def deserialize(in: BufferedReader): OperatorStateMetadata = {
+    Serialization.read[OperatorStateMetadataV1](in)
+  }
+
+  def serialize(
+      out: FSDataOutputStream,
+      operatorStateMetadata: OperatorStateMetadata): Unit = {
+    
Serialization.write(operatorStateMetadata.asInstanceOf[OperatorStateMetadataV1],
 out)
+  }
+}
+
+/**
+ * Write OperatorStateMetadata into the state checkpoint directory.
+ */
+class OperatorStateMetadataWriter(stateCheckpointPath: Path, hadoopConf: 
Configuration)
+  extends Logging {
+
+  private val metadataFilePath = 
OperatorStateMetadataV1.metadataFilePath(stateCheckpointPath)
+
+  private lazy val fm = CheckpointFileManager.create(stateCheckpointPath, 
hadoopConf)
+
+  def write(operatorMetadata: OperatorStateMetadata): Unit = {
+    if (fm.exists(metadataFilePath)) return
+
+    fm.mkdirs(metadataFilePath.getParent)
+    val outputStream = fm.createAtomic(metadataFilePath, overwriteIfPossible = 
false)
+    try {
+      
outputStream.write(s"v${operatorMetadata.version}\n".getBytes(StandardCharsets.UTF_8))
+      OperatorStateMetadataV1.serialize(outputStream, operatorMetadata)
+      outputStream.close()
+    } catch {
+      case e: Throwable =>
+        logError(s"Fail to write state metadata file to $metadataFilePath", e)
+        outputStream.cancel()
+        throw e
+    }
+  }
+}
+
+/**
+ * Read OperatorStateMetadata from the state checkpoint directory.
+ */
+class OperatorStateMetadataReader(stateCheckpointPath: Path, hadoopConf: 
Configuration) {
+
+  private val metadataFilePath = 
OperatorStateMetadataV1.metadataFilePath(stateCheckpointPath)
+
+  private lazy val fm = CheckpointFileManager.create(stateCheckpointPath, 
hadoopConf)
+
+  def read(): OperatorStateMetadata = {
+    val inputStream = fm.open(metadataFilePath)
+    val inputReader =
+      new BufferedReader(new InputStreamReader(inputStream, 
StandardCharsets.UTF_8))
+    try {
+      val versionStr = inputReader.readLine()
+      val version = MetadataVersionUtil.validateVersion(versionStr, 1)
+      assert(version == 1)
+      OperatorStateMetadataV1.deserialize(inputReader)
+    } finally {
+      inputStream.close()
+    }
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index cb01fa9ff6d..f534c5f108a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -65,7 +65,7 @@ case class StatefulOperatorStateInfo(
 trait StatefulOperator extends SparkPlan {
   def stateInfo: Option[StatefulOperatorStateInfo]
 
-  protected def getStateInfo: StatefulOperatorStateInfo = {
+  def getStateInfo: StatefulOperatorStateInfo = {
     stateInfo.getOrElse {
       throw new IllegalStateException("State location not present for 
execution")
     }
@@ -179,6 +179,15 @@ trait StateStoreWriter extends StatefulOperator with 
PythonSQLMetrics { self: Sp
   /** Records the duration of running `body` for the next query progress 
update. */
   protected def timeTakenMs(body: => Unit): Long = Utils.timeTakenMs(body)._2
 
+  /** Metadata of this stateful operator and its states stores. */
+  def operatorStateMetadata(): OperatorStateMetadata = {
+    val info = getStateInfo
+    val operatorInfo = OperatorInfoV1(info.operatorId, shortName)
+    val stateStoreInfo =
+      Array(StateStoreMetadataV1(StateStoreId.DEFAULT_STORE_NAME, 0, 
info.numPartitions))
+    OperatorStateMetadataV1(operatorInfo, stateStoreInfo)
+  }
+
   /** Set the operator level metrics */
   protected def setOperatorMetrics(numStateStoreInstances: Int = 1): Unit = {
     assert(numStateStoreInstances >= 1, s"invalid number of stores: 
$numStateStoreInstances")
@@ -741,6 +750,8 @@ case class SessionWindowStateStoreSaveExec(
 
   override def keyExpressions: Seq[Attribute] = keyWithoutSessionExpressions
 
+  override def shortName: String = "sessionWindowStateStoreSaveExec"
+
   private val stateManager = 
StreamingSessionWindowStateManager.createStateManager(
     keyWithoutSessionExpressions, sessionExpression, child.output, 
stateFormatVersion)
 
@@ -828,6 +839,14 @@ case class SessionWindowStateStoreSaveExec(
       keyWithoutSessionExpressions, getStateInfo, conf) :: Nil
   }
 
+  override def operatorStateMetadata(): OperatorStateMetadata = {
+    val info = getStateInfo
+    val operatorInfo = OperatorInfoV1(info.operatorId, shortName)
+    val stateStoreInfo = Array(StateStoreMetadataV1(
+      StateStoreId.DEFAULT_STORE_NAME, stateManager.getNumColsForPrefixKey, 
info.numPartitions))
+    OperatorStateMetadataV1(operatorInfo, stateStoreInfo)
+  }
+
   override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = {
     (outputMode.contains(Append) || outputMode.contains(Update)) &&
       eventTimeWatermarkForEviction.isDefined &&
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala
new file mode 100644
index 00000000000..b75da9084e8
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadataSuite.scala
@@ -0,0 +1,181 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.sql.Column
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.streaming.{OutputMode, StreamTest}
+import org.apache.spark.sql.streaming.OutputMode.Complete
+import org.apache.spark.sql.test.SharedSparkSession
+
+
+class OperatorStateMetadataSuite extends StreamTest with SharedSparkSession {
+  import testImplicits._
+
+  private lazy val hadoopConf = spark.sessionState.newHadoopConf()
+
+  private def numShufflePartitions = 
spark.sessionState.conf.numShufflePartitions
+
+  private def checkOperatorStateMetadata(
+      checkpointDir: String,
+      operatorId: Int,
+      expectedMetadata: OperatorStateMetadataV1): Unit = {
+    val statePath = new Path(checkpointDir, s"state/$operatorId")
+    val operatorMetadata = new OperatorStateMetadataReader(statePath, 
hadoopConf).read()
+      .asInstanceOf[OperatorStateMetadataV1]
+    assert(operatorMetadata.operatorInfo == expectedMetadata.operatorInfo &&
+      
operatorMetadata.stateStoreInfo.sameElements(expectedMetadata.stateStoreInfo))
+  }
+
+  test("Serialize and deserialize stateful operator metadata") {
+    withTempDir { checkpointDir =>
+      val statePath = new Path(checkpointDir.toString, "state/0")
+      val stateStoreInfo = (1 to 4).map(i => StateStoreMetadataV1(s"store$i", 
1, 200))
+      val operatorInfo = OperatorInfoV1(1, "Join")
+      val operatorMetadata = OperatorStateMetadataV1(operatorInfo, 
stateStoreInfo.toArray)
+      new OperatorStateMetadataWriter(statePath, 
hadoopConf).write(operatorMetadata)
+      checkOperatorStateMetadata(checkpointDir.toString, 0, operatorMetadata)
+    }
+  }
+
+  test("Stateful operator metadata for streaming aggregation") {
+    withTempDir { checkpointDir =>
+      val inputData = MemoryStream[Int]
+      val aggregated =
+        inputData.toDF()
+          .groupBy($"value")
+          .agg(count("*"))
+          .as[(Int, Long)]
+
+      testStream(aggregated, Complete)(
+        StartStream(checkpointLocation = checkpointDir.toString),
+        AddData(inputData, 3),
+        CheckLastBatch((3, 1)),
+        StopStream
+      )
+
+      val expectedMetadata = OperatorStateMetadataV1(OperatorInfoV1(0, 
"stateStoreSave"),
+        Array(StateStoreMetadataV1("default", 0, numShufflePartitions)))
+      checkOperatorStateMetadata(checkpointDir.toString, 0, expectedMetadata)
+    }
+  }
+
+  test("Stateful operator metadata for streaming join") {
+    withTempDir { checkpointDir =>
+      val input1 = MemoryStream[Int]
+      val input2 = MemoryStream[Int]
+
+      val df1 = input1.toDF.select($"value" as "key", ($"value" * 2) as 
"leftValue")
+      val df2 = input2.toDF.select($"value" as "key", ($"value" * 3) as 
"rightValue")
+      val joined = df1.join(df2, "key")
+
+      testStream(joined)(
+        StartStream(checkpointLocation = checkpointDir.getCanonicalPath),
+        AddData(input1, 1),
+        CheckAnswer(),
+        AddData(input2, 1, 10), // 1 arrived on input1 first, then input2, 
should join
+        CheckNewAnswer((1, 2, 3)),
+        StopStream
+      )
+
+      val expectedStateStoreInfo = Array(
+        StateStoreMetadataV1("left-keyToNumValues", 0, numShufflePartitions),
+        StateStoreMetadataV1("left-keyWithIndexToValue", 0, 
numShufflePartitions),
+        StateStoreMetadataV1("right-keyToNumValues", 0, numShufflePartitions),
+        StateStoreMetadataV1("right-keyWithIndexToValue", 0, 
numShufflePartitions))
+
+      val expectedMetadata = OperatorStateMetadataV1(
+        OperatorInfoV1(0, "symmetricHashJoin"), expectedStateStoreInfo)
+      checkOperatorStateMetadata(checkpointDir.toString, 0, expectedMetadata)
+    }
+  }
+
+  test("Stateful operator metadata for streaming session window") {
+    withTempDir { checkpointDir =>
+      val input = MemoryStream[(String, Long)]
+      val sessionWindow: Column = session_window($"eventTime", "10 seconds")
+
+      val events = input.toDF()
+        .select($"_1".as("value"), $"_2".as("timestamp"))
+        .withColumn("eventTime", $"timestamp".cast("timestamp"))
+        .withWatermark("eventTime", "30 seconds")
+        .selectExpr("explode(split(value, ' ')) AS sessionId", "eventTime")
+
+      val streamingDf = events
+        .groupBy(sessionWindow as Symbol("session"), $"sessionId")
+        .agg(count("*").as("numEvents"))
+        .selectExpr("sessionId", "CAST(session.start AS LONG)", 
"CAST(session.end AS LONG)",
+          "CAST(session.end AS LONG) - CAST(session.start AS LONG) AS 
durationMs",
+          "numEvents")
+
+      testStream(streamingDf, OutputMode.Complete())(
+        StartStream(checkpointLocation = checkpointDir.toString),
+        AddData(input,
+          ("hello world spark streaming", 40L),
+          ("world hello structured streaming", 41L)
+        ),
+        CheckNewAnswer(
+          ("hello", 40, 51, 11, 2),
+          ("world", 40, 51, 11, 2),
+          ("streaming", 40, 51, 11, 2),
+          ("spark", 40, 50, 10, 1),
+          ("structured", 41, 51, 10, 1)
+        ),
+        StopStream
+      )
+
+      val expectedMetadata = OperatorStateMetadataV1(
+        OperatorInfoV1(0, "sessionWindowStateStoreSaveExec"),
+        Array(StateStoreMetadataV1("default", 1, 
spark.sessionState.conf.numShufflePartitions))
+      )
+      checkOperatorStateMetadata(checkpointDir.toString, 0, expectedMetadata)
+    }
+  }
+
+  test("Stateful operator metadata for multiple operators") {
+    withTempDir { checkpointDir =>
+      val inputData = MemoryStream[Int]
+
+      val stream = inputData.toDF()
+        .withColumn("eventTime", timestamp_seconds($"value"))
+        .withWatermark("eventTime", "0 seconds")
+        .groupBy(window($"eventTime", "5 seconds").as("window"))
+        .agg(count("*").as("count"))
+        .groupBy(window($"window", "10 seconds"))
+        .agg(count("*").as("count"), sum("count").as("sum"))
+        .select($"window".getField("start").cast("long").as[Long],
+          $"count".as[Long], $"sum".as[Long])
+
+      testStream(stream)(
+        StartStream(checkpointLocation = checkpointDir.toString),
+        AddData(inputData, 10 to 21: _*),
+        CheckNewAnswer((10, 2, 10)),
+        StopStream
+      )
+      val expectedMetadata0 = OperatorStateMetadataV1(OperatorInfoV1(0, 
"stateStoreSave"),
+        Array(StateStoreMetadataV1("default", 0, numShufflePartitions)))
+      val expectedMetadata1 = OperatorStateMetadataV1(OperatorInfoV1(1, 
"stateStoreSave"),
+        Array(StateStoreMetadataV1("default", 0, numShufflePartitions)))
+      checkOperatorStateMetadata(checkpointDir.toString, 0, expectedMetadata0)
+      checkOperatorStateMetadata(checkpointDir.toString, 1, expectedMetadata1)
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to