This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new fe1cf3200223 [SPARK-49656][SS] Add support for state variables with 
value state collection types and read change feed options
fe1cf3200223 is described below

commit fe1cf3200223c33ed4670bfa5924d5a4053c8ef9
Author: Anish Shrigondekar <[email protected]>
AuthorDate: Thu Sep 26 17:38:58 2024 +0900

    [SPARK-49656][SS] Add support for state variables with value state 
collection types and read change feed options
    
    ### What changes were proposed in this pull request?
    Add support for state variables with value state collection types and read 
change feed options
    
    ### Why are the changes needed?
    Without this, we cannot support reading per key changes for state variables 
used with stateful processors.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes
    
    Users can now query value state variables with the following query:
    
    ```
            val changeFeedDf = spark.read
                .format("statestore")
                .option(StateSourceOptions.PATH, <checkpoint_loc>)
                .option(StateSourceOptions.STATE_VAR_NAME, <state_var_name>)
                .option(StateSourceOptions.READ_CHANGE_FEED, true)
                .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0)
                .load()
    ```
    
    ### How was this patch tested?
    Added unit tests
    
    ```
    [info] Run completed in 17 seconds, 318 milliseconds.
    [info] Total number of tests run: 2
    [info] Suites: completed 1, aborted 0
    [info] Tests: succeeded 2, failed 0, canceled 0, ignored 0, pending 0
    [info] All tests passed.
    ```
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #48148 from anishshri-db/task/SPARK-49656.
    
    Authored-by: Anish Shrigondekar <[email protected]>
    Signed-off-by: Jungtaek Lim <[email protected]>
---
 .../datasources/v2/state/StateDataSource.scala     |  10 +-
 .../v2/state/StatePartitionReader.scala            |  10 +-
 .../state/HDFSBackedStateStoreProvider.scala       |  10 +-
 .../state/RocksDBStateStoreProvider.scala          |  79 ++++++++++++---
 .../sql/execution/streaming/state/StateStore.scala |   6 +-
 .../streaming/state/StateStoreChangelog.scala      |  11 ++-
 .../StateDataSourceTransformWithStateSuite.scala   | 107 +++++++++++++++++----
 7 files changed, 190 insertions(+), 43 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
index 39bc4dd9fb9c..edddfbd6ccae 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
@@ -33,7 +33,7 @@ import 
org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.{J
 import 
org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues.JoinSideValues
 import 
org.apache.spark.sql.execution.datasources.v2.state.metadata.{StateMetadataPartitionReader,
 StateMetadataTableEntry}
 import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
-import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog, 
OffsetSeqMetadata, TimerStateUtils, TransformWithStateOperatorProperties, 
TransformWithStateVariableInfo}
+import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog, 
OffsetSeqMetadata, StateVariableType, TimerStateUtils, 
TransformWithStateOperatorProperties, TransformWithStateVariableInfo}
 import 
org.apache.spark.sql.execution.streaming.StreamingCheckpointConstants.{DIR_NAME_COMMITS,
 DIR_NAME_OFFSETS, DIR_NAME_STATE}
 import 
org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide,
 RightSide}
 import org.apache.spark.sql.execution.streaming.state.{KeyStateEncoderSpec, 
NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, 
StateSchemaCompatibilityChecker, StateStore, StateStoreColFamilySchema, 
StateStoreConf, StateStoreId, StateStoreProviderId}
@@ -170,13 +170,15 @@ class StateDataSource extends TableProvider with 
DataSourceRegister with Logging
       }
 
       val stateVars = twsOperatorProperties.stateVariables
-      if (stateVars.filter(stateVar => stateVar.stateName == 
stateVarName).size != 1) {
+      val stateVarInfo = stateVars.filter(stateVar => stateVar.stateName == 
stateVarName)
+      if (stateVarInfo.size != 1) {
         throw StateDataSourceErrors.invalidOptionValue(STATE_VAR_NAME,
           s"State variable $stateVarName is not defined for the 
transformWithState operator.")
       }
 
-      // TODO: Support change feed and transformWithState together
-      if (sourceOptions.readChangeFeed) {
+      // TODO: add support for list and map type
+      if (sourceOptions.readChangeFeed &&
+        stateVarInfo.head.stateVariableType != StateVariableType.ValueState) {
         throw 
StateDataSourceErrors.conflictOptions(Seq(StateSourceOptions.READ_CHANGE_FEED,
           StateSourceOptions.STATE_VAR_NAME))
       }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
index d77d97f0057f..b925aee5b627 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
@@ -223,10 +223,18 @@ class StateStoreChangeDataPartitionReader(
       throw StateStoreErrors.stateStoreProviderDoesNotSupportFineGrainedReplay(
         provider.getClass.toString)
     }
+
+    val colFamilyNameOpt = if (stateVariableInfoOpt.isDefined) {
+      Some(stateVariableInfoOpt.get.stateName)
+    } else {
+      None
+    }
+
     provider.asInstanceOf[SupportsFineGrainedReplay]
       .getStateStoreChangeDataReader(
         partition.sourceOptions.readChangeFeedOptions.get.changeStartBatchId + 
1,
-        partition.sourceOptions.readChangeFeedOptions.get.changeEndBatchId + 1)
+        partition.sourceOptions.readChangeFeedOptions.get.changeEndBatchId + 1,
+        colFamilyNameOpt)
   }
 
   override lazy val iter: Iterator[InternalRow] = {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index d9f4443b7961..884b8aa3853c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -991,8 +991,16 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
     result
   }
 
-  override def getStateStoreChangeDataReader(startVersion: Long, endVersion: 
Long):
+  override def getStateStoreChangeDataReader(
+      startVersion: Long,
+      endVersion: Long,
+      colFamilyNameOpt: Option[String] = None):
     StateStoreChangeDataReader = {
+    // Multiple column families are not supported with 
HDFSBackedStateStoreProvider
+    if (colFamilyNameOpt.isDefined) {
+      throw StateStoreErrors.multipleColumnFamiliesNotSupported(providerName)
+    }
+
     new HDFSBackedStateStoreChangeDataReader(fm, baseDir, startVersion, 
endVersion,
       CompressionCodec.createCodec(sparkConf, storeConf.compressionCodec),
       keySchema, valueSchema)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
index 85f80ce9eb1a..6ab634668bc2 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
@@ -498,7 +498,10 @@ private[sql] class RocksDBStateStoreProvider
     }
   }
 
-  override def getStateStoreChangeDataReader(startVersion: Long, endVersion: 
Long):
+  override def getStateStoreChangeDataReader(
+      startVersion: Long,
+      endVersion: Long,
+      colFamilyNameOpt: Option[String] = None):
     StateStoreChangeDataReader = {
     val statePath = stateStoreId.storeCheckpointLocation()
     val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf)
@@ -508,7 +511,8 @@ private[sql] class RocksDBStateStoreProvider
       startVersion,
       endVersion,
       CompressionCodec.createCodec(sparkConf, storeConf.compressionCodec),
-      keyValueEncoderMap)
+      keyValueEncoderMap,
+      colFamilyNameOpt)
   }
 
   /**
@@ -676,27 +680,70 @@ class RocksDBStateStoreChangeDataReader(
     endVersion: Long,
     compressionCodec: CompressionCodec,
     keyValueEncoderMap:
-      ConcurrentHashMap[String, (RocksDBKeyStateEncoder, 
RocksDBValueStateEncoder)])
+      ConcurrentHashMap[String, (RocksDBKeyStateEncoder, 
RocksDBValueStateEncoder)],
+    colFamilyNameOpt: Option[String] = None)
   extends StateStoreChangeDataReader(
-    fm, stateLocation, startVersion, endVersion, compressionCodec) {
+    fm, stateLocation, startVersion, endVersion, compressionCodec, 
colFamilyNameOpt) {
 
   override protected var changelogSuffix: String = "changelog"
 
+  private def getColFamilyIdBytes: Option[Array[Byte]] = {
+    if (colFamilyNameOpt.isDefined) {
+      val colFamilyName = colFamilyNameOpt.get
+      if (!keyValueEncoderMap.containsKey(colFamilyName)) {
+        throw new IllegalStateException(
+          s"Column family $colFamilyName not found in the key value encoder 
map")
+      }
+      Some(keyValueEncoderMap.get(colFamilyName)._1.getColumnFamilyIdBytes())
+    } else {
+      None
+    }
+  }
+
+  private val colFamilyIdBytesOpt: Option[Array[Byte]] = getColFamilyIdBytes
+
   override def getNext(): (RecordType.Value, UnsafeRow, UnsafeRow, Long) = {
-    val reader = currentChangelogReader()
-    if (reader == null) {
-      return null
+    var currRecord: (RecordType.Value, Array[Byte], Array[Byte]) = null
+    val currEncoder: (RocksDBKeyStateEncoder, RocksDBValueStateEncoder) =
+      keyValueEncoderMap.get(colFamilyNameOpt
+        .getOrElse(StateStore.DEFAULT_COL_FAMILY_NAME))
+
+    if (colFamilyIdBytesOpt.isDefined) {
+      // If we are reading records for a particular column family, the 
corresponding vcf id
+      // will be encoded in the key byte array. We need to extract that and 
compare for the
+      // expected column family id. If it matches, we return the record. If 
not, we move to
+      // the next record. Note that this has be handled across multiple 
changelog files and we
+      // rely on the currentChangelogReader to move to the next changelog file 
when needed.
+      while (currRecord == null) {
+        val reader = currentChangelogReader()
+        if (reader == null) {
+          return null
+        }
+
+        val nextRecord = reader.next()
+        val colFamilyIdBytes: Array[Byte] = colFamilyIdBytesOpt.get
+        val endIndex = colFamilyIdBytes.size
+        // Function checks for byte arrays being equal
+        // from index 0 to endIndex - 1 (both inclusive)
+        if (java.util.Arrays.equals(nextRecord._2, 0, endIndex,
+          colFamilyIdBytes, 0, endIndex)) {
+          currRecord = nextRecord
+        }
+      }
+    } else {
+      val reader = currentChangelogReader()
+      if (reader == null) {
+        return null
+      }
+      currRecord = reader.next()
     }
-    val (recordType, keyArray, valueArray) = reader.next()
-    // Todo: does not support multiple virtual column families
-    val (rocksDBKeyStateEncoder, rocksDBValueStateEncoder) =
-      keyValueEncoderMap.get(StateStore.DEFAULT_COL_FAMILY_NAME)
-    val keyRow = rocksDBKeyStateEncoder.decodeKey(keyArray)
-    if (valueArray == null) {
-      (recordType, keyRow, null, currentChangelogVersion - 1)
+
+    val keyRow = currEncoder._1.decodeKey(currRecord._2)
+    if (currRecord._3 == null) {
+      (currRecord._1, keyRow, null, currentChangelogVersion - 1)
     } else {
-      val valueRow = rocksDBValueStateEncoder.decodeValue(valueArray)
-      (recordType, keyRow, valueRow, currentChangelogVersion - 1)
+      val valueRow = currEncoder._2.decodeValue(currRecord._3)
+      (currRecord._1, keyRow, valueRow, currentChangelogVersion - 1)
     }
   }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index d55a973a14e1..6e616cc71a80 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -519,10 +519,14 @@ trait SupportsFineGrainedReplay {
    *
    * @param startVersion starting changelog version
    * @param endVersion ending changelog version
+   * @param colFamilyNameOpt optional column family name to read from
    * @return iterator that gives tuple(recordType: [[RecordType.Value]], 
nested key: [[UnsafeRow]],
    *         nested value: [[UnsafeRow]], batchId: [[Long]])
    */
-  def getStateStoreChangeDataReader(startVersion: Long, endVersion: Long):
+  def getStateStoreChangeDataReader(
+      startVersion: Long,
+      endVersion: Long,
+      colFamilyNameOpt: Option[String] = None):
     NextIterator[(RecordType.Value, UnsafeRow, UnsafeRow, Long)]
 }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala
index 651d72da1609..e89550da37e0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala
@@ -397,13 +397,15 @@ class StateStoreChangelogReaderV2(
  * @param startVersion start version of the changelog file to read
  * @param endVersion end version of the changelog file to read
  * @param compressionCodec de-compression method using for reading changelog 
file
+ * @param colFamilyNameOpt optional column family name to read from
  */
 abstract class StateStoreChangeDataReader(
     fm: CheckpointFileManager,
     stateLocation: Path,
     startVersion: Long,
     endVersion: Long,
-    compressionCodec: CompressionCodec)
+    compressionCodec: CompressionCodec,
+    colFamilyNameOpt: Option[String] = None)
   extends NextIterator[(RecordType.Value, UnsafeRow, UnsafeRow, Long)] with 
Logging {
 
   assert(startVersion >= 1)
@@ -451,9 +453,12 @@ abstract class StateStoreChangeDataReader(
         finished = true
         return null
       }
-      // Todo: Does not support StateStoreChangelogReaderV2
-      changelogReader =
+
+      changelogReader = if (colFamilyNameOpt.isDefined) {
+        new StateStoreChangelogReaderV2(fm, fileIterator.next(), 
compressionCodec)
+      } else {
         new StateStoreChangelogReaderV1(fm, fileIterator.next(), 
compressionCodec)
+      }
     }
     changelogReader
   }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
index bd047d1132fb..84c6eb54681a 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
@@ -186,18 +186,49 @@ class StateDataSourceTransformWithStateSuite extends 
StateStoreMetricsTest
         }
         assert(ex1.isInstanceOf[StateDataSourceInvalidOptionValue])
         assert(ex1.getMessage.contains("Registered timers are not available"))
+      }
+    }
+  }
 
-        // TODO: this should be removed when readChangeFeed is supported for 
value state
-        val ex2 = intercept[Exception] {
-          spark.read
+  testWithChangelogCheckpointingEnabled("state data source cdf integration - " 
+
+    "value state with single variable") {
+    withTempDir { tempDir =>
+      withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+        classOf[RocksDBStateStoreProvider].getName,
+        SQLConf.SHUFFLE_PARTITIONS.key ->
+          TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) {
+        val inputData = MemoryStream[String]
+        val result = inputData.toDS()
+          .groupByKey(x => x)
+          .transformWithState(new StatefulProcessorWithSingleValueVar(),
+            TimeMode.None(),
+            OutputMode.Update())
+
+        testStream(result, OutputMode.Update())(
+          StartStream(checkpointLocation = tempDir.getAbsolutePath),
+          AddData(inputData, "a"),
+          CheckNewAnswer(("a", "1")),
+          AddData(inputData, "b"),
+          CheckNewAnswer(("b", "1")),
+          StopStream
+        )
+
+        val changeFeedDf = spark.read
             .format("statestore")
             .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
             .option(StateSourceOptions.STATE_VAR_NAME, "valueState")
-            .option(StateSourceOptions.READ_CHANGE_FEED, "true")
+            .option(StateSourceOptions.READ_CHANGE_FEED, true)
             .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0)
             .load()
-        }
-        assert(ex2.isInstanceOf[StateDataSourceConflictOptions])
+
+        val opDf = changeFeedDf.selectExpr(
+          "change_type",
+          "key.value AS groupingKey",
+          "value.id AS valueId", "value.name AS valueName",
+          "partition_id")
+
+        checkAnswer(opDf,
+          Seq(Row("update", "a", 1L, "dummyKey", 0), Row("update", "b", 1L, 
"dummyKey", 1)))
       }
     }
   }
@@ -260,19 +291,61 @@ class StateDataSourceTransformWithStateSuite extends 
StateStoreMetricsTest
         }
         assert(ex.isInstanceOf[StateDataSourceInvalidOptionValue])
         assert(ex.getMessage.contains("State variable non-exist is not 
defined"))
+      }
+    }
+  }
 
-        // TODO: this should be removed when readChangeFeed is supported for 
TTL based state
-        // variables
-        val ex1 = intercept[Exception] {
-          spark.read
-            .format("statestore")
-            .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
-            .option(StateSourceOptions.STATE_VAR_NAME, "countState")
-            .option(StateSourceOptions.READ_CHANGE_FEED, "true")
-            .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0)
-            .load()
+  testWithChangelogCheckpointingEnabled("state data source cdf integration - " 
+
+    "value state with single variable and TTL") {
+    withTempDir { tempDir =>
+      withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+        classOf[RocksDBStateStoreProvider].getName,
+        SQLConf.SHUFFLE_PARTITIONS.key ->
+          TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) {
+        val inputData = MemoryStream[String]
+        val result = inputData.toDS()
+          .groupByKey(x => x)
+          .transformWithState(new StatefulProcessorWithTTL(),
+            TimeMode.ProcessingTime(),
+            OutputMode.Update())
+
+        testStream(result, OutputMode.Update())(
+          StartStream(checkpointLocation = tempDir.getAbsolutePath),
+          AddData(inputData, "a"),
+          AddData(inputData, "b"),
+          Execute { _ =>
+            // wait for the batch to run since we are using processing time
+            Thread.sleep(5000)
+          },
+          StopStream
+        )
+
+        val stateReaderDf = spark.read
+          .format("statestore")
+          .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+          .option(StateSourceOptions.STATE_VAR_NAME, "countState")
+          .option(StateSourceOptions.READ_CHANGE_FEED, true)
+          .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0)
+          .load()
+
+        val resultDf = stateReaderDf.selectExpr(
+          "key.value", "value.value", "value.ttlExpirationMs", "partition_id")
+
+        var count = 0L
+        resultDf.collect().foreach { row =>
+          count = count + 1
+          assert(row.getLong(2) > 0)
         }
-        assert(ex1.isInstanceOf[StateDataSourceConflictOptions])
+
+        // verify that 2 state rows are present
+        assert(count === 2)
+
+        val answerDf = stateReaderDf.selectExpr(
+          "change_type",
+          "key.value AS groupingKey",
+          "value.value.value AS valueId", "partition_id")
+        checkAnswer(answerDf,
+          Seq(Row("update", "a", 1L, 0), Row("update", "b", 1L, 1)))
       }
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to