viirya commented on a change in pull request #33038:
URL: https://github.com/apache/spark/pull/33038#discussion_r667439211



##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
##########
@@ -253,6 +239,12 @@ private[state] class HDFSBackedStateStoreProvider extends 
StateStoreProvider wit
     this.storeConf = storeConf
     this.hadoopConf = hadoopConf
     this.numberOfVersionsToRetainInMemory = 
storeConf.maxVersionsToRetainInMemory
+
+    require((keySchema.length == 0 && numColsPrefixKey == 0) ||
+      (keySchema.length > numColsPrefixKey), "The number of columns for prefix 
key must be " +
+      "greater than the number of columns in the key!")
+    this.numColsPrefixKey = numColsPrefixKey

Review comment:
       The number of column in the key must be greater than the number of 
columns for prefix key?

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
##########
@@ -0,0 +1,251 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow, 
UnsafeProjection, UnsafeRow}
+import 
org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider.{STATE_ENCODING_NUM_VERSION_BYTES,
 STATE_ENCODING_VERSION}
+import org.apache.spark.sql.types.{StructField, StructType}
+import org.apache.spark.unsafe.Platform
+
+sealed trait RocksDBStateEncoder {
+  def supportPrefixKeyScan: Boolean
+  def decodePrefixKey(groupKey: UnsafeRow): Array[Byte]
+  def extractPrefixKey(key: UnsafeRow): UnsafeRow
+
+  def encodeKey(row: UnsafeRow): Array[Byte]
+  def encodeValue(row: UnsafeRow): Array[Byte]
+
+  def decodeKey(keyBytes: Array[Byte]): UnsafeRow
+  def decodeValue(valueBytes: Array[Byte]): UnsafeRow
+  def decode(byteArrayTuple: ByteArrayPair): UnsafeRowPair
+}
+
+object RocksDBStateEncoder {
+  def getEncoder(
+      keySchema: StructType,
+      valueSchema: StructType,
+      numColsPrefixKey: Int): RocksDBStateEncoder = {
+    if (numColsPrefixKey > 0) {
+      new PrefixKeyScanStateEncoder(keySchema, valueSchema, numColsPrefixKey)
+    } else {
+      new NoPrefixKeyStateEncoder(keySchema, valueSchema)
+    }
+  }
+
+  /**
+   * Encode the UnsafeRow of N bytes as a N+1 byte array.
+   * @note This creates a new byte array and memcopies the UnsafeRow to the 
new array.
+   */
+  def encodeUnsafeRow(row: UnsafeRow): Array[Byte] = {
+    val bytesToEncode = row.getBytes
+    val encodedBytes = new Array[Byte](bytesToEncode.length + 
STATE_ENCODING_NUM_VERSION_BYTES)
+    Platform.putByte(encodedBytes, Platform.BYTE_ARRAY_OFFSET, 
STATE_ENCODING_VERSION)
+    // Platform.BYTE_ARRAY_OFFSET is the recommended way to memcopy b/w byte 
arrays. See Platform.
+    Platform.copyMemory(
+      bytesToEncode, Platform.BYTE_ARRAY_OFFSET,
+      encodedBytes, Platform.BYTE_ARRAY_OFFSET + 
STATE_ENCODING_NUM_VERSION_BYTES,
+      bytesToEncode.length)
+    encodedBytes
+  }
+
+  def decodeToUnsafeRow(bytes: Array[Byte], numFields: Int): UnsafeRow = {
+    if (bytes != null) {
+      val row = new UnsafeRow(numFields)
+      decodeToUnsafeRow(bytes, row)
+    } else {
+      null
+    }
+  }
+
+  def decodeToUnsafeRow(bytes: Array[Byte], reusedRow: UnsafeRow): UnsafeRow = 
{
+    if (bytes != null) {
+      // Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st 
offset. See Platform.
+      reusedRow.pointTo(
+        bytes,
+        Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES,
+        bytes.length - STATE_ENCODING_NUM_VERSION_BYTES)
+      reusedRow
+    } else {
+      null
+    }
+  }
+}
+
+class PrefixKeyScanStateEncoder(
+    keySchema: StructType,
+    valueSchema: StructType,
+    numColsPrefixKey: Int) extends RocksDBStateEncoder {
+
+  import RocksDBStateEncoder._
+
+  require(keySchema.length > numColsPrefixKey, "The number of columns for 
prefix key must be " +
+    "greater than the number of columns in the key!")
+
+  private val groupKeyFieldsWithIdx: Seq[(StructField, Int)] = {
+    keySchema.zipWithIndex.take(numColsPrefixKey)
+  }
+
+  private val remainingKeyFieldsWithIdx: Seq[(StructField, Int)] = {
+    keySchema.zipWithIndex.drop(numColsPrefixKey)
+  }
+
+  private val prefixKeyProjection: UnsafeProjection = {
+    val refs = groupKeyFieldsWithIdx.map(x => BoundReference(x._2, 
x._1.dataType, x._1.nullable))
+    UnsafeProjection.create(refs)
+  }
+
+  private val remainingKeyProjection: UnsafeProjection = {
+    val refs = remainingKeyFieldsWithIdx.map(x =>
+      BoundReference(x._2, x._1.dataType, x._1.nullable))
+    UnsafeProjection.create(refs)
+  }
+
+  // This is quite simple to do - just bind sequentially, as we don't change 
the order.
+  private val restoreKeyProjection: UnsafeProjection = 
UnsafeProjection.create(keySchema)
+
+  // Reusable objects
+  private val joinedRowOnKey = new JoinedRow()
+  private val valueRow = new UnsafeRow(valueSchema.size)
+  private val rowTuple = new UnsafeRowPair()
+
+  override def encodeKey(row: UnsafeRow): Array[Byte] = {
+    val prefixKeyEncoded = encodeUnsafeRow(extractPrefixKey(row))
+    val remainingEncoded = encodeUnsafeRow(remainingKeyProjection(row))
+
+    val encodedBytes = new Array[Byte](prefixKeyEncoded.length + 
remainingEncoded.length + 4)
+    Platform.putInt(encodedBytes, Platform.BYTE_ARRAY_OFFSET, 
prefixKeyEncoded.length)
+    Platform.copyMemory(prefixKeyEncoded, Platform.BYTE_ARRAY_OFFSET,
+      encodedBytes, Platform.BYTE_ARRAY_OFFSET + 4, prefixKeyEncoded.length)
+    // NOTE: We don't put the length of remainingEncoded as we can calculate 
later
+    // on deserialization.
+    Platform.copyMemory(remainingEncoded, Platform.BYTE_ARRAY_OFFSET,
+      encodedBytes, Platform.BYTE_ARRAY_OFFSET + 4 + prefixKeyEncoded.length,
+      remainingEncoded.length)
+
+    encodedBytes
+  }
+
+  override def encodeValue(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row)
+
+  override def decodeKey(keyBytes: Array[Byte]): UnsafeRow = {
+    val groupKeyEncodedLen = Platform.getInt(keyBytes, 
Platform.BYTE_ARRAY_OFFSET)
+    val groupKeyEncoded = new Array[Byte](groupKeyEncodedLen)
+    Platform.copyMemory(keyBytes, Platform.BYTE_ARRAY_OFFSET + 4, 
groupKeyEncoded,
+      Platform.BYTE_ARRAY_OFFSET, groupKeyEncodedLen)
+
+    // Here we calculate the remainingKeyEncodedLen leveraging the length of 
keyBytes
+    val remainingKeyEncodedLen = keyBytes.length - 4 - groupKeyEncodedLen
+
+    val remainingKeyEncoded = new Array[Byte](remainingKeyEncodedLen)
+    Platform.copyMemory(keyBytes, Platform.BYTE_ARRAY_OFFSET + 4 +
+      groupKeyEncodedLen, remainingKeyEncoded, Platform.BYTE_ARRAY_OFFSET,
+      remainingKeyEncodedLen)
+
+    val groupKeyDecoded = decodeToUnsafeRow(groupKeyEncoded, 
groupKeyFieldsWithIdx.length)
+    val remainingKeyDecoded = decodeToUnsafeRow(remainingKeyEncoded, numFields 
= 1)
+
+    
restoreKeyProjection(joinedRowOnKey.withLeft(groupKeyDecoded).withRight(remainingKeyDecoded))
+  }
+
+  override def decodeValue(valueBytes: Array[Byte]): UnsafeRow = {
+    decodeToUnsafeRow(valueBytes, valueRow)
+  }
+
+  override def extractPrefixKey(key: UnsafeRow): UnsafeRow = {
+    prefixKeyProjection(key)
+  }
+
+  override def decodePrefixKey(groupKey: UnsafeRow): Array[Byte] = {

Review comment:
       What this method does more like `encodePrefixKey`?

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
##########
@@ -185,6 +189,33 @@ class RocksDB(
     }
   }
 
+  def prefixScan(prefix: Array[Byte]): Iterator[ByteArrayPair] = {
+    val threadId = Thread.currentThread().getId
+    val iter = prefixScanReuseIter.computeIfAbsent(threadId, tid => {
+      val it = writeBatch.newIteratorWithBase(db.newIterator())
+      logInfo(s"Getting iterator from version $loadedVersion for prefix scan 
on " +
+        s"tid $tid")

Review comment:
       tid -> thread id?

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
##########
@@ -0,0 +1,251 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow, 
UnsafeProjection, UnsafeRow}
+import 
org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider.{STATE_ENCODING_NUM_VERSION_BYTES,
 STATE_ENCODING_VERSION}
+import org.apache.spark.sql.types.{StructField, StructType}
+import org.apache.spark.unsafe.Platform
+
+sealed trait RocksDBStateEncoder {
+  def supportPrefixKeyScan: Boolean
+  def decodePrefixKey(groupKey: UnsafeRow): Array[Byte]
+  def extractPrefixKey(key: UnsafeRow): UnsafeRow
+
+  def encodeKey(row: UnsafeRow): Array[Byte]
+  def encodeValue(row: UnsafeRow): Array[Byte]
+
+  def decodeKey(keyBytes: Array[Byte]): UnsafeRow
+  def decodeValue(valueBytes: Array[Byte]): UnsafeRow
+  def decode(byteArrayTuple: ByteArrayPair): UnsafeRowPair
+}
+
+object RocksDBStateEncoder {
+  def getEncoder(
+      keySchema: StructType,
+      valueSchema: StructType,
+      numColsPrefixKey: Int): RocksDBStateEncoder = {
+    if (numColsPrefixKey > 0) {
+      new PrefixKeyScanStateEncoder(keySchema, valueSchema, numColsPrefixKey)
+    } else {
+      new NoPrefixKeyStateEncoder(keySchema, valueSchema)
+    }
+  }
+
+  /**
+   * Encode the UnsafeRow of N bytes as a N+1 byte array.
+   * @note This creates a new byte array and memcopies the UnsafeRow to the 
new array.
+   */
+  def encodeUnsafeRow(row: UnsafeRow): Array[Byte] = {
+    val bytesToEncode = row.getBytes
+    val encodedBytes = new Array[Byte](bytesToEncode.length + 
STATE_ENCODING_NUM_VERSION_BYTES)
+    Platform.putByte(encodedBytes, Platform.BYTE_ARRAY_OFFSET, 
STATE_ENCODING_VERSION)
+    // Platform.BYTE_ARRAY_OFFSET is the recommended way to memcopy b/w byte 
arrays. See Platform.
+    Platform.copyMemory(
+      bytesToEncode, Platform.BYTE_ARRAY_OFFSET,
+      encodedBytes, Platform.BYTE_ARRAY_OFFSET + 
STATE_ENCODING_NUM_VERSION_BYTES,
+      bytesToEncode.length)
+    encodedBytes
+  }
+
+  def decodeToUnsafeRow(bytes: Array[Byte], numFields: Int): UnsafeRow = {
+    if (bytes != null) {
+      val row = new UnsafeRow(numFields)
+      decodeToUnsafeRow(bytes, row)
+    } else {
+      null
+    }
+  }
+
+  def decodeToUnsafeRow(bytes: Array[Byte], reusedRow: UnsafeRow): UnsafeRow = 
{
+    if (bytes != null) {
+      // Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st 
offset. See Platform.
+      reusedRow.pointTo(
+        bytes,
+        Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES,
+        bytes.length - STATE_ENCODING_NUM_VERSION_BYTES)
+      reusedRow
+    } else {
+      null
+    }
+  }
+}
+
+class PrefixKeyScanStateEncoder(
+    keySchema: StructType,
+    valueSchema: StructType,
+    numColsPrefixKey: Int) extends RocksDBStateEncoder {
+
+  import RocksDBStateEncoder._
+
+  require(keySchema.length > numColsPrefixKey, "The number of columns for 
prefix key must be " +
+    "greater than the number of columns in the key!")
+

Review comment:
       The number of columns in the key must be greater than the number of 
columns for prefix key?

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
##########
@@ -0,0 +1,251 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow, 
UnsafeProjection, UnsafeRow}
+import 
org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider.{STATE_ENCODING_NUM_VERSION_BYTES,
 STATE_ENCODING_VERSION}
+import org.apache.spark.sql.types.{StructField, StructType}
+import org.apache.spark.unsafe.Platform
+
+sealed trait RocksDBStateEncoder {
+  def supportPrefixKeyScan: Boolean
+  def decodePrefixKey(groupKey: UnsafeRow): Array[Byte]
+  def extractPrefixKey(key: UnsafeRow): UnsafeRow
+
+  def encodeKey(row: UnsafeRow): Array[Byte]
+  def encodeValue(row: UnsafeRow): Array[Byte]
+
+  def decodeKey(keyBytes: Array[Byte]): UnsafeRow
+  def decodeValue(valueBytes: Array[Byte]): UnsafeRow
+  def decode(byteArrayTuple: ByteArrayPair): UnsafeRowPair
+}
+
+object RocksDBStateEncoder {
+  def getEncoder(
+      keySchema: StructType,
+      valueSchema: StructType,
+      numColsPrefixKey: Int): RocksDBStateEncoder = {
+    if (numColsPrefixKey > 0) {
+      new PrefixKeyScanStateEncoder(keySchema, valueSchema, numColsPrefixKey)
+    } else {
+      new NoPrefixKeyStateEncoder(keySchema, valueSchema)
+    }
+  }
+
+  /**
+   * Encode the UnsafeRow of N bytes as a N+1 byte array.
+   * @note This creates a new byte array and memcopies the UnsafeRow to the 
new array.
+   */
+  def encodeUnsafeRow(row: UnsafeRow): Array[Byte] = {
+    val bytesToEncode = row.getBytes
+    val encodedBytes = new Array[Byte](bytesToEncode.length + 
STATE_ENCODING_NUM_VERSION_BYTES)
+    Platform.putByte(encodedBytes, Platform.BYTE_ARRAY_OFFSET, 
STATE_ENCODING_VERSION)
+    // Platform.BYTE_ARRAY_OFFSET is the recommended way to memcopy b/w byte 
arrays. See Platform.
+    Platform.copyMemory(
+      bytesToEncode, Platform.BYTE_ARRAY_OFFSET,
+      encodedBytes, Platform.BYTE_ARRAY_OFFSET + 
STATE_ENCODING_NUM_VERSION_BYTES,
+      bytesToEncode.length)
+    encodedBytes
+  }
+
+  def decodeToUnsafeRow(bytes: Array[Byte], numFields: Int): UnsafeRow = {
+    if (bytes != null) {
+      val row = new UnsafeRow(numFields)
+      decodeToUnsafeRow(bytes, row)
+    } else {
+      null
+    }
+  }
+
+  def decodeToUnsafeRow(bytes: Array[Byte], reusedRow: UnsafeRow): UnsafeRow = 
{
+    if (bytes != null) {
+      // Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st 
offset. See Platform.
+      reusedRow.pointTo(
+        bytes,
+        Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES,
+        bytes.length - STATE_ENCODING_NUM_VERSION_BYTES)
+      reusedRow
+    } else {
+      null
+    }
+  }
+}
+
+class PrefixKeyScanStateEncoder(
+    keySchema: StructType,
+    valueSchema: StructType,
+    numColsPrefixKey: Int) extends RocksDBStateEncoder {
+
+  import RocksDBStateEncoder._
+
+  require(keySchema.length > numColsPrefixKey, "The number of columns for 
prefix key must be " +
+    "greater than the number of columns in the key!")
+
+  private val groupKeyFieldsWithIdx: Seq[(StructField, Int)] = {
+    keySchema.zipWithIndex.take(numColsPrefixKey)
+  }
+
+  private val remainingKeyFieldsWithIdx: Seq[(StructField, Int)] = {
+    keySchema.zipWithIndex.drop(numColsPrefixKey)
+  }
+
+  private val prefixKeyProjection: UnsafeProjection = {
+    val refs = groupKeyFieldsWithIdx.map(x => BoundReference(x._2, 
x._1.dataType, x._1.nullable))
+    UnsafeProjection.create(refs)
+  }
+
+  private val remainingKeyProjection: UnsafeProjection = {
+    val refs = remainingKeyFieldsWithIdx.map(x =>
+      BoundReference(x._2, x._1.dataType, x._1.nullable))
+    UnsafeProjection.create(refs)
+  }
+
+  // This is quite simple to do - just bind sequentially, as we don't change 
the order.
+  private val restoreKeyProjection: UnsafeProjection = 
UnsafeProjection.create(keySchema)
+
+  // Reusable objects
+  private val joinedRowOnKey = new JoinedRow()
+  private val valueRow = new UnsafeRow(valueSchema.size)
+  private val rowTuple = new UnsafeRowPair()
+
+  override def encodeKey(row: UnsafeRow): Array[Byte] = {
+    val prefixKeyEncoded = encodeUnsafeRow(extractPrefixKey(row))
+    val remainingEncoded = encodeUnsafeRow(remainingKeyProjection(row))
+
+    val encodedBytes = new Array[Byte](prefixKeyEncoded.length + 
remainingEncoded.length + 4)
+    Platform.putInt(encodedBytes, Platform.BYTE_ARRAY_OFFSET, 
prefixKeyEncoded.length)
+    Platform.copyMemory(prefixKeyEncoded, Platform.BYTE_ARRAY_OFFSET,
+      encodedBytes, Platform.BYTE_ARRAY_OFFSET + 4, prefixKeyEncoded.length)
+    // NOTE: We don't put the length of remainingEncoded as we can calculate 
later
+    // on deserialization.
+    Platform.copyMemory(remainingEncoded, Platform.BYTE_ARRAY_OFFSET,
+      encodedBytes, Platform.BYTE_ARRAY_OFFSET + 4 + prefixKeyEncoded.length,
+      remainingEncoded.length)
+
+    encodedBytes
+  }
+
+  override def encodeValue(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row)
+
+  override def decodeKey(keyBytes: Array[Byte]): UnsafeRow = {
+    val groupKeyEncodedLen = Platform.getInt(keyBytes, 
Platform.BYTE_ARRAY_OFFSET)
+    val groupKeyEncoded = new Array[Byte](groupKeyEncodedLen)
+    Platform.copyMemory(keyBytes, Platform.BYTE_ARRAY_OFFSET + 4, 
groupKeyEncoded,
+      Platform.BYTE_ARRAY_OFFSET, groupKeyEncodedLen)
+
+    // Here we calculate the remainingKeyEncodedLen leveraging the length of 
keyBytes
+    val remainingKeyEncodedLen = keyBytes.length - 4 - groupKeyEncodedLen
+
+    val remainingKeyEncoded = new Array[Byte](remainingKeyEncodedLen)
+    Platform.copyMemory(keyBytes, Platform.BYTE_ARRAY_OFFSET + 4 +
+      groupKeyEncodedLen, remainingKeyEncoded, Platform.BYTE_ARRAY_OFFSET,
+      remainingKeyEncodedLen)
+
+    val groupKeyDecoded = decodeToUnsafeRow(groupKeyEncoded, 
groupKeyFieldsWithIdx.length)
+    val remainingKeyDecoded = decodeToUnsafeRow(remainingKeyEncoded, numFields 
= 1)
+
+    
restoreKeyProjection(joinedRowOnKey.withLeft(groupKeyDecoded).withRight(remainingKeyDecoded))
+  }
+
+  override def decodeValue(valueBytes: Array[Byte]): UnsafeRow = {
+    decodeToUnsafeRow(valueBytes, valueRow)
+  }
+
+  override def extractPrefixKey(key: UnsafeRow): UnsafeRow = {
+    prefixKeyProjection(key)
+  }
+
+  override def decodePrefixKey(groupKey: UnsafeRow): Array[Byte] = {
+    val groupKeyEncoded = encodeUnsafeRow(groupKey)

Review comment:
       Shall we name `prefixKeyEncoded` (and other similar variables) to be 
consistent with method/API name?

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
##########
@@ -60,21 +60,17 @@ trait ReadStateStore {
   def get(key: UnsafeRow): UnsafeRow
 
   /**
-   * Get key value pairs with optional approximate `start` and `end` extents.
-   * If the State Store implementation maintains indices for the data based on 
the optional
-   * `keyIndexOrdinal` over fields `keySchema` (see 
`StateStoreProvider.init()`), then it can use
-   * `start` and `end` to make a best-effort scan over the data. Default 
implementation returns
-   * the full data scan iterator, which is correct but inefficient. Custom 
implementations must
-   * ensure that updates (puts, removes) can be made while iterating over this 
iterator.
+   * Return an iterator containing all the key-value pairs which are matched 
with
+   * the given prefix key.
+   *
+   * Spark will provide numColsPrefixKey greater than 0 in 
StateStoreProvider.init method if
+   * the state store is responsible to handle the request for prefix scan. The 
schema of the
+   * prefix key should be same with the leftmost `numColsPrefixKey` columns of 
the key schema.
    *
-   * @param start UnsafeRow having the `keyIndexOrdinal` column set with 
appropriate starting value.
-   * @param end UnsafeRow having the `keyIndexOrdinal` column set with 
appropriate ending value.
-   * @return An iterator of key-value pairs that is guaranteed not miss any 
key between start and
-   *         end, both inclusive.
+   * It is expected to throw exception if Spark calls this method without 
setting numColsPrefixKey
+   * to the greater than 0.
    */
-  def getRange(start: Option[UnsafeRow], end: Option[UnsafeRow]): 
Iterator[UnsafeRowPair] = {
-    iterator()
-  }
+  def prefixScan(prefixKey: UnsafeRow): Iterator[UnsafeRowPair]

Review comment:
       E.g., `MemoryStateStore` doesn't support it.

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
##########
@@ -0,0 +1,251 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow, 
UnsafeProjection, UnsafeRow}
+import 
org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider.{STATE_ENCODING_NUM_VERSION_BYTES,
 STATE_ENCODING_VERSION}
+import org.apache.spark.sql.types.{StructField, StructType}
+import org.apache.spark.unsafe.Platform
+
+sealed trait RocksDBStateEncoder {
+  def supportPrefixKeyScan: Boolean
+  def decodePrefixKey(groupKey: UnsafeRow): Array[Byte]
+  def extractPrefixKey(key: UnsafeRow): UnsafeRow
+
+  def encodeKey(row: UnsafeRow): Array[Byte]
+  def encodeValue(row: UnsafeRow): Array[Byte]
+
+  def decodeKey(keyBytes: Array[Byte]): UnsafeRow
+  def decodeValue(valueBytes: Array[Byte]): UnsafeRow
+  def decode(byteArrayTuple: ByteArrayPair): UnsafeRowPair
+}
+
+object RocksDBStateEncoder {
+  def getEncoder(
+      keySchema: StructType,
+      valueSchema: StructType,
+      numColsPrefixKey: Int): RocksDBStateEncoder = {
+    if (numColsPrefixKey > 0) {
+      new PrefixKeyScanStateEncoder(keySchema, valueSchema, numColsPrefixKey)
+    } else {
+      new NoPrefixKeyStateEncoder(keySchema, valueSchema)
+    }
+  }
+
+  /**
+   * Encode the UnsafeRow of N bytes as a N+1 byte array.
+   * @note This creates a new byte array and memcopies the UnsafeRow to the 
new array.
+   */
+  def encodeUnsafeRow(row: UnsafeRow): Array[Byte] = {
+    val bytesToEncode = row.getBytes
+    val encodedBytes = new Array[Byte](bytesToEncode.length + 
STATE_ENCODING_NUM_VERSION_BYTES)
+    Platform.putByte(encodedBytes, Platform.BYTE_ARRAY_OFFSET, 
STATE_ENCODING_VERSION)
+    // Platform.BYTE_ARRAY_OFFSET is the recommended way to memcopy b/w byte 
arrays. See Platform.
+    Platform.copyMemory(
+      bytesToEncode, Platform.BYTE_ARRAY_OFFSET,
+      encodedBytes, Platform.BYTE_ARRAY_OFFSET + 
STATE_ENCODING_NUM_VERSION_BYTES,
+      bytesToEncode.length)
+    encodedBytes
+  }
+
+  def decodeToUnsafeRow(bytes: Array[Byte], numFields: Int): UnsafeRow = {
+    if (bytes != null) {
+      val row = new UnsafeRow(numFields)
+      decodeToUnsafeRow(bytes, row)
+    } else {
+      null
+    }
+  }
+
+  def decodeToUnsafeRow(bytes: Array[Byte], reusedRow: UnsafeRow): UnsafeRow = 
{
+    if (bytes != null) {
+      // Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st 
offset. See Platform.
+      reusedRow.pointTo(
+        bytes,
+        Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES,
+        bytes.length - STATE_ENCODING_NUM_VERSION_BYTES)
+      reusedRow
+    } else {
+      null
+    }
+  }
+}
+
+class PrefixKeyScanStateEncoder(
+    keySchema: StructType,
+    valueSchema: StructType,
+    numColsPrefixKey: Int) extends RocksDBStateEncoder {
+
+  import RocksDBStateEncoder._
+
+  require(keySchema.length > numColsPrefixKey, "The number of columns for 
prefix key must be " +
+    "greater than the number of columns in the key!")
+
+  private val groupKeyFieldsWithIdx: Seq[(StructField, Int)] = {
+    keySchema.zipWithIndex.take(numColsPrefixKey)
+  }
+
+  private val remainingKeyFieldsWithIdx: Seq[(StructField, Int)] = {
+    keySchema.zipWithIndex.drop(numColsPrefixKey)
+  }
+
+  private val prefixKeyProjection: UnsafeProjection = {
+    val refs = groupKeyFieldsWithIdx.map(x => BoundReference(x._2, 
x._1.dataType, x._1.nullable))
+    UnsafeProjection.create(refs)
+  }
+
+  private val remainingKeyProjection: UnsafeProjection = {
+    val refs = remainingKeyFieldsWithIdx.map(x =>
+      BoundReference(x._2, x._1.dataType, x._1.nullable))
+    UnsafeProjection.create(refs)
+  }
+
+  // This is quite simple to do - just bind sequentially, as we don't change 
the order.
+  private val restoreKeyProjection: UnsafeProjection = 
UnsafeProjection.create(keySchema)
+
+  // Reusable objects
+  private val joinedRowOnKey = new JoinedRow()
+  private val valueRow = new UnsafeRow(valueSchema.size)
+  private val rowTuple = new UnsafeRowPair()
+
+  override def encodeKey(row: UnsafeRow): Array[Byte] = {
+    val prefixKeyEncoded = encodeUnsafeRow(extractPrefixKey(row))
+    val remainingEncoded = encodeUnsafeRow(remainingKeyProjection(row))
+
+    val encodedBytes = new Array[Byte](prefixKeyEncoded.length + 
remainingEncoded.length + 4)
+    Platform.putInt(encodedBytes, Platform.BYTE_ARRAY_OFFSET, 
prefixKeyEncoded.length)
+    Platform.copyMemory(prefixKeyEncoded, Platform.BYTE_ARRAY_OFFSET,
+      encodedBytes, Platform.BYTE_ARRAY_OFFSET + 4, prefixKeyEncoded.length)
+    // NOTE: We don't put the length of remainingEncoded as we can calculate 
later
+    // on deserialization.
+    Platform.copyMemory(remainingEncoded, Platform.BYTE_ARRAY_OFFSET,
+      encodedBytes, Platform.BYTE_ARRAY_OFFSET + 4 + prefixKeyEncoded.length,
+      remainingEncoded.length)
+
+    encodedBytes
+  }
+
+  override def encodeValue(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row)
+
+  override def decodeKey(keyBytes: Array[Byte]): UnsafeRow = {
+    val groupKeyEncodedLen = Platform.getInt(keyBytes, 
Platform.BYTE_ARRAY_OFFSET)
+    val groupKeyEncoded = new Array[Byte](groupKeyEncodedLen)
+    Platform.copyMemory(keyBytes, Platform.BYTE_ARRAY_OFFSET + 4, 
groupKeyEncoded,
+      Platform.BYTE_ARRAY_OFFSET, groupKeyEncodedLen)
+
+    // Here we calculate the remainingKeyEncodedLen leveraging the length of 
keyBytes
+    val remainingKeyEncodedLen = keyBytes.length - 4 - groupKeyEncodedLen
+
+    val remainingKeyEncoded = new Array[Byte](remainingKeyEncodedLen)
+    Platform.copyMemory(keyBytes, Platform.BYTE_ARRAY_OFFSET + 4 +
+      groupKeyEncodedLen, remainingKeyEncoded, Platform.BYTE_ARRAY_OFFSET,
+      remainingKeyEncodedLen)
+
+    val groupKeyDecoded = decodeToUnsafeRow(groupKeyEncoded, 
groupKeyFieldsWithIdx.length)
+    val remainingKeyDecoded = decodeToUnsafeRow(remainingKeyEncoded, numFields 
= 1)

Review comment:
       why `numFields` is always 1 here?

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
##########
@@ -147,14 +146,21 @@ private[state] class RocksDBStateStoreProvider
       stateStoreId: StateStoreId,
       keySchema: StructType,
       valueSchema: StructType,
-      indexOrdinal: Option[Int],
+      numColsPrefixKey: Int,
       storeConf: StateStoreConf,
       hadoopConf: Configuration): Unit = {
     this.stateStoreId_ = stateStoreId
     this.keySchema = keySchema
     this.valueSchema = valueSchema
     this.storeConf = storeConf
     this.hadoopConf = hadoopConf
+
+    require((keySchema.length == 0 && numColsPrefixKey == 0) ||
+      (keySchema.length > numColsPrefixKey), "The number of columns for prefix 
key must be " +
+      "greater than the number of columns in the key!")

Review comment:
       The number of columns in the key must be greater than the number of 
columns for prefix key?

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
##########
@@ -60,21 +60,17 @@ trait ReadStateStore {
   def get(key: UnsafeRow): UnsafeRow
 
   /**
-   * Get key value pairs with optional approximate `start` and `end` extents.
-   * If the State Store implementation maintains indices for the data based on 
the optional
-   * `keyIndexOrdinal` over fields `keySchema` (see 
`StateStoreProvider.init()`), then it can use
-   * `start` and `end` to make a best-effort scan over the data. Default 
implementation returns
-   * the full data scan iterator, which is correct but inefficient. Custom 
implementations must
-   * ensure that updates (puts, removes) can be made while iterating over this 
iterator.
+   * Return an iterator containing all the key-value pairs which are matched 
with
+   * the given prefix key.
+   *
+   * Spark will provide numColsPrefixKey greater than 0 in 
StateStoreProvider.init method if
+   * the state store is responsible to handle the request for prefix scan. The 
schema of the
+   * prefix key should be same with the leftmost `numColsPrefixKey` columns of 
the key schema.
    *
-   * @param start UnsafeRow having the `keyIndexOrdinal` column set with 
appropriate starting value.
-   * @param end UnsafeRow having the `keyIndexOrdinal` column set with 
appropriate ending value.
-   * @return An iterator of key-value pairs that is guaranteed not miss any 
key between start and
-   *         end, both inclusive.
+   * It is expected to throw exception if Spark calls this method without 
setting numColsPrefixKey
+   * to the greater than 0.
    */
-  def getRange(start: Option[UnsafeRow], end: Option[UnsafeRow]): 
Iterator[UnsafeRowPair] = {
-    iterator()
-  }
+  def prefixScan(prefixKey: UnsafeRow): Iterator[UnsafeRowPair]

Review comment:
       Is this necessary to support by third-party StateStore implementation? 
If an implementation doesn't support it, what would happens? Could you clarify 
it in the doc?

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
##########
@@ -251,16 +250,15 @@ trait StateStoreProvider {
    * @param stateStoreId Id of the versioned StateStores that this provider 
will generate
    * @param keySchema Schema of keys to be stored
    * @param valueSchema Schema of value to be stored
-   * @param keyIndexOrdinal Optional column (represent as the ordinal of the 
field in keySchema) by
-   *                        which the StateStore implementation could index 
the data.
+   * @param numColsPrefixKey The number of leftmost columns to be used as 
prefix key.

Review comment:
       "A value not greater than 0 means that the StateStore doesn't support 
`prefixScan` API."




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to