viirya commented on a change in pull request #33038:
URL: https://github.com/apache/spark/pull/33038#discussion_r667439211
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
##########
@@ -253,6 +239,12 @@ private[state] class HDFSBackedStateStoreProvider extends
StateStoreProvider wit
this.storeConf = storeConf
this.hadoopConf = hadoopConf
this.numberOfVersionsToRetainInMemory =
storeConf.maxVersionsToRetainInMemory
+
+ require((keySchema.length == 0 && numColsPrefixKey == 0) ||
+ (keySchema.length > numColsPrefixKey), "The number of columns for prefix
key must be " +
+ "greater than the number of columns in the key!")
+ this.numColsPrefixKey = numColsPrefixKey
Review comment:
The number of column in the key must be greater than the number of
columns for prefix key?
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
##########
@@ -0,0 +1,251 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow,
UnsafeProjection, UnsafeRow}
+import
org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider.{STATE_ENCODING_NUM_VERSION_BYTES,
STATE_ENCODING_VERSION}
+import org.apache.spark.sql.types.{StructField, StructType}
+import org.apache.spark.unsafe.Platform
+
+sealed trait RocksDBStateEncoder {
+ def supportPrefixKeyScan: Boolean
+ def decodePrefixKey(groupKey: UnsafeRow): Array[Byte]
+ def extractPrefixKey(key: UnsafeRow): UnsafeRow
+
+ def encodeKey(row: UnsafeRow): Array[Byte]
+ def encodeValue(row: UnsafeRow): Array[Byte]
+
+ def decodeKey(keyBytes: Array[Byte]): UnsafeRow
+ def decodeValue(valueBytes: Array[Byte]): UnsafeRow
+ def decode(byteArrayTuple: ByteArrayPair): UnsafeRowPair
+}
+
+object RocksDBStateEncoder {
+ def getEncoder(
+ keySchema: StructType,
+ valueSchema: StructType,
+ numColsPrefixKey: Int): RocksDBStateEncoder = {
+ if (numColsPrefixKey > 0) {
+ new PrefixKeyScanStateEncoder(keySchema, valueSchema, numColsPrefixKey)
+ } else {
+ new NoPrefixKeyStateEncoder(keySchema, valueSchema)
+ }
+ }
+
+ /**
+ * Encode the UnsafeRow of N bytes as a N+1 byte array.
+ * @note This creates a new byte array and memcopies the UnsafeRow to the
new array.
+ */
+ def encodeUnsafeRow(row: UnsafeRow): Array[Byte] = {
+ val bytesToEncode = row.getBytes
+ val encodedBytes = new Array[Byte](bytesToEncode.length +
STATE_ENCODING_NUM_VERSION_BYTES)
+ Platform.putByte(encodedBytes, Platform.BYTE_ARRAY_OFFSET,
STATE_ENCODING_VERSION)
+ // Platform.BYTE_ARRAY_OFFSET is the recommended way to memcopy b/w byte
arrays. See Platform.
+ Platform.copyMemory(
+ bytesToEncode, Platform.BYTE_ARRAY_OFFSET,
+ encodedBytes, Platform.BYTE_ARRAY_OFFSET +
STATE_ENCODING_NUM_VERSION_BYTES,
+ bytesToEncode.length)
+ encodedBytes
+ }
+
+ def decodeToUnsafeRow(bytes: Array[Byte], numFields: Int): UnsafeRow = {
+ if (bytes != null) {
+ val row = new UnsafeRow(numFields)
+ decodeToUnsafeRow(bytes, row)
+ } else {
+ null
+ }
+ }
+
+ def decodeToUnsafeRow(bytes: Array[Byte], reusedRow: UnsafeRow): UnsafeRow =
{
+ if (bytes != null) {
+ // Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st
offset. See Platform.
+ reusedRow.pointTo(
+ bytes,
+ Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES,
+ bytes.length - STATE_ENCODING_NUM_VERSION_BYTES)
+ reusedRow
+ } else {
+ null
+ }
+ }
+}
+
+class PrefixKeyScanStateEncoder(
+ keySchema: StructType,
+ valueSchema: StructType,
+ numColsPrefixKey: Int) extends RocksDBStateEncoder {
+
+ import RocksDBStateEncoder._
+
+ require(keySchema.length > numColsPrefixKey, "The number of columns for
prefix key must be " +
+ "greater than the number of columns in the key!")
+
+ private val groupKeyFieldsWithIdx: Seq[(StructField, Int)] = {
+ keySchema.zipWithIndex.take(numColsPrefixKey)
+ }
+
+ private val remainingKeyFieldsWithIdx: Seq[(StructField, Int)] = {
+ keySchema.zipWithIndex.drop(numColsPrefixKey)
+ }
+
+ private val prefixKeyProjection: UnsafeProjection = {
+ val refs = groupKeyFieldsWithIdx.map(x => BoundReference(x._2,
x._1.dataType, x._1.nullable))
+ UnsafeProjection.create(refs)
+ }
+
+ private val remainingKeyProjection: UnsafeProjection = {
+ val refs = remainingKeyFieldsWithIdx.map(x =>
+ BoundReference(x._2, x._1.dataType, x._1.nullable))
+ UnsafeProjection.create(refs)
+ }
+
+ // This is quite simple to do - just bind sequentially, as we don't change
the order.
+ private val restoreKeyProjection: UnsafeProjection =
UnsafeProjection.create(keySchema)
+
+ // Reusable objects
+ private val joinedRowOnKey = new JoinedRow()
+ private val valueRow = new UnsafeRow(valueSchema.size)
+ private val rowTuple = new UnsafeRowPair()
+
+ override def encodeKey(row: UnsafeRow): Array[Byte] = {
+ val prefixKeyEncoded = encodeUnsafeRow(extractPrefixKey(row))
+ val remainingEncoded = encodeUnsafeRow(remainingKeyProjection(row))
+
+ val encodedBytes = new Array[Byte](prefixKeyEncoded.length +
remainingEncoded.length + 4)
+ Platform.putInt(encodedBytes, Platform.BYTE_ARRAY_OFFSET,
prefixKeyEncoded.length)
+ Platform.copyMemory(prefixKeyEncoded, Platform.BYTE_ARRAY_OFFSET,
+ encodedBytes, Platform.BYTE_ARRAY_OFFSET + 4, prefixKeyEncoded.length)
+ // NOTE: We don't put the length of remainingEncoded as we can calculate
later
+ // on deserialization.
+ Platform.copyMemory(remainingEncoded, Platform.BYTE_ARRAY_OFFSET,
+ encodedBytes, Platform.BYTE_ARRAY_OFFSET + 4 + prefixKeyEncoded.length,
+ remainingEncoded.length)
+
+ encodedBytes
+ }
+
+ override def encodeValue(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row)
+
+ override def decodeKey(keyBytes: Array[Byte]): UnsafeRow = {
+ val groupKeyEncodedLen = Platform.getInt(keyBytes,
Platform.BYTE_ARRAY_OFFSET)
+ val groupKeyEncoded = new Array[Byte](groupKeyEncodedLen)
+ Platform.copyMemory(keyBytes, Platform.BYTE_ARRAY_OFFSET + 4,
groupKeyEncoded,
+ Platform.BYTE_ARRAY_OFFSET, groupKeyEncodedLen)
+
+ // Here we calculate the remainingKeyEncodedLen leveraging the length of
keyBytes
+ val remainingKeyEncodedLen = keyBytes.length - 4 - groupKeyEncodedLen
+
+ val remainingKeyEncoded = new Array[Byte](remainingKeyEncodedLen)
+ Platform.copyMemory(keyBytes, Platform.BYTE_ARRAY_OFFSET + 4 +
+ groupKeyEncodedLen, remainingKeyEncoded, Platform.BYTE_ARRAY_OFFSET,
+ remainingKeyEncodedLen)
+
+ val groupKeyDecoded = decodeToUnsafeRow(groupKeyEncoded,
groupKeyFieldsWithIdx.length)
+ val remainingKeyDecoded = decodeToUnsafeRow(remainingKeyEncoded, numFields
= 1)
+
+
restoreKeyProjection(joinedRowOnKey.withLeft(groupKeyDecoded).withRight(remainingKeyDecoded))
+ }
+
+ override def decodeValue(valueBytes: Array[Byte]): UnsafeRow = {
+ decodeToUnsafeRow(valueBytes, valueRow)
+ }
+
+ override def extractPrefixKey(key: UnsafeRow): UnsafeRow = {
+ prefixKeyProjection(key)
+ }
+
+ override def decodePrefixKey(groupKey: UnsafeRow): Array[Byte] = {
Review comment:
What this method does more like `encodePrefixKey`?
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
##########
@@ -185,6 +189,33 @@ class RocksDB(
}
}
+ def prefixScan(prefix: Array[Byte]): Iterator[ByteArrayPair] = {
+ val threadId = Thread.currentThread().getId
+ val iter = prefixScanReuseIter.computeIfAbsent(threadId, tid => {
+ val it = writeBatch.newIteratorWithBase(db.newIterator())
+ logInfo(s"Getting iterator from version $loadedVersion for prefix scan
on " +
+ s"tid $tid")
Review comment:
tid -> thread id?
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
##########
@@ -0,0 +1,251 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow,
UnsafeProjection, UnsafeRow}
+import
org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider.{STATE_ENCODING_NUM_VERSION_BYTES,
STATE_ENCODING_VERSION}
+import org.apache.spark.sql.types.{StructField, StructType}
+import org.apache.spark.unsafe.Platform
+
+sealed trait RocksDBStateEncoder {
+ def supportPrefixKeyScan: Boolean
+ def decodePrefixKey(groupKey: UnsafeRow): Array[Byte]
+ def extractPrefixKey(key: UnsafeRow): UnsafeRow
+
+ def encodeKey(row: UnsafeRow): Array[Byte]
+ def encodeValue(row: UnsafeRow): Array[Byte]
+
+ def decodeKey(keyBytes: Array[Byte]): UnsafeRow
+ def decodeValue(valueBytes: Array[Byte]): UnsafeRow
+ def decode(byteArrayTuple: ByteArrayPair): UnsafeRowPair
+}
+
+object RocksDBStateEncoder {
+ def getEncoder(
+ keySchema: StructType,
+ valueSchema: StructType,
+ numColsPrefixKey: Int): RocksDBStateEncoder = {
+ if (numColsPrefixKey > 0) {
+ new PrefixKeyScanStateEncoder(keySchema, valueSchema, numColsPrefixKey)
+ } else {
+ new NoPrefixKeyStateEncoder(keySchema, valueSchema)
+ }
+ }
+
+ /**
+ * Encode the UnsafeRow of N bytes as a N+1 byte array.
+ * @note This creates a new byte array and memcopies the UnsafeRow to the
new array.
+ */
+ def encodeUnsafeRow(row: UnsafeRow): Array[Byte] = {
+ val bytesToEncode = row.getBytes
+ val encodedBytes = new Array[Byte](bytesToEncode.length +
STATE_ENCODING_NUM_VERSION_BYTES)
+ Platform.putByte(encodedBytes, Platform.BYTE_ARRAY_OFFSET,
STATE_ENCODING_VERSION)
+ // Platform.BYTE_ARRAY_OFFSET is the recommended way to memcopy b/w byte
arrays. See Platform.
+ Platform.copyMemory(
+ bytesToEncode, Platform.BYTE_ARRAY_OFFSET,
+ encodedBytes, Platform.BYTE_ARRAY_OFFSET +
STATE_ENCODING_NUM_VERSION_BYTES,
+ bytesToEncode.length)
+ encodedBytes
+ }
+
+ def decodeToUnsafeRow(bytes: Array[Byte], numFields: Int): UnsafeRow = {
+ if (bytes != null) {
+ val row = new UnsafeRow(numFields)
+ decodeToUnsafeRow(bytes, row)
+ } else {
+ null
+ }
+ }
+
+ def decodeToUnsafeRow(bytes: Array[Byte], reusedRow: UnsafeRow): UnsafeRow =
{
+ if (bytes != null) {
+ // Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st
offset. See Platform.
+ reusedRow.pointTo(
+ bytes,
+ Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES,
+ bytes.length - STATE_ENCODING_NUM_VERSION_BYTES)
+ reusedRow
+ } else {
+ null
+ }
+ }
+}
+
+class PrefixKeyScanStateEncoder(
+ keySchema: StructType,
+ valueSchema: StructType,
+ numColsPrefixKey: Int) extends RocksDBStateEncoder {
+
+ import RocksDBStateEncoder._
+
+ require(keySchema.length > numColsPrefixKey, "The number of columns for
prefix key must be " +
+ "greater than the number of columns in the key!")
+
Review comment:
The number of columns in the key must be greater than the number of
columns for prefix key?
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
##########
@@ -0,0 +1,251 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow,
UnsafeProjection, UnsafeRow}
+import
org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider.{STATE_ENCODING_NUM_VERSION_BYTES,
STATE_ENCODING_VERSION}
+import org.apache.spark.sql.types.{StructField, StructType}
+import org.apache.spark.unsafe.Platform
+
+sealed trait RocksDBStateEncoder {
+ def supportPrefixKeyScan: Boolean
+ def decodePrefixKey(groupKey: UnsafeRow): Array[Byte]
+ def extractPrefixKey(key: UnsafeRow): UnsafeRow
+
+ def encodeKey(row: UnsafeRow): Array[Byte]
+ def encodeValue(row: UnsafeRow): Array[Byte]
+
+ def decodeKey(keyBytes: Array[Byte]): UnsafeRow
+ def decodeValue(valueBytes: Array[Byte]): UnsafeRow
+ def decode(byteArrayTuple: ByteArrayPair): UnsafeRowPair
+}
+
+object RocksDBStateEncoder {
+ def getEncoder(
+ keySchema: StructType,
+ valueSchema: StructType,
+ numColsPrefixKey: Int): RocksDBStateEncoder = {
+ if (numColsPrefixKey > 0) {
+ new PrefixKeyScanStateEncoder(keySchema, valueSchema, numColsPrefixKey)
+ } else {
+ new NoPrefixKeyStateEncoder(keySchema, valueSchema)
+ }
+ }
+
+ /**
+ * Encode the UnsafeRow of N bytes as a N+1 byte array.
+ * @note This creates a new byte array and memcopies the UnsafeRow to the
new array.
+ */
+ def encodeUnsafeRow(row: UnsafeRow): Array[Byte] = {
+ val bytesToEncode = row.getBytes
+ val encodedBytes = new Array[Byte](bytesToEncode.length +
STATE_ENCODING_NUM_VERSION_BYTES)
+ Platform.putByte(encodedBytes, Platform.BYTE_ARRAY_OFFSET,
STATE_ENCODING_VERSION)
+ // Platform.BYTE_ARRAY_OFFSET is the recommended way to memcopy b/w byte
arrays. See Platform.
+ Platform.copyMemory(
+ bytesToEncode, Platform.BYTE_ARRAY_OFFSET,
+ encodedBytes, Platform.BYTE_ARRAY_OFFSET +
STATE_ENCODING_NUM_VERSION_BYTES,
+ bytesToEncode.length)
+ encodedBytes
+ }
+
+ def decodeToUnsafeRow(bytes: Array[Byte], numFields: Int): UnsafeRow = {
+ if (bytes != null) {
+ val row = new UnsafeRow(numFields)
+ decodeToUnsafeRow(bytes, row)
+ } else {
+ null
+ }
+ }
+
+ def decodeToUnsafeRow(bytes: Array[Byte], reusedRow: UnsafeRow): UnsafeRow =
{
+ if (bytes != null) {
+ // Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st
offset. See Platform.
+ reusedRow.pointTo(
+ bytes,
+ Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES,
+ bytes.length - STATE_ENCODING_NUM_VERSION_BYTES)
+ reusedRow
+ } else {
+ null
+ }
+ }
+}
+
+class PrefixKeyScanStateEncoder(
+ keySchema: StructType,
+ valueSchema: StructType,
+ numColsPrefixKey: Int) extends RocksDBStateEncoder {
+
+ import RocksDBStateEncoder._
+
+ require(keySchema.length > numColsPrefixKey, "The number of columns for
prefix key must be " +
+ "greater than the number of columns in the key!")
+
+ private val groupKeyFieldsWithIdx: Seq[(StructField, Int)] = {
+ keySchema.zipWithIndex.take(numColsPrefixKey)
+ }
+
+ private val remainingKeyFieldsWithIdx: Seq[(StructField, Int)] = {
+ keySchema.zipWithIndex.drop(numColsPrefixKey)
+ }
+
+ private val prefixKeyProjection: UnsafeProjection = {
+ val refs = groupKeyFieldsWithIdx.map(x => BoundReference(x._2,
x._1.dataType, x._1.nullable))
+ UnsafeProjection.create(refs)
+ }
+
+ private val remainingKeyProjection: UnsafeProjection = {
+ val refs = remainingKeyFieldsWithIdx.map(x =>
+ BoundReference(x._2, x._1.dataType, x._1.nullable))
+ UnsafeProjection.create(refs)
+ }
+
+ // This is quite simple to do - just bind sequentially, as we don't change
the order.
+ private val restoreKeyProjection: UnsafeProjection =
UnsafeProjection.create(keySchema)
+
+ // Reusable objects
+ private val joinedRowOnKey = new JoinedRow()
+ private val valueRow = new UnsafeRow(valueSchema.size)
+ private val rowTuple = new UnsafeRowPair()
+
+ override def encodeKey(row: UnsafeRow): Array[Byte] = {
+ val prefixKeyEncoded = encodeUnsafeRow(extractPrefixKey(row))
+ val remainingEncoded = encodeUnsafeRow(remainingKeyProjection(row))
+
+ val encodedBytes = new Array[Byte](prefixKeyEncoded.length +
remainingEncoded.length + 4)
+ Platform.putInt(encodedBytes, Platform.BYTE_ARRAY_OFFSET,
prefixKeyEncoded.length)
+ Platform.copyMemory(prefixKeyEncoded, Platform.BYTE_ARRAY_OFFSET,
+ encodedBytes, Platform.BYTE_ARRAY_OFFSET + 4, prefixKeyEncoded.length)
+ // NOTE: We don't put the length of remainingEncoded as we can calculate
later
+ // on deserialization.
+ Platform.copyMemory(remainingEncoded, Platform.BYTE_ARRAY_OFFSET,
+ encodedBytes, Platform.BYTE_ARRAY_OFFSET + 4 + prefixKeyEncoded.length,
+ remainingEncoded.length)
+
+ encodedBytes
+ }
+
+ override def encodeValue(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row)
+
+ override def decodeKey(keyBytes: Array[Byte]): UnsafeRow = {
+ val groupKeyEncodedLen = Platform.getInt(keyBytes,
Platform.BYTE_ARRAY_OFFSET)
+ val groupKeyEncoded = new Array[Byte](groupKeyEncodedLen)
+ Platform.copyMemory(keyBytes, Platform.BYTE_ARRAY_OFFSET + 4,
groupKeyEncoded,
+ Platform.BYTE_ARRAY_OFFSET, groupKeyEncodedLen)
+
+ // Here we calculate the remainingKeyEncodedLen leveraging the length of
keyBytes
+ val remainingKeyEncodedLen = keyBytes.length - 4 - groupKeyEncodedLen
+
+ val remainingKeyEncoded = new Array[Byte](remainingKeyEncodedLen)
+ Platform.copyMemory(keyBytes, Platform.BYTE_ARRAY_OFFSET + 4 +
+ groupKeyEncodedLen, remainingKeyEncoded, Platform.BYTE_ARRAY_OFFSET,
+ remainingKeyEncodedLen)
+
+ val groupKeyDecoded = decodeToUnsafeRow(groupKeyEncoded,
groupKeyFieldsWithIdx.length)
+ val remainingKeyDecoded = decodeToUnsafeRow(remainingKeyEncoded, numFields
= 1)
+
+
restoreKeyProjection(joinedRowOnKey.withLeft(groupKeyDecoded).withRight(remainingKeyDecoded))
+ }
+
+ override def decodeValue(valueBytes: Array[Byte]): UnsafeRow = {
+ decodeToUnsafeRow(valueBytes, valueRow)
+ }
+
+ override def extractPrefixKey(key: UnsafeRow): UnsafeRow = {
+ prefixKeyProjection(key)
+ }
+
+ override def decodePrefixKey(groupKey: UnsafeRow): Array[Byte] = {
+ val groupKeyEncoded = encodeUnsafeRow(groupKey)
Review comment:
Shall we name `prefixKeyEncoded` (and other similar variables) to be
consistent with method/API name?
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
##########
@@ -60,21 +60,17 @@ trait ReadStateStore {
def get(key: UnsafeRow): UnsafeRow
/**
- * Get key value pairs with optional approximate `start` and `end` extents.
- * If the State Store implementation maintains indices for the data based on
the optional
- * `keyIndexOrdinal` over fields `keySchema` (see
`StateStoreProvider.init()`), then it can use
- * `start` and `end` to make a best-effort scan over the data. Default
implementation returns
- * the full data scan iterator, which is correct but inefficient. Custom
implementations must
- * ensure that updates (puts, removes) can be made while iterating over this
iterator.
+ * Return an iterator containing all the key-value pairs which are matched
with
+ * the given prefix key.
+ *
+ * Spark will provide numColsPrefixKey greater than 0 in
StateStoreProvider.init method if
+ * the state store is responsible to handle the request for prefix scan. The
schema of the
+ * prefix key should be same with the leftmost `numColsPrefixKey` columns of
the key schema.
*
- * @param start UnsafeRow having the `keyIndexOrdinal` column set with
appropriate starting value.
- * @param end UnsafeRow having the `keyIndexOrdinal` column set with
appropriate ending value.
- * @return An iterator of key-value pairs that is guaranteed not miss any
key between start and
- * end, both inclusive.
+ * It is expected to throw exception if Spark calls this method without
setting numColsPrefixKey
+ * to the greater than 0.
*/
- def getRange(start: Option[UnsafeRow], end: Option[UnsafeRow]):
Iterator[UnsafeRowPair] = {
- iterator()
- }
+ def prefixScan(prefixKey: UnsafeRow): Iterator[UnsafeRowPair]
Review comment:
E.g., `MemoryStateStore` doesn't support it.
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
##########
@@ -0,0 +1,251 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow,
UnsafeProjection, UnsafeRow}
+import
org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider.{STATE_ENCODING_NUM_VERSION_BYTES,
STATE_ENCODING_VERSION}
+import org.apache.spark.sql.types.{StructField, StructType}
+import org.apache.spark.unsafe.Platform
+
+sealed trait RocksDBStateEncoder {
+ def supportPrefixKeyScan: Boolean
+ def decodePrefixKey(groupKey: UnsafeRow): Array[Byte]
+ def extractPrefixKey(key: UnsafeRow): UnsafeRow
+
+ def encodeKey(row: UnsafeRow): Array[Byte]
+ def encodeValue(row: UnsafeRow): Array[Byte]
+
+ def decodeKey(keyBytes: Array[Byte]): UnsafeRow
+ def decodeValue(valueBytes: Array[Byte]): UnsafeRow
+ def decode(byteArrayTuple: ByteArrayPair): UnsafeRowPair
+}
+
+object RocksDBStateEncoder {
+ def getEncoder(
+ keySchema: StructType,
+ valueSchema: StructType,
+ numColsPrefixKey: Int): RocksDBStateEncoder = {
+ if (numColsPrefixKey > 0) {
+ new PrefixKeyScanStateEncoder(keySchema, valueSchema, numColsPrefixKey)
+ } else {
+ new NoPrefixKeyStateEncoder(keySchema, valueSchema)
+ }
+ }
+
+ /**
+ * Encode the UnsafeRow of N bytes as a N+1 byte array.
+ * @note This creates a new byte array and memcopies the UnsafeRow to the
new array.
+ */
+ def encodeUnsafeRow(row: UnsafeRow): Array[Byte] = {
+ val bytesToEncode = row.getBytes
+ val encodedBytes = new Array[Byte](bytesToEncode.length +
STATE_ENCODING_NUM_VERSION_BYTES)
+ Platform.putByte(encodedBytes, Platform.BYTE_ARRAY_OFFSET,
STATE_ENCODING_VERSION)
+ // Platform.BYTE_ARRAY_OFFSET is the recommended way to memcopy b/w byte
arrays. See Platform.
+ Platform.copyMemory(
+ bytesToEncode, Platform.BYTE_ARRAY_OFFSET,
+ encodedBytes, Platform.BYTE_ARRAY_OFFSET +
STATE_ENCODING_NUM_VERSION_BYTES,
+ bytesToEncode.length)
+ encodedBytes
+ }
+
+ def decodeToUnsafeRow(bytes: Array[Byte], numFields: Int): UnsafeRow = {
+ if (bytes != null) {
+ val row = new UnsafeRow(numFields)
+ decodeToUnsafeRow(bytes, row)
+ } else {
+ null
+ }
+ }
+
+ def decodeToUnsafeRow(bytes: Array[Byte], reusedRow: UnsafeRow): UnsafeRow =
{
+ if (bytes != null) {
+ // Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st
offset. See Platform.
+ reusedRow.pointTo(
+ bytes,
+ Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES,
+ bytes.length - STATE_ENCODING_NUM_VERSION_BYTES)
+ reusedRow
+ } else {
+ null
+ }
+ }
+}
+
+class PrefixKeyScanStateEncoder(
+ keySchema: StructType,
+ valueSchema: StructType,
+ numColsPrefixKey: Int) extends RocksDBStateEncoder {
+
+ import RocksDBStateEncoder._
+
+ require(keySchema.length > numColsPrefixKey, "The number of columns for
prefix key must be " +
+ "greater than the number of columns in the key!")
+
+ private val groupKeyFieldsWithIdx: Seq[(StructField, Int)] = {
+ keySchema.zipWithIndex.take(numColsPrefixKey)
+ }
+
+ private val remainingKeyFieldsWithIdx: Seq[(StructField, Int)] = {
+ keySchema.zipWithIndex.drop(numColsPrefixKey)
+ }
+
+ private val prefixKeyProjection: UnsafeProjection = {
+ val refs = groupKeyFieldsWithIdx.map(x => BoundReference(x._2,
x._1.dataType, x._1.nullable))
+ UnsafeProjection.create(refs)
+ }
+
+ private val remainingKeyProjection: UnsafeProjection = {
+ val refs = remainingKeyFieldsWithIdx.map(x =>
+ BoundReference(x._2, x._1.dataType, x._1.nullable))
+ UnsafeProjection.create(refs)
+ }
+
+ // This is quite simple to do - just bind sequentially, as we don't change
the order.
+ private val restoreKeyProjection: UnsafeProjection =
UnsafeProjection.create(keySchema)
+
+ // Reusable objects
+ private val joinedRowOnKey = new JoinedRow()
+ private val valueRow = new UnsafeRow(valueSchema.size)
+ private val rowTuple = new UnsafeRowPair()
+
+ override def encodeKey(row: UnsafeRow): Array[Byte] = {
+ val prefixKeyEncoded = encodeUnsafeRow(extractPrefixKey(row))
+ val remainingEncoded = encodeUnsafeRow(remainingKeyProjection(row))
+
+ val encodedBytes = new Array[Byte](prefixKeyEncoded.length +
remainingEncoded.length + 4)
+ Platform.putInt(encodedBytes, Platform.BYTE_ARRAY_OFFSET,
prefixKeyEncoded.length)
+ Platform.copyMemory(prefixKeyEncoded, Platform.BYTE_ARRAY_OFFSET,
+ encodedBytes, Platform.BYTE_ARRAY_OFFSET + 4, prefixKeyEncoded.length)
+ // NOTE: We don't put the length of remainingEncoded as we can calculate
later
+ // on deserialization.
+ Platform.copyMemory(remainingEncoded, Platform.BYTE_ARRAY_OFFSET,
+ encodedBytes, Platform.BYTE_ARRAY_OFFSET + 4 + prefixKeyEncoded.length,
+ remainingEncoded.length)
+
+ encodedBytes
+ }
+
+ override def encodeValue(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row)
+
+ override def decodeKey(keyBytes: Array[Byte]): UnsafeRow = {
+ val groupKeyEncodedLen = Platform.getInt(keyBytes,
Platform.BYTE_ARRAY_OFFSET)
+ val groupKeyEncoded = new Array[Byte](groupKeyEncodedLen)
+ Platform.copyMemory(keyBytes, Platform.BYTE_ARRAY_OFFSET + 4,
groupKeyEncoded,
+ Platform.BYTE_ARRAY_OFFSET, groupKeyEncodedLen)
+
+ // Here we calculate the remainingKeyEncodedLen leveraging the length of
keyBytes
+ val remainingKeyEncodedLen = keyBytes.length - 4 - groupKeyEncodedLen
+
+ val remainingKeyEncoded = new Array[Byte](remainingKeyEncodedLen)
+ Platform.copyMemory(keyBytes, Platform.BYTE_ARRAY_OFFSET + 4 +
+ groupKeyEncodedLen, remainingKeyEncoded, Platform.BYTE_ARRAY_OFFSET,
+ remainingKeyEncodedLen)
+
+ val groupKeyDecoded = decodeToUnsafeRow(groupKeyEncoded,
groupKeyFieldsWithIdx.length)
+ val remainingKeyDecoded = decodeToUnsafeRow(remainingKeyEncoded, numFields
= 1)
Review comment:
why `numFields` is always 1 here?
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
##########
@@ -147,14 +146,21 @@ private[state] class RocksDBStateStoreProvider
stateStoreId: StateStoreId,
keySchema: StructType,
valueSchema: StructType,
- indexOrdinal: Option[Int],
+ numColsPrefixKey: Int,
storeConf: StateStoreConf,
hadoopConf: Configuration): Unit = {
this.stateStoreId_ = stateStoreId
this.keySchema = keySchema
this.valueSchema = valueSchema
this.storeConf = storeConf
this.hadoopConf = hadoopConf
+
+ require((keySchema.length == 0 && numColsPrefixKey == 0) ||
+ (keySchema.length > numColsPrefixKey), "The number of columns for prefix
key must be " +
+ "greater than the number of columns in the key!")
Review comment:
The number of columns in the key must be greater than the number of
columns for prefix key?
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
##########
@@ -60,21 +60,17 @@ trait ReadStateStore {
def get(key: UnsafeRow): UnsafeRow
/**
- * Get key value pairs with optional approximate `start` and `end` extents.
- * If the State Store implementation maintains indices for the data based on
the optional
- * `keyIndexOrdinal` over fields `keySchema` (see
`StateStoreProvider.init()`), then it can use
- * `start` and `end` to make a best-effort scan over the data. Default
implementation returns
- * the full data scan iterator, which is correct but inefficient. Custom
implementations must
- * ensure that updates (puts, removes) can be made while iterating over this
iterator.
+ * Return an iterator containing all the key-value pairs which are matched
with
+ * the given prefix key.
+ *
+ * Spark will provide numColsPrefixKey greater than 0 in
StateStoreProvider.init method if
+ * the state store is responsible to handle the request for prefix scan. The
schema of the
+ * prefix key should be same with the leftmost `numColsPrefixKey` columns of
the key schema.
*
- * @param start UnsafeRow having the `keyIndexOrdinal` column set with
appropriate starting value.
- * @param end UnsafeRow having the `keyIndexOrdinal` column set with
appropriate ending value.
- * @return An iterator of key-value pairs that is guaranteed not miss any
key between start and
- * end, both inclusive.
+ * It is expected to throw exception if Spark calls this method without
setting numColsPrefixKey
+ * to the greater than 0.
*/
- def getRange(start: Option[UnsafeRow], end: Option[UnsafeRow]):
Iterator[UnsafeRowPair] = {
- iterator()
- }
+ def prefixScan(prefixKey: UnsafeRow): Iterator[UnsafeRowPair]
Review comment:
Is this necessary to support by third-party StateStore implementation?
If an implementation doesn't support it, what would happens? Could you clarify
it in the doc?
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
##########
@@ -251,16 +250,15 @@ trait StateStoreProvider {
* @param stateStoreId Id of the versioned StateStores that this provider
will generate
* @param keySchema Schema of keys to be stored
* @param valueSchema Schema of value to be stored
- * @param keyIndexOrdinal Optional column (represent as the ordinal of the
field in keySchema) by
- * which the StateStore implementation could index
the data.
+ * @param numColsPrefixKey The number of leftmost columns to be used as
prefix key.
Review comment:
"A value not greater than 0 means that the StateStore doesn't support
`prefixScan` API."
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]