This is an automated email from the ASF dual-hosted git repository.

ashrigondekar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0497b386ec4e [SPARK-54824] Add support for multiGet and deleteRange 
for Rocksdb State Store
0497b386ec4e is described below

commit 0497b386ec4e939a7229cfb5bc6f534f8aa79f5d
Author: zeruibao <[email protected]>
AuthorDate: Fri Jan 16 12:57:26 2026 -0800

    [SPARK-54824] Add support for multiGet and deleteRange for Rocksdb State 
Store
    
    ### What changes were proposed in this pull request?
    Add support for multiGet and deleteRange for Rocksdb State Store
    
    ### Why are the changes needed?
    For some streaming operators, using multiGet and deleteRange can improve 
the read/write performance.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    UT
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #53583 from 
zeruibao/zeruibao/SPARK-54824-add-support-for-multi-get-and-rangeDelete.
    
    Authored-by: zeruibao <[email protected]>
    Signed-off-by: Anish Shrigondekar <[email protected]>
---
 .../state/HDFSBackedStateStoreProvider.scala       |  9 +++
 .../sql/execution/streaming/state/RocksDB.scala    | 74 ++++++++++++++++++++--
 .../streaming/state/RocksDBStateEncoder.scala      |  7 ++
 .../state/RocksDBStateStoreProvider.scala          | 35 +++++++++-
 .../sql/execution/streaming/state/StateStore.scala | 33 ++++++++++
 .../streaming/state/RocksDBStateStoreSuite.scala   | 38 +++++++++++
 .../streaming/state/StateStoreSuite.scala          | 35 ++++++++++
 7 files changed, 225 insertions(+), 6 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index 7291c62f33d9..3a5d14f5581a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -81,6 +81,10 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
 
     override def get(key: UnsafeRow, colFamilyName: String): UnsafeRow = 
map.get(key)
 
+    override def multiGet(keys: Array[UnsafeRow], colFamilyName: String): 
Iterator[UnsafeRow] = {
+      keys.iterator.map(key => get(key, colFamilyName))
+    }
+
     override def iterator(colFamilyName: String): 
StateStoreIterator[UnsafeRowPair] = {
       val iter = map.iterator()
       new StateStoreIterator(iter)
@@ -166,6 +170,11 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
       mapToUpdate.get(key)
     }
 
+    override def multiGet(keys: Array[UnsafeRow], colFamilyName: String): 
Iterator[UnsafeRow] = {
+      assertUseOfDefaultColFamily(colFamilyName)
+      keys.iterator.map(key => mapToUpdate.get(key))
+    }
+
     override def put(key: UnsafeRow, value: UnsafeRow, colFamilyName: String): 
Unit = {
       assertUseOfDefaultColFamily(colFamilyName)
       require(value != null, "Cannot put a null value")
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
index 4c9b2282ba27..8d6ae2a180c4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
@@ -25,7 +25,7 @@ import java.util.concurrent.{ConcurrentHashMap, 
ConcurrentLinkedQueue}
 import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger, AtomicLong}
 
 import scala.collection.{mutable, Map}
-import scala.jdk.CollectionConverters.ConcurrentMapHasAsScala
+import scala.jdk.CollectionConverters.{ConcurrentMapHasAsScala, ListHasAsScala}
 import scala.util.Try
 import scala.util.control.NonFatal
 
@@ -1047,6 +1047,43 @@ class RocksDB(
     }
   }
 
+  /**
+   * Get the values for multiple keys in a single batch operation.
+   * Uses RocksDB's native multiGet for efficient batch lookups.
+   *
+   * @param keys   Array of keys to look up
+   * @param cfName The column family name
+   * @return Array of values corresponding to the keys (null for keys that 
don't exist)
+   */
+  def multiGet(
+      keys: Array[Array[Byte]],
+      cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): 
Iterator[Array[Byte]] = {
+    updateMemoryUsageIfNeeded()
+    // Prepare keys
+    val finalKeys = if (useColumnFamilies) {
+      keys.map(encodeStateRowWithPrefix(_, cfName))
+    } else {
+      keys
+    }
+
+    val keysList = java.util.Arrays.asList(finalKeys: _*)
+
+    // Call RocksDB multiGetAsList
+    val valuesList = db.multiGetAsList(readOptions, keysList)
+
+    // Decode and verify checksums if enabled, then return iterator
+    if (conf.rowChecksumEnabled) {
+      valuesList.asScala.iterator.zipWithIndex.map {
+        case (value, idx) if value != null =>
+          KeyValueChecksumEncoder.decodeAndVerifyValueRowWithChecksum(
+            readVerifier, finalKeys(idx), value)
+        case _ => null
+      }
+    } else {
+      valuesList.asScala.iterator
+    }
+  }
+
   /**
    * This method should gives a 100% guarantee of a correct result, whether 
the key exists or
    * not.
@@ -1068,15 +1105,15 @@ class RocksDB(
   }
 
   /**
-   * Get the values for a given key if present, that were merged (via merge).
+   * Get multiple values for a given key that were stored via merge operation.
    * This returns the values as an iterator of index range, to allow inline 
access
    * of each value bytes without copying, for better performance.
    * Note: This method is currently only supported when row checksum is 
enabled.
    * */
-  def multiGet(
+  def getMergedValues(
       key: Array[Byte],
       cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): 
Iterator[ArrayIndexRange[Byte]] = {
-    assert(conf.rowChecksumEnabled, "multiGet is only allowed when row 
checksum is enabled")
+    assert(conf.rowChecksumEnabled, "getMergedValues is only allowed when row 
checksum is enabled")
     updateMemoryUsageIfNeeded()
 
     val (finalKey, value) = getValue(key, cfName)
@@ -1382,6 +1419,35 @@ class RocksDB(
     }
   }
 
+  /**
+   * Delete all keys in the range [beginKey, endKey).
+   * Uses RocksDB's native deleteRange for efficient bulk deletion.
+   *
+   * @param beginKey The start key of the range (inclusive)
+   * @param endKey   The end key of the range (exclusive)
+   * @param cfName   The column family name
+   */
+  def deleteRange(
+      beginKey: Array[Byte],
+      endKey: Array[Byte],
+      cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = {
+    updateMemoryUsageIfNeeded()
+
+    val beginKeyWithPrefix = if (useColumnFamilies) {
+      encodeStateRowWithPrefix(beginKey, cfName)
+    } else {
+      beginKey
+    }
+
+    val endKeyWithPrefix = if (useColumnFamilies) {
+      encodeStateRowWithPrefix(endKey, cfName)
+    } else {
+      endKey
+    }
+
+    db.deleteRange(writeOptions, beginKeyWithPrefix, endKeyWithPrefix)
+  }
+
   /**
    * Get an iterator of all committed and uncommitted key-value pairs.
    */
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
index 36e2dd2d527c..102f38443b8b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
@@ -44,6 +44,7 @@ import org.apache.spark.unsafe.Platform
 
 sealed trait RocksDBKeyStateEncoder {
   def supportPrefixKeyScan: Boolean
+  def supportsDeleteRange: Boolean
   def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte]
   def encodeKey(row: UnsafeRow): Array[Byte]
   def decodeKey(keyBytes: Array[Byte]): UnsafeRow
@@ -1472,6 +1473,8 @@ class PrefixKeyScanStateEncoder(
   }
 
   override def supportPrefixKeyScan: Boolean = true
+
+  override def supportsDeleteRange: Boolean = false
 }
 
 /**
@@ -1669,6 +1672,8 @@ class RangeKeyScanStateEncoder(
   }
 
   override def supportPrefixKeyScan: Boolean = true
+
+  override def supportsDeleteRange: Boolean = true
 }
 
 /**
@@ -1699,6 +1704,8 @@ class NoPrefixKeyStateEncoder(
 
   override def supportPrefixKeyScan: Boolean = false
 
+  override def supportsDeleteRange: Boolean = false
+
   override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = {
     throw new IllegalStateException("This encoder doesn't support prefix key!")
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
index 37952376e520..7494c12d028f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
@@ -259,6 +259,18 @@ private[sql] class RocksDBStateStoreProvider
       value
     }
 
+    override def multiGet(
+        keys: Array[UnsafeRow],
+        colFamilyName: String): Iterator[UnsafeRow] = {
+      validateAndTransitionState(UPDATE)
+      verify(keys != null && keys.forall(_ != null), "Keys cannot be null")
+      verifyColFamilyOperations("multiGet", colFamilyName)
+      val kvEncoder = keyValueEncoderMap.get(colFamilyName)
+      val encodedKeys = keys.map(kvEncoder._1.encodeKey)
+      val encodedValues = rocksDB.multiGet(encodedKeys, colFamilyName)
+      encodedValues.map(kvEncoder._2.decodeValue)
+    }
+
     override def keyExists(key: UnsafeRow, colFamilyName: String): Boolean = {
       validateAndTransitionState(UPDATE)
       verify(key != null, "Key cannot be null")
@@ -291,8 +303,9 @@ private[sql] class RocksDBStateStoreProvider
       "that supports multiple values for a single key.")
 
       if (storeConf.rowChecksumEnabled) {
-        // multiGet provides better perf for row checksum, since it avoids 
copying values
-        val encodedValuesIterator = 
rocksDB.multiGet(keyEncoder.encodeKey(key), colFamilyName)
+        // getMergedValues provides better perf for row checksum, since it 
avoids copying values
+        val encodedValuesIterator =
+          rocksDB.getMergedValues(keyEncoder.encodeKey(key), colFamilyName)
         valueEncoder.decodeValues(encodedValuesIterator)
       } else {
         val encodedValues = rocksDB.get(keyEncoder.encodeKey(key), 
colFamilyName)
@@ -385,6 +398,24 @@ private[sql] class RocksDBStateStoreProvider
       rocksDB.remove(kvEncoder._1.encodeKey(key), colFamilyName)
     }
 
+    override def deleteRange(
+        beginKey: UnsafeRow,
+        endKey: UnsafeRow,
+        colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = {
+      validateAndTransitionState(UPDATE)
+      verify(state == UPDATING, "Cannot deleteRange after already committed or 
aborted")
+      verify(beginKey != null, "Begin key cannot be null")
+      verify(endKey != null, "End key cannot be null")
+      verifyColFamilyOperations("deleteRange", colFamilyName)
+
+      val kvEncoder = keyValueEncoderMap.get(colFamilyName)
+      verify(kvEncoder._1.supportsDeleteRange,
+        "deleteRange requires a RangeKeyScanStateEncoderSpec for ordered key 
encoding")
+      val encodedBeginKey = kvEncoder._1.encodeKey(beginKey)
+      val encodedEndKey = kvEncoder._1.encodeKey(endKey)
+      rocksDB.deleteRange(encodedBeginKey, encodedEndKey, colFamilyName)
+    }
+
     override def iterator(colFamilyName: String): 
StateStoreIterator[UnsafeRowPair] = {
       validateAndTransitionState(UPDATE)
       // Note this verify function only verify on the colFamilyName being 
valid,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index 08d629c38b72..3f90cb4edbc7 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -115,6 +115,19 @@ trait ReadStateStore {
       key: UnsafeRow,
       colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): UnsafeRow
 
+  /**
+   * Get the values for multiple keys in a single batch operation.
+   * Default implementation throws UnsupportedOperationException.
+   * Providers that support batch retrieval should override this method.
+   *
+   * @param keys          Array of keys to look up
+   * @param colFamilyName The column family name
+   * @return Iterator of values corresponding to the keys (null for keys that 
don't exist)
+   */
+  def multiGet(keys: Array[UnsafeRow], colFamilyName: String): 
Iterator[UnsafeRow] = {
+    throw new UnsupportedOperationException("multiGet is not supported by this 
StateStore")
+  }
+
   /**
    * Check if a key exists in the store, with 100% guarantee of a correct 
result.
    *
@@ -249,6 +262,22 @@ trait StateStore extends ReadStateStore {
       key: UnsafeRow,
       colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit
 
+  /**
+   * Delete all keys in the range [beginKey, endKey).
+   * Uses RocksDB's native deleteRange for efficient bulk deletion.
+   *
+   * @note This operation is NOT recorded in the changelog.
+   * @param beginKey      The start key of the range (inclusive)
+   * @param endKey        The end key of the range (exclusive)
+   * @param colFamilyName The column family name
+   */
+  def deleteRange(
+      beginKey: UnsafeRow,
+      endKey: UnsafeRow,
+      colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = {
+    throw new UnsupportedOperationException("deleteRange is not supported by 
this StateStore")
+  }
+
   /**
    * Merges the provided value with existing values of a non-null key. If a 
existing
    * value does not exist, this operation behaves as [[StateStore.put()]].
@@ -329,6 +358,10 @@ class WrappedReadStateStore(store: StateStore) extends 
ReadStateStore {
     colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): UnsafeRow = 
store.get(key,
     colFamilyName)
 
+  override def multiGet(keys: Array[UnsafeRow], colFamilyName: String): 
Iterator[UnsafeRow] = {
+    store.multiGet(keys, colFamilyName)
+  }
+
   override def keyExists(
       key: UnsafeRow,
       colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Boolean = {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
index f702a9827934..8832d7e09933 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
@@ -2401,6 +2401,44 @@ class RocksDBStateStoreSuite extends 
StateStoreSuiteBase[RocksDBStateStoreProvid
     }
   }
 
+  test("deleteRange - bulk deletion of keys in range") {
+    tryWithProviderResource(
+      newStoreProvider(
+        keySchemaWithRangeScan,
+        RangeKeyScanStateEncoderSpec(keySchemaWithRangeScan, Seq(0)),
+        useColumnFamilies = false)) { provider =>
+      val store = provider.getStore(0)
+      try {
+        // Put keys with timestamps that will be ordered
+        // Keys: (1, "a"), (2, "b"), (3, "c"), (4, "d"), (5, "e")
+        store.put(dataToKeyRowWithRangeScan(1L, "a"), dataToValueRow(10))
+        store.put(dataToKeyRowWithRangeScan(2L, "b"), dataToValueRow(20))
+        store.put(dataToKeyRowWithRangeScan(3L, "c"), dataToValueRow(30))
+        store.put(dataToKeyRowWithRangeScan(4L, "d"), dataToValueRow(40))
+        store.put(dataToKeyRowWithRangeScan(5L, "e"), dataToValueRow(50))
+
+        // Verify all keys exist
+        assert(store.iterator().toSeq.length === 5)
+
+        // Delete range [2, 4) - should delete keys with timestamps 2 and 3
+        val beginKey = dataToKeyRowWithRangeScan(2L, "")
+        val endKey = dataToKeyRowWithRangeScan(4L, "")
+        store.deleteRange(beginKey, endKey)
+
+        // Verify remaining keys
+        val remainingKeys = store.iterator().map { kv =>
+          keyRowWithRangeScanToData(kv.key)
+        }.toSeq
+
+        // Keys 1, 4, 5 should remain; keys 2, 3 should be deleted
+        assert(remainingKeys.length === 3)
+        assert(remainingKeys.map(_._1).toSet === Set(1L, 4L, 5L))
+      } finally {
+        if (!store.hasCommitted) store.abort()
+      }
+    }
+  }
+
   test("Rocks DB task completion listener does not double unlock 
acquireThread") {
     // This test verifies that a thread that locks then unlocks the db and then
     // fires a completion listener (Thread 1) does not unlock the lock validly
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
index 13a85fdd1ed5..b13998708b61 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
@@ -1868,6 +1868,41 @@ abstract class StateStoreSuiteBase[ProviderClass <: 
StateStoreProvider]
     }
   }
 
+  testWithAllCodec("multiGet - batch retrieval of multiple keys") { 
colFamiliesEnabled =>
+    tryWithProviderResource(newStoreProvider(colFamiliesEnabled)) { provider =>
+      val store = provider.getStore(0)
+      try {
+        // Put multiple key-value pairs
+        put(store, "a", 1, 10)
+        put(store, "b", 2, 20)
+        put(store, "c", 3, 30)
+        put(store, "d", 4, 40)
+
+        // Create keys array for multiGet
+        val keys = Array(
+          dataToKeyRow("a", 1),
+          dataToKeyRow("b", 2),
+          dataToKeyRow("c", 3),
+          dataToKeyRow("nonexistent", 999) // Key that doesn't exist
+        )
+
+        // Perform multiGet
+        // Note: multiGet returns an iterator, we copy rows when collecting
+        val results = store.multiGet(keys, StateStore.DEFAULT_COL_FAMILY_NAME)
+          .map(row => if (row != null) row.copy() else null).toArray
+
+        // Verify results
+        assert(results.length === 4)
+        assert(valueRowToData(results(0)) === 10)
+        assert(valueRowToData(results(1)) === 20)
+        assert(valueRowToData(results(2)) === 30)
+        assert(results(3) === null) // Non-existent key should return null
+      } finally {
+        if (!store.hasCommitted) store.abort()
+      }
+    }
+  }
+
   testWithAllCodec(s"removing while iterating") { colFamiliesEnabled =>
     tryWithProviderResource(newStoreProvider(colFamiliesEnabled)) { provider =>
       // Verify state before starting a new set of updates


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to