This is an automated email from the ASF dual-hosted git repository.
kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 229118ca7a12 [SPARK-50599][SQL] Create the DataEncoder trait that
allows for Avro and UnsafeRow encoding
229118ca7a12 is described below
commit 229118ca7a127753635543909efdb27601985d42
Author: Eric Marnadi <[email protected]>
AuthorDate: Wed Dec 18 16:06:45 2024 +0900
[SPARK-50599][SQL] Create the DataEncoder trait that allows for Avro and
UnsafeRow encoding
### What changes were proposed in this pull request?
Currently, we use the internal byte representation to store state for
stateful streaming operators in the StateStore. This PR introduces Avro
serialization and deserialization capabilities in the RocksDBStateEncoder so
that we can instead use Avro encoding to store state. This is currently enabled
for the TransformWithState operator via SQLConf to support all functionality
supported by TWS
### Why are the changes needed?
UnsafeRow is an inherently unstable format that makes no guarantees of
being backwards-compatible. Therefore, if the format changes between Spark
releases, this could cause StateStore corruptions. Avro is more stable, and
inherently enables schema evolution.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Amended and added to unit tests
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #48944 from ericm-db/avro-ss.
Lead-authored-by: Eric Marnadi <[email protected]>
Co-authored-by: Eric Marnadi <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
---
.../streaming/state/RocksDBStateEncoder.scala | 1711 +++++++++++---------
.../state/RocksDBStateStoreProvider.scala | 169 +-
.../sql/execution/streaming/state/StateStore.scala | 41 +
3 files changed, 1080 insertions(+), 841 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
index f39022c1f53a..b4f619781193 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
@@ -27,7 +27,7 @@ import org.apache.avro.generic.{GenericData,
GenericDatumReader, GenericDatumWri
import org.apache.avro.io.{DecoderFactory, EncoderFactory}
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.avro.{AvroDeserializer, AvroSerializer,
SchemaConverters}
+import org.apache.spark.sql.avro.{AvroDeserializer, AvroOptions,
AvroSerializer, SchemaConverters}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow,
UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
@@ -51,93 +51,138 @@ sealed trait RocksDBValueStateEncoder {
def decodeValues(valueBytes: Array[Byte]): Iterator[UnsafeRow]
}
-abstract class RocksDBKeyStateEncoderBase(
- useColumnFamilies: Boolean,
- virtualColFamilyId: Option[Short] = None) extends RocksDBKeyStateEncoder {
- def offsetForColFamilyPrefix: Int =
- if (useColumnFamilies) VIRTUAL_COL_FAMILY_PREFIX_BYTES else 0
+/**
+ * The DataEncoder can encode UnsafeRows into raw bytes in two ways:
+ * - Using the direct byte layout of the UnsafeRow
+ * - Converting the UnsafeRow into an Avro row, and encoding that
+ * In both of these cases, the raw bytes that are written into RockDB have
+ * headers, footers and other metadata, but they also have data that is
provided
+ * by the callers. The metadata in each row does not need to be written as
Avro or UnsafeRow,
+ * but the actual data provided by the caller does.
+ * The classes that use this trait require specialized partial encoding which
makes them much
+ * easier to cache and use, which is why each DataEncoder deals with multiple
schemas.
+ */
+trait DataEncoder {
+ /**
+ * Encodes a complete key row into bytes. Used as the primary key for state
lookups.
+ *
+ * @param row An UnsafeRow containing all key columns as defined in the
keySchema
+ * @return Serialized byte array representation of the key
+ */
+ def encodeKey(row: UnsafeRow): Array[Byte]
- val out = new ByteArrayOutputStream
/**
- * Get Byte Array for the virtual column family id that is used as prefix for
- * key state rows.
+ * Encodes the non-prefix portion of a key row. Used with prefix scan and
+ * range scan state lookups where the key is split into prefix and remaining
portions.
+ *
+ * For prefix scans: Encodes columns after the prefix columns
+ * For range scans: Encodes columns not included in the ordering columns
+ *
+ * @param row An UnsafeRow containing only the remaining key columns
+ * @return Serialized byte array of the remaining key portion
+ * @throws UnsupportedOperationException if called on an encoder that
doesn't support split keys
*/
- override def getColumnFamilyIdBytes(): Array[Byte] = {
- assert(useColumnFamilies, "Cannot return virtual Column Family Id Bytes" +
- " because multiple Column is not supported for this encoder")
- val encodedBytes = new Array[Byte](VIRTUAL_COL_FAMILY_PREFIX_BYTES)
- Platform.putShort(encodedBytes, Platform.BYTE_ARRAY_OFFSET,
virtualColFamilyId.get)
- encodedBytes
- }
+ def encodeRemainingKey(row: UnsafeRow): Array[Byte]
/**
- * Encode and put column family Id as a prefix to a pre-allocated byte array.
+ * Encodes key columns used for range scanning, ensuring proper sort order
in RocksDB.
*
- * @param numBytes - size of byte array to be created for storing key row
(without
- * column family prefix)
- * @return Array[Byte] for an array byte to put encoded key bytes
- * Int for a starting offset to put the encoded key bytes
+ * This method handles special encoding for numeric types to maintain
correct sort order:
+ * - Adds sign byte markers for numeric types
+ * - Flips bits for negative floating point values
+ * - Preserves null ordering
+ *
+ * @param row An UnsafeRow containing the columns needed for range scan
+ * (specified by orderingOrdinals)
+ * @return Serialized bytes that will maintain correct sort order in RocksDB
+ * @throws UnsupportedOperationException if called on an encoder that
doesn't support range scans
*/
- protected def encodeColumnFamilyPrefix(numBytes: Int): (Array[Byte], Int) = {
- val encodedBytes = new Array[Byte](numBytes + offsetForColFamilyPrefix)
- var offset = Platform.BYTE_ARRAY_OFFSET
- if (useColumnFamilies) {
- Platform.putShort(encodedBytes, Platform.BYTE_ARRAY_OFFSET,
virtualColFamilyId.get)
- offset = Platform.BYTE_ARRAY_OFFSET + offsetForColFamilyPrefix
- }
- (encodedBytes, offset)
- }
+ def encodePrefixKeyForRangeScan(row: UnsafeRow): Array[Byte]
/**
- * Get starting offset for decoding an encoded key byte array.
+ * Encodes a value row into bytes.
+ *
+ * @param row An UnsafeRow containing the value columns as defined in the
valueSchema
+ * @return Serialized byte array representation of the value
*/
- protected def decodeKeyStartOffset: Int = {
- if (useColumnFamilies) {
- Platform.BYTE_ARRAY_OFFSET + VIRTUAL_COL_FAMILY_PREFIX_BYTES
- } else Platform.BYTE_ARRAY_OFFSET
- }
+ def encodeValue(row: UnsafeRow): Array[Byte]
+
+ /**
+ * Decodes a complete key from its serialized byte form.
+ *
+ * For NoPrefixKeyStateEncoder: Decodes the entire key
+ * For PrefixKeyScanStateEncoder: Decodes only the prefix portion
+ *
+ * @param bytes Serialized byte array containing the encoded key
+ * @return UnsafeRow containing the decoded key columns
+ * @throws UnsupportedOperationException for unsupported encoder types
+ */
+ def decodeKey(bytes: Array[Byte]): UnsafeRow
+
+ /**
+ * Decodes the remaining portion of a split key from its serialized form.
+ *
+ * For PrefixKeyScanStateEncoder: Decodes columns after the prefix
+ * For RangeKeyScanStateEncoder: Decodes non-ordering columns
+ *
+ * @param bytes Serialized byte array containing the encoded remaining key
portion
+ * @return UnsafeRow containing the decoded remaining key columns
+ * @throws UnsupportedOperationException if called on an encoder that
doesn't support split keys
+ */
+ def decodeRemainingKey(bytes: Array[Byte]): UnsafeRow
+
+ /**
+ * Decodes range scan key bytes back into an UnsafeRow, preserving proper
ordering.
+ *
+ * This method reverses the special encoding done by
encodePrefixKeyForRangeScan:
+ * - Interprets sign byte markers
+ * - Reverses bit flipping for negative floating point values
+ * - Handles null values
+ *
+ * @param bytes Serialized byte array containing the encoded range scan key
+ * @return UnsafeRow containing the decoded range scan columns
+ * @throws UnsupportedOperationException if called on an encoder that
doesn't support range scans
+ */
+ def decodePrefixKeyForRangeScan(bytes: Array[Byte]): UnsafeRow
+
+ /**
+ * Decodes a value from its serialized byte form.
+ *
+ * @param bytes Serialized byte array containing the encoded value
+ * @return UnsafeRow containing the decoded value columns
+ */
+ def decodeValue(bytes: Array[Byte]): UnsafeRow
}
-object RocksDBStateEncoder extends Logging {
- def getKeyEncoder(
- keyStateEncoderSpec: KeyStateEncoderSpec,
- useColumnFamilies: Boolean,
- virtualColFamilyId: Option[Short] = None,
- avroEnc: Option[AvroEncoder] = None): RocksDBKeyStateEncoder = {
- // Return the key state encoder based on the requested type
- keyStateEncoderSpec match {
- case NoPrefixKeyStateEncoderSpec(keySchema) =>
- new NoPrefixKeyStateEncoder(keySchema, useColumnFamilies,
virtualColFamilyId, avroEnc)
+abstract class RocksDBDataEncoder(
+ keyStateEncoderSpec: KeyStateEncoderSpec,
+ valueSchema: StructType) extends DataEncoder {
- case PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey) =>
- new PrefixKeyScanStateEncoder(keySchema, numColsPrefixKey,
- useColumnFamilies, virtualColFamilyId, avroEnc)
+ val keySchema = keyStateEncoderSpec.keySchema
+ val reusedKeyRow = new UnsafeRow(keyStateEncoderSpec.keySchema.length)
+ val reusedValueRow = new UnsafeRow(valueSchema.length)
- case RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) =>
- new RangeKeyScanStateEncoder(keySchema, orderingOrdinals,
- useColumnFamilies, virtualColFamilyId, avroEnc)
+ // bit masks used for checking sign or flipping all bits for negative
float/double values
+ val floatFlipBitMask = 0xFFFFFFFF
+ val floatSignBitMask = 0x80000000
- case _ =>
- throw new IllegalArgumentException(s"Unsupported key state encoder
spec: " +
- s"$keyStateEncoderSpec")
- }
- }
+ val doubleFlipBitMask = 0xFFFFFFFFFFFFFFFFL
+ val doubleSignBitMask = 0x8000000000000000L
- def getValueEncoder(
- valueSchema: StructType,
- useMultipleValuesPerKey: Boolean,
- avroEnc: Option[AvroEncoder] = None): RocksDBValueStateEncoder = {
- if (useMultipleValuesPerKey) {
- new MultiValuedStateEncoder(valueSchema, avroEnc)
- } else {
- new SingleValueStateEncoder(valueSchema, avroEnc)
- }
- }
+ // Byte markers used to identify whether the value is null, negative or
positive
+ // To ensure sorted ordering, we use the lowest byte value for negative
numbers followed by
+ // positive numbers and then null values.
+ val negativeValMarker: Byte = 0x00.toByte
+ val positiveValMarker: Byte = 0x01.toByte
+ val nullValMarker: Byte = 0x02.toByte
- def getColumnFamilyIdBytes(virtualColFamilyId: Short): Array[Byte] = {
- val encodedBytes = new Array[Byte](VIRTUAL_COL_FAMILY_PREFIX_BYTES)
- Platform.putShort(encodedBytes, Platform.BYTE_ARRAY_OFFSET,
virtualColFamilyId)
- encodedBytes
+
+ def unsupportedOperationForKeyStateEncoder(
+ operation: String
+ ): UnsupportedOperationException = {
+ new UnsupportedOperationException(
+ s"Method $operation not supported for encoder spec type " +
+ s"${keyStateEncoderSpec.getClass.getSimpleName}")
}
/**
@@ -156,26 +201,6 @@ object RocksDBStateEncoder extends Logging {
encodedBytes
}
- /**
- * This method takes an UnsafeRow, and serializes to a byte array using Avro
encoding.
- */
- def encodeUnsafeRowToAvro(
- row: UnsafeRow,
- avroSerializer: AvroSerializer,
- valueAvroType: Schema,
- out: ByteArrayOutputStream): Array[Byte] = {
- // InternalRow -> Avro.GenericDataRecord
- val avroData =
- avroSerializer.serialize(row)
- out.reset()
- val encoder = EncoderFactory.get().directBinaryEncoder(out, null)
- val writer = new GenericDatumWriter[Any](
- valueAvroType) // Defining Avro writer for this struct type
- writer.write(avroData, encoder) // Avro.GenericDataRecord -> byte array
- encoder.flush()
- out.toByteArray
- }
-
def decodeToUnsafeRow(bytes: Array[Byte], numFields: Int): UnsafeRow = {
if (bytes != null) {
val row = new UnsafeRow(numFields)
@@ -185,26 +210,6 @@ object RocksDBStateEncoder extends Logging {
}
}
- /**
- * This method takes a byte array written using Avro encoding, and
- * deserializes to an UnsafeRow using the Avro deserializer
- */
- def decodeFromAvroToUnsafeRow(
- valueBytes: Array[Byte],
- avroDeserializer: AvroDeserializer,
- valueAvroType: Schema,
- valueProj: UnsafeProjection): UnsafeRow = {
- val reader = new GenericDatumReader[Any](valueAvroType)
- val decoder = DecoderFactory.get().binaryDecoder(valueBytes, 0,
valueBytes.length, null)
- // bytes -> Avro.GenericDataRecord
- val genericData = reader.read(null, decoder)
- // Avro.GenericDataRecord -> InternalRow
- val internalRow = avroDeserializer.deserialize(
- genericData).orNull.asInstanceOf[InternalRow]
- // InternalRow -> UnsafeRow
- valueProj.apply(internalRow)
- }
-
def decodeToUnsafeRow(bytes: Array[Byte], reusedRow: UnsafeRow): UnsafeRow =
{
if (bytes != null) {
// Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st
offset. See Platform.
@@ -219,470 +224,403 @@ object RocksDBStateEncoder extends Logging {
}
}
-/**
- * RocksDB Key Encoder for UnsafeRow that supports prefix scan
- *
- * @param keySchema - schema of the key to be encoded
- * @param numColsPrefixKey - number of columns to be used for prefix key
- * @param useColumnFamilies - if column family is enabled for this encoder
- * @param avroEnc - if Avro encoding is specified for this StateEncoder, this
encoder will
- * be defined
- */
-class PrefixKeyScanStateEncoder(
- keySchema: StructType,
- numColsPrefixKey: Int,
- useColumnFamilies: Boolean = false,
- virtualColFamilyId: Option[Short] = None,
- avroEnc: Option[AvroEncoder] = None)
- extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) {
-
- import RocksDBStateEncoder._
+class UnsafeRowDataEncoder(
+ keyStateEncoderSpec: KeyStateEncoderSpec,
+ valueSchema: StructType) extends RocksDBDataEncoder(keyStateEncoderSpec,
valueSchema) {
- private val usingAvroEncoding = avroEnc.isDefined
- private val prefixKeyFieldsWithIdx: Seq[(StructField, Int)] = {
- keySchema.zipWithIndex.take(numColsPrefixKey)
+ override def encodeKey(row: UnsafeRow): Array[Byte] = {
+ encodeUnsafeRow(row)
}
- private val remainingKeyFieldsWithIdx: Seq[(StructField, Int)] = {
- keySchema.zipWithIndex.drop(numColsPrefixKey)
+ override def encodeRemainingKey(row: UnsafeRow): Array[Byte] = {
+ encodeUnsafeRow(row)
}
- private val prefixKeyProjection: UnsafeProjection = {
- val refs = prefixKeyFieldsWithIdx.map(x => BoundReference(x._2,
x._1.dataType, x._1.nullable))
- UnsafeProjection.create(refs)
- }
+ override def encodePrefixKeyForRangeScan(row: UnsafeRow): Array[Byte] = {
+ assert(keyStateEncoderSpec.isInstanceOf[RangeKeyScanStateEncoderSpec])
+ val rsk = keyStateEncoderSpec.asInstanceOf[RangeKeyScanStateEncoderSpec]
+ val rangeScanKeyFieldsWithOrdinal = rsk.orderingOrdinals.map { ordinal =>
+ val field = rsk.keySchema(ordinal)
+ (field, ordinal)
+ }
+ val writer = new UnsafeRowWriter(rsk.orderingOrdinals.length)
+ writer.resetRowWriter()
+ rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case
(fieldWithOrdinal, idx) =>
+ val field = fieldWithOrdinal._1
+ val value = row.get(idx, field.dataType)
+ // Note that we cannot allocate a smaller buffer here even if the value
is null
+ // because the effective byte array is considered variable size and
needs to have
+ // the same size across all rows for the ordering to work as expected.
+ val bbuf = ByteBuffer.allocate(field.dataType.defaultSize + 1)
+ bbuf.order(ByteOrder.BIG_ENDIAN)
+ if (value == null) {
+ bbuf.put(nullValMarker)
+ writer.write(idx, bbuf.array())
+ } else {
+ field.dataType match {
+ case BooleanType =>
+ case ByteType =>
+ val byteVal = value.asInstanceOf[Byte]
+ val signCol = if (byteVal < 0) {
+ negativeValMarker
+ } else {
+ positiveValMarker
+ }
+ bbuf.put(signCol)
+ bbuf.put(byteVal)
+ writer.write(idx, bbuf.array())
- private val remainingKeyProjection: UnsafeProjection = {
- val refs = remainingKeyFieldsWithIdx.map(x =>
- BoundReference(x._2, x._1.dataType, x._1.nullable))
- UnsafeProjection.create(refs)
- }
+ case ShortType =>
+ val shortVal = value.asInstanceOf[Short]
+ val signCol = if (shortVal < 0) {
+ negativeValMarker
+ } else {
+ positiveValMarker
+ }
+ bbuf.put(signCol)
+ bbuf.putShort(shortVal)
+ writer.write(idx, bbuf.array())
- // Prefix Key schema and projection definitions used by the Avro Serializers
- // and Deserializers
- private val prefixKeySchema = StructType(keySchema.take(numColsPrefixKey))
- private lazy val prefixKeyAvroType =
SchemaConverters.toAvroType(prefixKeySchema)
- private val prefixKeyProj = UnsafeProjection.create(prefixKeySchema)
+ case IntegerType =>
+ val intVal = value.asInstanceOf[Int]
+ val signCol = if (intVal < 0) {
+ negativeValMarker
+ } else {
+ positiveValMarker
+ }
+ bbuf.put(signCol)
+ bbuf.putInt(intVal)
+ writer.write(idx, bbuf.array())
- // Remaining Key schema and projection definitions used by the Avro
Serializers
- // and Deserializers
- private val remainingKeySchema = StructType(keySchema.drop(numColsPrefixKey))
- private lazy val remainingKeyAvroType =
SchemaConverters.toAvroType(remainingKeySchema)
- private val remainingKeyProj = UnsafeProjection.create(remainingKeySchema)
+ case LongType =>
+ val longVal = value.asInstanceOf[Long]
+ val signCol = if (longVal < 0) {
+ negativeValMarker
+ } else {
+ positiveValMarker
+ }
+ bbuf.put(signCol)
+ bbuf.putLong(longVal)
+ writer.write(idx, bbuf.array())
- // This is quite simple to do - just bind sequentially, as we don't change
the order.
- private val restoreKeyProjection: UnsafeProjection =
UnsafeProjection.create(keySchema)
+ case FloatType =>
+ val floatVal = value.asInstanceOf[Float]
+ val rawBits = floatToRawIntBits(floatVal)
+ // perform sign comparison using bit manipulation to ensure NaN
values are handled
+ // correctly
+ if ((rawBits & floatSignBitMask) != 0) {
+ // for negative values, we need to flip all the bits to ensure
correct ordering
+ val updatedVal = rawBits ^ floatFlipBitMask
+ bbuf.put(negativeValMarker)
+ // convert the bits back to float
+ bbuf.putFloat(intBitsToFloat(updatedVal))
+ } else {
+ bbuf.put(positiveValMarker)
+ bbuf.putFloat(floatVal)
+ }
+ writer.write(idx, bbuf.array())
- // Reusable objects
- private val joinedRowOnKey = new JoinedRow()
+ case DoubleType =>
+ val doubleVal = value.asInstanceOf[Double]
+ val rawBits = doubleToRawLongBits(doubleVal)
+ // perform sign comparison using bit manipulation to ensure NaN
values are handled
+ // correctly
+ if ((rawBits & doubleSignBitMask) != 0) {
+ // for negative values, we need to flip all the bits to ensure
correct ordering
+ val updatedVal = rawBits ^ doubleFlipBitMask
+ bbuf.put(negativeValMarker)
+ // convert the bits back to double
+ bbuf.putDouble(longBitsToDouble(updatedVal))
+ } else {
+ bbuf.put(positiveValMarker)
+ bbuf.putDouble(doubleVal)
+ }
+ writer.write(idx, bbuf.array())
+ }
+ }
+ }
+ encodeUnsafeRow(writer.getRow())
+ }
- override def encodeKey(row: UnsafeRow): Array[Byte] = {
- val (prefixKeyEncoded, remainingEncoded) = if (usingAvroEncoding) {
- (
- encodeUnsafeRowToAvro(
- extractPrefixKey(row),
- avroEnc.get.keySerializer,
- prefixKeyAvroType,
- out
- ),
- encodeUnsafeRowToAvro(
- remainingKeyProjection(row),
- avroEnc.get.suffixKeySerializer.get,
- remainingKeyAvroType,
- out
- )
- )
- } else {
- (encodeUnsafeRow(extractPrefixKey(row)),
encodeUnsafeRow(remainingKeyProjection(row)))
+ override def encodeValue(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row)
+
+ override def decodeKey(bytes: Array[Byte]): UnsafeRow = {
+ keyStateEncoderSpec match {
+ case NoPrefixKeyStateEncoderSpec(_) =>
+ decodeToUnsafeRow(bytes, reusedKeyRow)
+ case PrefixKeyScanStateEncoderSpec(_, numColsPrefixKey) =>
+ decodeToUnsafeRow(bytes, numFields = numColsPrefixKey)
+ case _ => throw unsupportedOperationForKeyStateEncoder("decodeKey")
}
- val (encodedBytes, startingOffset) = encodeColumnFamilyPrefix(
- prefixKeyEncoded.length + remainingEncoded.length + 4
- )
+ }
- Platform.putInt(encodedBytes, startingOffset, prefixKeyEncoded.length)
- Platform.copyMemory(prefixKeyEncoded, Platform.BYTE_ARRAY_OFFSET,
- encodedBytes, startingOffset + 4, prefixKeyEncoded.length)
- // NOTE: We don't put the length of remainingEncoded as we can calculate
later
- // on deserialization.
- Platform.copyMemory(remainingEncoded, Platform.BYTE_ARRAY_OFFSET,
- encodedBytes, startingOffset + 4 + prefixKeyEncoded.length,
- remainingEncoded.length)
-
- encodedBytes
+ override def decodeRemainingKey(bytes: Array[Byte]): UnsafeRow = {
+ keyStateEncoderSpec match {
+ case PrefixKeyScanStateEncoderSpec(_, numColsPrefixKey) =>
+ decodeToUnsafeRow(bytes, numFields = numColsPrefixKey)
+ case RangeKeyScanStateEncoderSpec(_, orderingOrdinals) =>
+ decodeToUnsafeRow(bytes, keySchema.length - orderingOrdinals.length)
+ case _ => throw
unsupportedOperationForKeyStateEncoder("decodeRemainingKey")
+ }
}
- override def decodeKey(keyBytes: Array[Byte]): UnsafeRow = {
- val prefixKeyEncodedLen = Platform.getInt(keyBytes, decodeKeyStartOffset)
- val prefixKeyEncoded = new Array[Byte](prefixKeyEncodedLen)
- Platform.copyMemory(keyBytes, decodeKeyStartOffset + 4,
- prefixKeyEncoded, Platform.BYTE_ARRAY_OFFSET, prefixKeyEncodedLen)
+ override def decodePrefixKeyForRangeScan(bytes: Array[Byte]): UnsafeRow = {
+ assert(keyStateEncoderSpec.isInstanceOf[RangeKeyScanStateEncoderSpec])
+ val rsk = keyStateEncoderSpec.asInstanceOf[RangeKeyScanStateEncoderSpec]
+ val writer = new UnsafeRowWriter(rsk.orderingOrdinals.length)
+ val rangeScanKeyFieldsWithOrdinal = rsk.orderingOrdinals.map { ordinal =>
+ val field = rsk.keySchema(ordinal)
+ (field, ordinal)
+ }
+ writer.resetRowWriter()
+ val row = decodeToUnsafeRow(bytes, numFields = rsk.orderingOrdinals.length)
+ rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case
(fieldWithOrdinal, idx) =>
+ val field = fieldWithOrdinal._1
- // Here we calculate the remainingKeyEncodedLen leveraging the length of
keyBytes
- val remainingKeyEncodedLen = keyBytes.length - 4 - prefixKeyEncodedLen -
- offsetForColFamilyPrefix
+ val value = row.getBinary(idx)
+ val bbuf = ByteBuffer.wrap(value.asInstanceOf[Array[Byte]])
+ bbuf.order(ByteOrder.BIG_ENDIAN)
+ val isNullOrSignCol = bbuf.get()
+ if (isNullOrSignCol == nullValMarker) {
+ // set the column to null and skip reading the next byte(s)
+ writer.setNullAt(idx)
+ } else {
+ field.dataType match {
+ case BooleanType =>
+ case ByteType =>
+ writer.write(idx, bbuf.get)
- val remainingKeyEncoded = new Array[Byte](remainingKeyEncodedLen)
- Platform.copyMemory(keyBytes, decodeKeyStartOffset + 4 +
prefixKeyEncodedLen,
- remainingKeyEncoded, Platform.BYTE_ARRAY_OFFSET, remainingKeyEncodedLen)
+ case ShortType =>
+ writer.write(idx, bbuf.getShort)
- val (prefixKeyDecoded, remainingKeyDecoded) = if (usingAvroEncoding) {
- (
- decodeFromAvroToUnsafeRow(
- prefixKeyEncoded,
- avroEnc.get.keyDeserializer,
- prefixKeyAvroType,
- prefixKeyProj
- ),
- decodeFromAvroToUnsafeRow(
- remainingKeyEncoded,
- avroEnc.get.suffixKeyDeserializer.get,
- remainingKeyAvroType,
- remainingKeyProj
- )
- )
- } else {
- (decodeToUnsafeRow(prefixKeyEncoded, numFields = numColsPrefixKey),
- decodeToUnsafeRow(remainingKeyEncoded, numFields = keySchema.length -
numColsPrefixKey))
- }
+ case IntegerType =>
+ writer.write(idx, bbuf.getInt)
-
restoreKeyProjection(joinedRowOnKey.withLeft(prefixKeyDecoded).withRight(remainingKeyDecoded))
- }
+ case LongType =>
+ writer.write(idx, bbuf.getLong)
- private def extractPrefixKey(key: UnsafeRow): UnsafeRow = {
- prefixKeyProjection(key)
- }
+ case FloatType =>
+ if (isNullOrSignCol == negativeValMarker) {
+ // if the number is negative, get the raw binary bits for the
float
+ // and flip the bits back
+ val updatedVal = floatToRawIntBits(bbuf.getFloat) ^
floatFlipBitMask
+ writer.write(idx, intBitsToFloat(updatedVal))
+ } else {
+ writer.write(idx, bbuf.getFloat)
+ }
- override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = {
- val prefixKeyEncoded = if (usingAvroEncoding) {
- encodeUnsafeRowToAvro(prefixKey, avroEnc.get.keySerializer,
prefixKeyAvroType, out)
- } else {
- encodeUnsafeRow(prefixKey)
+ case DoubleType =>
+ if (isNullOrSignCol == negativeValMarker) {
+ // if the number is negative, get the raw binary bits for the
double
+ // and flip the bits back
+ val updatedVal = doubleToRawLongBits(bbuf.getDouble) ^
doubleFlipBitMask
+ writer.write(idx, longBitsToDouble(updatedVal))
+ } else {
+ writer.write(idx, bbuf.getDouble)
+ }
+ }
+ }
}
- val (prefix, startingOffset) = encodeColumnFamilyPrefix(
- prefixKeyEncoded.length + 4
- )
-
- Platform.putInt(prefix, startingOffset, prefixKeyEncoded.length)
- Platform.copyMemory(prefixKeyEncoded, Platform.BYTE_ARRAY_OFFSET, prefix,
- startingOffset + 4, prefixKeyEncoded.length)
- prefix
+ writer.getRow()
}
- override def supportPrefixKeyScan: Boolean = true
+ override def decodeValue(bytes: Array[Byte]): UnsafeRow =
decodeToUnsafeRow(bytes, reusedValueRow)
}
-/**
- * RocksDB Key Encoder for UnsafeRow that supports range scan for fixed size
fields
- *
- * To encode a row for range scan, we first project the orderingOrdinals from
the oridinal
- * UnsafeRow into another UnsafeRow; we then rewrite that new UnsafeRow's
fields in BIG_ENDIAN
- * to allow for scanning keys in sorted order using the byte-wise comparison
method that
- * RocksDB uses.
- *
- * Then, for the rest of the fields, we project those into another UnsafeRow.
- * We then effectively join these two UnsafeRows together, and finally take
those bytes
- * to get the resulting row.
- *
- * We cannot support variable sized fields in the range scan because the
UnsafeRow format
- * stores variable sized fields as offset and length pointers to the actual
values,
- * thereby changing the required ordering.
- *
- * Note that we also support "null" values being passed for these fixed size
fields. We prepend
- * a single byte to indicate whether the column value is null or not. We
cannot change the
- * nullability on the UnsafeRow itself as the expected ordering would change
if non-first
- * columns are marked as null. If the first col is null, those entries will
appear last in
- * the iterator. If non-first columns are null, ordering based on the previous
columns will
- * still be honored. For rows with null column values, ordering for subsequent
columns
- * will also be maintained within those set of rows. We use the same byte to
also encode whether
- * the value is negative or not. For negative float/double values, we flip all
the bits to ensure
- * the right lexicographical ordering. For the rationale around this, please
check the link
- * here: https://en.wikipedia.org/wiki/IEEE_754#Design_rationale
- *
- * @param keySchema - schema of the key to be encoded
- * @param orderingOrdinals - the ordinals for which the range scan is
constructed
- * @param useColumnFamilies - if column family is enabled for this encoder
- * @param avroEnc - if Avro encoding is specified for this StateEncoder, this
encoder will
- * be defined
- */
-class RangeKeyScanStateEncoder(
- keySchema: StructType,
- orderingOrdinals: Seq[Int],
- useColumnFamilies: Boolean = false,
- virtualColFamilyId: Option[Short] = None,
- avroEnc: Option[AvroEncoder] = None)
- extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId)
with Logging {
+class AvroStateEncoder(
+ keyStateEncoderSpec: KeyStateEncoderSpec,
+ valueSchema: StructType) extends RocksDBDataEncoder(keyStateEncoderSpec,
valueSchema)
+ with Logging {
- import RocksDBStateEncoder._
+ private val avroEncoder = createAvroEnc(keyStateEncoderSpec, valueSchema)
+ // Avro schema used by the avro encoders
+ private lazy val keyAvroType: Schema = SchemaConverters.toAvroType(keySchema)
+ private lazy val keyProj = UnsafeProjection.create(keySchema)
- private val rangeScanKeyFieldsWithOrdinal: Seq[(StructField, Int)] = {
- orderingOrdinals.map { ordinal =>
- val field = keySchema(ordinal)
- (field, ordinal)
- }
- }
+ private lazy val valueAvroType: Schema =
SchemaConverters.toAvroType(valueSchema)
+ private lazy val valueProj = UnsafeProjection.create(valueSchema)
- private def isFixedSize(dataType: DataType): Boolean = dataType match {
- case _: ByteType | _: BooleanType | _: ShortType | _: IntegerType | _:
LongType |
- _: FloatType | _: DoubleType => true
- case _ => false
+ // Prefix Key schema and projection definitions used by the Avro Serializers
+ // and Deserializers
+ private lazy val prefixKeySchema = keyStateEncoderSpec match {
+ case PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey) =>
+ StructType(keySchema.take (numColsPrefixKey))
+ case _ => throw unsupportedOperationForKeyStateEncoder("prefixKeySchema")
}
-
- // verify that only fixed sized columns are used for ordering
- rangeScanKeyFieldsWithOrdinal.foreach { case (field, ordinal) =>
- if (!isFixedSize(field.dataType)) {
- // NullType is technically fixed size, but not supported for ordering
- if (field.dataType == NullType) {
- throw StateStoreErrors.nullTypeOrderingColsNotSupported(field.name,
ordinal.toString)
- } else {
- throw
StateStoreErrors.variableSizeOrderingColsNotSupported(field.name,
ordinal.toString)
+ private lazy val prefixKeyAvroType =
SchemaConverters.toAvroType(prefixKeySchema)
+ private lazy val prefixKeyProj = UnsafeProjection.create(prefixKeySchema)
+
+ // Range Key schema nd projection definitions used by the Avro Serializers
and
+ // Deserializers
+ private lazy val rangeScanKeyFieldsWithOrdinal = keyStateEncoderSpec match {
+ case RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) =>
+ orderingOrdinals.map { ordinal =>
+ val field = keySchema(ordinal)
+ (field, ordinal)
}
- }
+ case _ =>
+ throw unsupportedOperationForKeyStateEncoder("rangeScanKey")
}
- private val remainingKeyFieldsWithOrdinal: Seq[(StructField, Int)] = {
- 0.to(keySchema.length - 1).diff(orderingOrdinals).map { ordinal =>
- val field = keySchema(ordinal)
- (field, ordinal)
- }
- }
+ private lazy val rangeScanAvroSchema =
StateStoreColumnFamilySchemaUtils.convertForRangeScan(
+ StructType(rangeScanKeyFieldsWithOrdinal.map(_._1).toArray))
- private val rangeScanKeyProjection: UnsafeProjection = {
- val refs = rangeScanKeyFieldsWithOrdinal.map(x =>
- BoundReference(x._2, x._1.dataType, x._1.nullable))
- UnsafeProjection.create(refs)
- }
+ private lazy val rangeScanAvroType =
SchemaConverters.toAvroType(rangeScanAvroSchema)
- private val remainingKeyProjection: UnsafeProjection = {
- val refs = remainingKeyFieldsWithOrdinal.map(x =>
- BoundReference(x._2, x._1.dataType, x._1.nullable))
- UnsafeProjection.create(refs)
+ private lazy val rangeScanAvroProjection =
UnsafeProjection.create(rangeScanAvroSchema)
+
+ // Existing remainder key schema definitions
+ // Remaining Key schema and projection definitions used by the Avro
Serializers
+ // and Deserializers
+ private lazy val remainingKeySchema = keyStateEncoderSpec match {
+ case PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey) =>
+ StructType(keySchema.drop(numColsPrefixKey))
+ case RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) =>
+
StructType(0.until(keySchema.length).diff(orderingOrdinals).map(keySchema(_)))
+ case _ => throw
unsupportedOperationForKeyStateEncoder("remainingKeySchema")
}
- // The original schema that we might get could be:
- // [foo, bar, baz, buzz]
- // We might order by bar and buzz, leading to:
- // [bar, buzz, foo, baz]
- // We need to create a projection that sends, for example, the buzz at index
1 to index
- // 3. Thus, for every record in the original schema, we compute where it
would be in
- // the joined row and created a projection based on that.
- private val restoreKeyProjection: UnsafeProjection = {
- val refs = keySchema.zipWithIndex.map { case (field, originalOrdinal) =>
- val ordinalInJoinedRow = if (orderingOrdinals.contains(originalOrdinal))
{
- orderingOrdinals.indexOf(originalOrdinal)
- } else {
- orderingOrdinals.length +
- remainingKeyFieldsWithOrdinal.indexWhere(_._2 == originalOrdinal)
- }
+ private lazy val remainingKeyAvroType =
SchemaConverters.toAvroType(remainingKeySchema)
- BoundReference(ordinalInJoinedRow, field.dataType, field.nullable)
- }
- UnsafeProjection.create(refs)
+ private lazy val remainingKeyAvroProjection =
UnsafeProjection.create(remainingKeySchema)
+
+ private def getAvroSerializer(schema: StructType): AvroSerializer = {
+ val avroType = SchemaConverters.toAvroType(schema)
+ new AvroSerializer(schema, avroType, nullable = false)
}
- private val rangeScanAvroSchema =
StateStoreColumnFamilySchemaUtils.convertForRangeScan(
- StructType(rangeScanKeyFieldsWithOrdinal.map(_._1).toArray))
+ private def getAvroDeserializer(schema: StructType): AvroDeserializer = {
+ val avroType = SchemaConverters.toAvroType(schema)
+ val avroOptions = AvroOptions(Map.empty)
+ new AvroDeserializer(avroType, schema,
+ avroOptions.datetimeRebaseModeInRead,
avroOptions.useStableIdForUnionType,
+ avroOptions.stableIdPrefixForUnionType,
avroOptions.recursiveFieldMaxDepth)
+ }
- private lazy val rangeScanAvroType =
SchemaConverters.toAvroType(rangeScanAvroSchema)
+ /**
+ * Creates an AvroEncoder that handles both key and value
serialization/deserialization.
+ * This method sets up the complete encoding infrastructure needed for state
store operations.
+ *
+ * The encoder handles different key encoding specifications:
+ * - NoPrefixKeyStateEncoderSpec: Simple key encoding without prefix
+ * - PrefixKeyScanStateEncoderSpec: Keys with prefix for efficient scanning
+ * - RangeKeyScanStateEncoderSpec: Keys with ordering requirements for range
scans
+ *
+ * For prefix scan cases, it also creates separate encoders for the suffix
portion of keys.
+ *
+ * @param keyStateEncoderSpec Specification for how to encode keys
+ * @param valueSchema Schema for the values to be encoded
+ * @return An AvroEncoder containing all necessary serializers and
deserializers
+ */
+ private def createAvroEnc(
+ keyStateEncoderSpec: KeyStateEncoderSpec,
+ valueSchema: StructType): AvroEncoder = {
+ val valueSerializer = getAvroSerializer(valueSchema)
+ val valueDeserializer = getAvroDeserializer(valueSchema)
+
+ // Get key schema based on encoder spec type
+ val keySchema = keyStateEncoderSpec match {
+ case NoPrefixKeyStateEncoderSpec(schema) =>
+ schema
+ case PrefixKeyScanStateEncoderSpec(schema, numColsPrefixKey) =>
+ StructType(schema.take(numColsPrefixKey))
+ case RangeKeyScanStateEncoderSpec(schema, orderingOrdinals) =>
+ val remainingSchema = {
+ 0.until(schema.length).diff(orderingOrdinals).map { ordinal =>
+ schema(ordinal)
+ }
+ }
+ StructType(remainingSchema)
+ }
- private val rangeScanAvroProjection =
UnsafeProjection.create(rangeScanAvroSchema)
+ // Handle suffix key schema for prefix scan case
+ val suffixKeySchema = keyStateEncoderSpec match {
+ case PrefixKeyScanStateEncoderSpec(schema, numColsPrefixKey) =>
+ Some(StructType(schema.drop(numColsPrefixKey)))
+ case _ =>
+ None
+ }
- // Existing remainder key schema stuff
- private val remainingKeySchema = StructType(
- 0.to(keySchema.length - 1).diff(orderingOrdinals).map(keySchema(_))
- )
+ val keySerializer = getAvroSerializer(keySchema)
+ val keyDeserializer = getAvroDeserializer(keySchema)
+
+ // Create the AvroEncoder with all components
+ AvroEncoder(
+ keySerializer,
+ keyDeserializer,
+ valueSerializer,
+ valueDeserializer,
+ suffixKeySchema.map(getAvroSerializer),
+ suffixKeySchema.map(getAvroDeserializer)
+ )
+ }
- private lazy val remainingKeyAvroType =
SchemaConverters.toAvroType(remainingKeySchema)
+ /**
+ * This method takes an UnsafeRow, and serializes to a byte array using Avro
encoding.
+ */
+ def encodeUnsafeRowToAvro(
+ row: UnsafeRow,
+ avroSerializer: AvroSerializer,
+ valueAvroType: Schema,
+ out: ByteArrayOutputStream): Array[Byte] = {
+ // InternalRow -> Avro.GenericDataRecord
+ val avroData =
+ avroSerializer.serialize(row)
+ out.reset()
+ val encoder = EncoderFactory.get().directBinaryEncoder(out, null)
+ val writer = new GenericDatumWriter[Any](
+ valueAvroType) // Defining Avro writer for this struct type
+ writer.write(avroData, encoder) // Avro.GenericDataRecord -> byte array
+ encoder.flush()
+ out.toByteArray
+ }
- private val remainingKeyAvroProjection =
UnsafeProjection.create(remainingKeySchema)
+ /**
+ * This method takes a byte array written using Avro encoding, and
+ * deserializes to an UnsafeRow using the Avro deserializer
+ */
+ def decodeFromAvroToUnsafeRow(
+ valueBytes: Array[Byte],
+ avroDeserializer: AvroDeserializer,
+ valueAvroType: Schema,
+ valueProj: UnsafeProjection): UnsafeRow = {
+ if (valueBytes != null) {
+ val reader = new GenericDatumReader[Any](valueAvroType)
+ val decoder = DecoderFactory.get().binaryDecoder(
+ valueBytes, 0, valueBytes.length, null)
+ // bytes -> Avro.GenericDataRecord
+ val genericData = reader.read(null, decoder)
+ // Avro.GenericDataRecord -> InternalRow
+ val internalRow = avroDeserializer.deserialize(
+ genericData).orNull.asInstanceOf[InternalRow]
+ // InternalRow -> UnsafeRow
+ valueProj.apply(internalRow)
+ } else {
+ null
+ }
+ }
- // Reusable objects
- private val joinedRowOnKey = new JoinedRow()
+ private val out = new ByteArrayOutputStream
- private def extractPrefixKey(key: UnsafeRow): UnsafeRow = {
- rangeScanKeyProjection(key)
+ override def encodeKey(row: UnsafeRow): Array[Byte] = {
+ keyStateEncoderSpec match {
+ case NoPrefixKeyStateEncoderSpec(_) =>
+ encodeUnsafeRowToAvro(row, avroEncoder.keySerializer, keyAvroType, out)
+ case PrefixKeyScanStateEncoderSpec(_, _) =>
+ encodeUnsafeRowToAvro(row, avroEncoder.keySerializer,
prefixKeyAvroType, out)
+ case _ => throw unsupportedOperationForKeyStateEncoder("encodeKey")
+ }
}
- // bit masks used for checking sign or flipping all bits for negative
float/double values
- private val floatFlipBitMask = 0xFFFFFFFF
- private val floatSignBitMask = 0x80000000
-
- private val doubleFlipBitMask = 0xFFFFFFFFFFFFFFFFL
- private val doubleSignBitMask = 0x8000000000000000L
-
- // Byte markers used to identify whether the value is null, negative or
positive
- // To ensure sorted ordering, we use the lowest byte value for negative
numbers followed by
- // positive numbers and then null values.
- private val negativeValMarker: Byte = 0x00.toByte
- private val positiveValMarker: Byte = 0x01.toByte
- private val nullValMarker: Byte = 0x02.toByte
-
- // Rewrite the unsafe row by replacing fixed size fields with BIG_ENDIAN
encoding
- // using byte arrays.
- // To handle "null" values, we prepend a byte to the byte array indicating
whether the value
- // is null or not. If the value is null, we write the null byte followed by
zero bytes.
- // If the value is not null, we write the null byte followed by the value.
- // Note that setting null for the index on the unsafeRow is not feasible as
it would change
- // the sorting order on iteration.
- // Also note that the same byte is used to indicate whether the value is
negative or not.
- private def encodePrefixKeyForRangeScan(row: UnsafeRow): UnsafeRow = {
- val writer = new UnsafeRowWriter(orderingOrdinals.length)
- writer.resetRowWriter()
- rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case
(fieldWithOrdinal, idx) =>
- val field = fieldWithOrdinal._1
- val value = row.get(idx, field.dataType)
- // Note that we cannot allocate a smaller buffer here even if the value
is null
- // because the effective byte array is considered variable size and
needs to have
- // the same size across all rows for the ordering to work as expected.
- val bbuf = ByteBuffer.allocate(field.dataType.defaultSize + 1)
- bbuf.order(ByteOrder.BIG_ENDIAN)
- if (value == null) {
- bbuf.put(nullValMarker)
- writer.write(idx, bbuf.array())
- } else {
- field.dataType match {
- case BooleanType =>
- case ByteType =>
- val byteVal = value.asInstanceOf[Byte]
- val signCol = if (byteVal < 0) {
- negativeValMarker
- } else {
- positiveValMarker
- }
- bbuf.put(signCol)
- bbuf.put(byteVal)
- writer.write(idx, bbuf.array())
-
- case ShortType =>
- val shortVal = value.asInstanceOf[Short]
- val signCol = if (shortVal < 0) {
- negativeValMarker
- } else {
- positiveValMarker
- }
- bbuf.put(signCol)
- bbuf.putShort(shortVal)
- writer.write(idx, bbuf.array())
-
- case IntegerType =>
- val intVal = value.asInstanceOf[Int]
- val signCol = if (intVal < 0) {
- negativeValMarker
- } else {
- positiveValMarker
- }
- bbuf.put(signCol)
- bbuf.putInt(intVal)
- writer.write(idx, bbuf.array())
-
- case LongType =>
- val longVal = value.asInstanceOf[Long]
- val signCol = if (longVal < 0) {
- negativeValMarker
- } else {
- positiveValMarker
- }
- bbuf.put(signCol)
- bbuf.putLong(longVal)
- writer.write(idx, bbuf.array())
-
- case FloatType =>
- val floatVal = value.asInstanceOf[Float]
- val rawBits = floatToRawIntBits(floatVal)
- // perform sign comparison using bit manipulation to ensure NaN
values are handled
- // correctly
- if ((rawBits & floatSignBitMask) != 0) {
- // for negative values, we need to flip all the bits to ensure
correct ordering
- val updatedVal = rawBits ^ floatFlipBitMask
- bbuf.put(negativeValMarker)
- // convert the bits back to float
- bbuf.putFloat(intBitsToFloat(updatedVal))
- } else {
- bbuf.put(positiveValMarker)
- bbuf.putFloat(floatVal)
- }
- writer.write(idx, bbuf.array())
-
- case DoubleType =>
- val doubleVal = value.asInstanceOf[Double]
- val rawBits = doubleToRawLongBits(doubleVal)
- // perform sign comparison using bit manipulation to ensure NaN
values are handled
- // correctly
- if ((rawBits & doubleSignBitMask) != 0) {
- // for negative values, we need to flip all the bits to ensure
correct ordering
- val updatedVal = rawBits ^ doubleFlipBitMask
- bbuf.put(negativeValMarker)
- // convert the bits back to double
- bbuf.putDouble(longBitsToDouble(updatedVal))
- } else {
- bbuf.put(positiveValMarker)
- bbuf.putDouble(doubleVal)
- }
- writer.write(idx, bbuf.array())
- }
- }
- }
- writer.getRow()
- }
-
- // Rewrite the unsafe row by converting back from BIG_ENDIAN byte arrays to
the
- // original data types.
- // For decode, we extract the byte array from the UnsafeRow, and then read
the first byte
- // to determine if the value is null or not. If the value is null, we set
the ordinal on
- // the UnsafeRow to null. If the value is not null, we read the rest of the
bytes to get the
- // actual value.
- // For negative float/double values, we need to flip all the bits back to
get the original value.
- private def decodePrefixKeyForRangeScan(row: UnsafeRow): UnsafeRow = {
- val writer = new UnsafeRowWriter(orderingOrdinals.length)
- writer.resetRowWriter()
- rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case
(fieldWithOrdinal, idx) =>
- val field = fieldWithOrdinal._1
-
- val value = row.getBinary(idx)
- val bbuf = ByteBuffer.wrap(value.asInstanceOf[Array[Byte]])
- bbuf.order(ByteOrder.BIG_ENDIAN)
- val isNullOrSignCol = bbuf.get()
- if (isNullOrSignCol == nullValMarker) {
- // set the column to null and skip reading the next byte(s)
- writer.setNullAt(idx)
- } else {
- field.dataType match {
- case BooleanType =>
- case ByteType =>
- writer.write(idx, bbuf.get)
-
- case ShortType =>
- writer.write(idx, bbuf.getShort)
-
- case IntegerType =>
- writer.write(idx, bbuf.getInt)
-
- case LongType =>
- writer.write(idx, bbuf.getLong)
-
- case FloatType =>
- if (isNullOrSignCol == negativeValMarker) {
- // if the number is negative, get the raw binary bits for the
float
- // and flip the bits back
- val updatedVal = floatToRawIntBits(bbuf.getFloat) ^
floatFlipBitMask
- writer.write(idx, intBitsToFloat(updatedVal))
- } else {
- writer.write(idx, bbuf.getFloat)
- }
-
- case DoubleType =>
- if (isNullOrSignCol == negativeValMarker) {
- // if the number is negative, get the raw binary bits for the
double
- // and flip the bits back
- val updatedVal = doubleToRawLongBits(bbuf.getDouble) ^
doubleFlipBitMask
- writer.write(idx, longBitsToDouble(updatedVal))
- } else {
- writer.write(idx, bbuf.getDouble)
- }
- }
- }
- }
- writer.getRow()
- }
+ override def encodeRemainingKey(row: UnsafeRow): Array[Byte] = {
+ keyStateEncoderSpec match {
+ case PrefixKeyScanStateEncoderSpec(_, _) =>
+ encodeUnsafeRowToAvro(row, avroEncoder.suffixKeySerializer.get,
remainingKeyAvroType, out)
+ case RangeKeyScanStateEncoderSpec(_, _) =>
+ encodeUnsafeRowToAvro(row, avroEncoder.keySerializer,
remainingKeyAvroType, out)
+ case _ => throw
unsupportedOperationForKeyStateEncoder("encodeRemainingKey")
+ }
+ }
/**
* Encodes an UnsafeRow into an Avro-compatible byte array format for range
scan operations.
@@ -704,10 +642,8 @@ class RangeKeyScanStateEncoder(
* @throws UnsupportedOperationException if a field's data type is not
supported for range
* scan encoding
*/
- def encodePrefixKeyForRangeScan(
- row: UnsafeRow,
- avroType: Schema): Array[Byte] = {
- val record = new GenericData.Record(avroType)
+ override def encodePrefixKeyForRangeScan(row: UnsafeRow): Array[Byte] = {
+ val record = new GenericData.Record(rangeScanAvroType)
var fieldIdx = 0
rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case
(fieldWithOrdinal, idx) =>
val field = fieldWithOrdinal._1
@@ -810,146 +746,486 @@ class RangeKeyScanStateEncoder(
}
record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array()))
- case _ => throw new UnsupportedOperationException(
- s"Range scan encoding not supported for data type:
${field.dataType}")
- }
- }
- fieldIdx += 2
- }
+ case _ => throw new UnsupportedOperationException(
+ s"Range scan encoding not supported for data type:
${field.dataType}")
+ }
+ }
+ fieldIdx += 2
+ }
+
+ out.reset()
+ val writer = new GenericDatumWriter[GenericRecord](rangeScanAvroType)
+ val encoder = EncoderFactory.get().binaryEncoder(out, null)
+ writer.write(record, encoder)
+ encoder.flush()
+ out.toByteArray
+ }
+
+ override def encodeValue(row: UnsafeRow): Array[Byte] =
+ encodeUnsafeRowToAvro(row, avroEncoder.valueSerializer, valueAvroType, out)
+
+ override def decodeKey(bytes: Array[Byte]): UnsafeRow = {
+ keyStateEncoderSpec match {
+ case NoPrefixKeyStateEncoderSpec(_) =>
+ decodeFromAvroToUnsafeRow(bytes, avroEncoder.keyDeserializer,
keyAvroType, keyProj)
+ case PrefixKeyScanStateEncoderSpec(_, _) =>
+ decodeFromAvroToUnsafeRow(
+ bytes, avroEncoder.keyDeserializer, prefixKeyAvroType, prefixKeyProj)
+ case _ => throw unsupportedOperationForKeyStateEncoder("decodeKey")
+ }
+ }
+
+
+ override def decodeRemainingKey(bytes: Array[Byte]): UnsafeRow = {
+ keyStateEncoderSpec match {
+ case PrefixKeyScanStateEncoderSpec(_, _) =>
+ decodeFromAvroToUnsafeRow(bytes,
+ avroEncoder.suffixKeyDeserializer.get, remainingKeyAvroType,
remainingKeyAvroProjection)
+ case RangeKeyScanStateEncoderSpec(_, _) =>
+ decodeFromAvroToUnsafeRow(
+ bytes, avroEncoder.keyDeserializer, remainingKeyAvroType,
remainingKeyAvroProjection)
+ case _ => throw
unsupportedOperationForKeyStateEncoder("decodeRemainingKey")
+ }
+ }
+
+ /**
+ * Decodes an Avro-encoded byte array back into an UnsafeRow for range scan
operations.
+ *
+ * This method reverses the encoding process performed by
encodePrefixKeyForRangeScan:
+ * - Reads the marker byte to determine null status or sign
+ * - Reconstructs the original values from big-endian format
+ * - Handles special cases for floating point numbers by reversing bit
manipulations
+ *
+ * The decoding process preserves the original data types and values,
including:
+ * - Null values marked by nullValMarker
+ * - Sign information for numeric types
+ * - Proper restoration of negative floating point values
+ *
+ * @param bytes The Avro-encoded byte array to decode
+ * @param avroType The Avro schema defining the structure for decoding
+ * @return UnsafeRow containing the decoded data
+ * @throws UnsupportedOperationException if a field's data type is not
supported for range
+ * scan decoding
+ */
+ override def decodePrefixKeyForRangeScan(bytes: Array[Byte]): UnsafeRow = {
+ val reader = new GenericDatumReader[GenericRecord](rangeScanAvroType)
+ val decoder = DecoderFactory.get().binaryDecoder(bytes, 0, bytes.length,
null)
+ val record = reader.read(null, decoder)
+
+ val rowWriter = new UnsafeRowWriter(rangeScanKeyFieldsWithOrdinal.length)
+ rowWriter.resetRowWriter()
+
+ var fieldIdx = 0
+ rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case
(fieldWithOrdinal, idx) =>
+ val field = fieldWithOrdinal._1
+
+ val markerBytes = record.get(fieldIdx).asInstanceOf[ByteBuffer].array()
+ val markerBuf = ByteBuffer.wrap(markerBytes)
+ markerBuf.order(ByteOrder.BIG_ENDIAN)
+ val marker = markerBuf.get()
+
+ if (marker == nullValMarker) {
+ rowWriter.setNullAt(idx)
+ } else {
+ field.dataType match {
+ case BooleanType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ rowWriter.write(idx, bytes(0) == 1)
+
+ case ByteType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ rowWriter.write(idx, valueBuf.get())
+
+ case ShortType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ rowWriter.write(idx, valueBuf.getShort())
+
+ case IntegerType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ rowWriter.write(idx, valueBuf.getInt())
+
+ case LongType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ rowWriter.write(idx, valueBuf.getLong())
+
+ case FloatType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ if (marker == negativeValMarker) {
+ val floatVal = valueBuf.getFloat
+ val updatedVal = floatToRawIntBits(floatVal) ^ floatFlipBitMask
+ rowWriter.write(idx, intBitsToFloat(updatedVal))
+ } else {
+ rowWriter.write(idx, valueBuf.getFloat())
+ }
+
+ case DoubleType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ if (marker == negativeValMarker) {
+ val doubleVal = valueBuf.getDouble
+ val updatedVal = doubleToRawLongBits(doubleVal) ^
doubleFlipBitMask
+ rowWriter.write(idx, longBitsToDouble(updatedVal))
+ } else {
+ rowWriter.write(idx, valueBuf.getDouble())
+ }
+
+ case _ => throw new UnsupportedOperationException(
+ s"Range scan decoding not supported for data type:
${field.dataType}")
+ }
+ }
+ fieldIdx += 2
+ }
+
+ rowWriter.getRow()
+ }
+
+ override def decodeValue(bytes: Array[Byte]): UnsafeRow =
+ decodeFromAvroToUnsafeRow(
+ bytes, avroEncoder.valueDeserializer, valueAvroType, valueProj)
+}
+
+abstract class RocksDBKeyStateEncoderBase(
+ useColumnFamilies: Boolean,
+ virtualColFamilyId: Option[Short] = None) extends RocksDBKeyStateEncoder {
+ def offsetForColFamilyPrefix: Int =
+ if (useColumnFamilies) VIRTUAL_COL_FAMILY_PREFIX_BYTES else 0
+
+ val out = new ByteArrayOutputStream
+
+ /**
+ * Get Byte Array for the virtual column family id that is used as prefix for
+ * key state rows.
+ */
+ override def getColumnFamilyIdBytes(): Array[Byte] = {
+ assert(useColumnFamilies, "Cannot return virtual Column Family Id Bytes" +
+ " because multiple Column is not supported for this encoder")
+ val encodedBytes = new Array[Byte](VIRTUAL_COL_FAMILY_PREFIX_BYTES)
+ Platform.putShort(encodedBytes, Platform.BYTE_ARRAY_OFFSET,
virtualColFamilyId.get)
+ encodedBytes
+ }
+
+ /**
+ * Encode and put column family Id as a prefix to a pre-allocated byte array.
+ *
+ * @param numBytes - size of byte array to be created for storing key row
(without
+ * column family prefix)
+ * @return Array[Byte] for an array byte to put encoded key bytes
+ * Int for a starting offset to put the encoded key bytes
+ */
+ protected def encodeColumnFamilyPrefix(numBytes: Int): (Array[Byte], Int) = {
+ val encodedBytes = new Array[Byte](numBytes + offsetForColFamilyPrefix)
+ var offset = Platform.BYTE_ARRAY_OFFSET
+ if (useColumnFamilies) {
+ Platform.putShort(encodedBytes, Platform.BYTE_ARRAY_OFFSET,
virtualColFamilyId.get)
+ offset = Platform.BYTE_ARRAY_OFFSET + offsetForColFamilyPrefix
+ }
+ (encodedBytes, offset)
+ }
+
+ /**
+ * Get starting offset for decoding an encoded key byte array.
+ */
+ protected def decodeKeyStartOffset: Int = {
+ if (useColumnFamilies) {
+ Platform.BYTE_ARRAY_OFFSET + VIRTUAL_COL_FAMILY_PREFIX_BYTES
+ } else Platform.BYTE_ARRAY_OFFSET
+ }
+}
+
+/**
+ * Factory object for creating state encoders used by RocksDB state store.
+ *
+ * The encoders created by this object handle serialization and
deserialization of state data,
+ * supporting both key and value encoding with various access patterns
+ * (e.g., prefix scan, range scan).
+ */
+object RocksDBStateEncoder extends Logging {
+
+ /**
+ * Creates a key encoder based on the specified encoding strategy and
configuration.
+ *
+ * @param dataEncoder The underlying encoder that handles the actual data
encoding/decoding
+ * @param keyStateEncoderSpec Specification defining the key encoding
strategy
+ * (no prefix, prefix scan, or range scan)
+ * @param useColumnFamilies Whether to use RocksDB column families for
storage
+ * @param virtualColFamilyId Optional column family identifier when column
families are enabled
+ * @return A configured RocksDBKeyStateEncoder instance
+ */
+ def getKeyEncoder(
+ dataEncoder: RocksDBDataEncoder,
+ keyStateEncoderSpec: KeyStateEncoderSpec,
+ useColumnFamilies: Boolean,
+ virtualColFamilyId: Option[Short] = None): RocksDBKeyStateEncoder = {
+ keyStateEncoderSpec.toEncoder(dataEncoder, useColumnFamilies,
virtualColFamilyId)
+ }
+
+ /**
+ * Creates a value encoder that supports either single or multiple values
per key.
+ *
+ * @param dataEncoder The underlying encoder that handles the actual data
encoding/decoding
+ * @param valueSchema Schema defining the structure of values to be encoded
+ * @param useMultipleValuesPerKey If true, creates an encoder that can
handle multiple values
+ * per key; if false, creates an encoder for
single values
+ * @return A configured RocksDBValueStateEncoder instance
+ */
+ def getValueEncoder(
+ dataEncoder: RocksDBDataEncoder,
+ valueSchema: StructType,
+ useMultipleValuesPerKey: Boolean): RocksDBValueStateEncoder = {
+ if (useMultipleValuesPerKey) {
+ new MultiValuedStateEncoder(dataEncoder, valueSchema)
+ } else {
+ new SingleValueStateEncoder(dataEncoder, valueSchema)
+ }
+ }
+
+ /**
+ * Encodes a virtual column family ID into a byte array suitable for RocksDB.
+ *
+ * This method creates a fixed-size byte array prefixed with the virtual
column family ID,
+ * which is used to partition data within RocksDB.
+ *
+ * @param virtualColFamilyId The column family identifier to encode
+ * @return A byte array containing the encoded column family ID
+ */
+ def getColumnFamilyIdBytes(virtualColFamilyId: Short): Array[Byte] = {
+ val encodedBytes = new Array[Byte](VIRTUAL_COL_FAMILY_PREFIX_BYTES)
+ Platform.putShort(encodedBytes, Platform.BYTE_ARRAY_OFFSET,
virtualColFamilyId)
+ encodedBytes
+ }
+}
+
+/**
+ * RocksDB Key Encoder for UnsafeRow that supports prefix scan
+ *
+ * @param dataEncoder - the encoder that handles actual encoding/decoding of
data
+ * @param keySchema - schema of the key to be encoded
+ * @param numColsPrefixKey - number of columns to be used for prefix key
+ * @param useColumnFamilies - if column family is enabled for this encoder
+ */
+class PrefixKeyScanStateEncoder(
+ dataEncoder: RocksDBDataEncoder,
+ keySchema: StructType,
+ numColsPrefixKey: Int,
+ useColumnFamilies: Boolean = false,
+ virtualColFamilyId: Option[Short] = None)
+ extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId)
with Logging {
+
+ private val prefixKeyFieldsWithIdx: Seq[(StructField, Int)] = {
+ keySchema.zipWithIndex.take(numColsPrefixKey)
+ }
+
+ private val remainingKeyFieldsWithIdx: Seq[(StructField, Int)] = {
+ keySchema.zipWithIndex.drop(numColsPrefixKey)
+ }
+
+ private val prefixKeyProjection: UnsafeProjection = {
+ val refs = prefixKeyFieldsWithIdx.map(x => BoundReference(x._2,
x._1.dataType, x._1.nullable))
+ UnsafeProjection.create(refs)
+ }
+
+ private val remainingKeyProjection: UnsafeProjection = {
+ val refs = remainingKeyFieldsWithIdx.map(x =>
+ BoundReference(x._2, x._1.dataType, x._1.nullable))
+ UnsafeProjection.create(refs)
+ }
+
+ // This is quite simple to do - just bind sequentially, as we don't change
the order.
+ private val restoreKeyProjection: UnsafeProjection =
UnsafeProjection.create(keySchema)
+
+ // Reusable objects
+ private val joinedRowOnKey = new JoinedRow()
+
+ override def encodeKey(row: UnsafeRow): Array[Byte] = {
+ val prefixKeyEncoded = dataEncoder.encodeKey(extractPrefixKey(row))
+ val remainingEncoded =
dataEncoder.encodeRemainingKey(remainingKeyProjection(row))
+
+ val (encodedBytes, startingOffset) = encodeColumnFamilyPrefix(
+ prefixKeyEncoded.length + remainingEncoded.length + 4
+ )
+
+ Platform.putInt(encodedBytes, startingOffset, prefixKeyEncoded.length)
+ Platform.copyMemory(prefixKeyEncoded, Platform.BYTE_ARRAY_OFFSET,
+ encodedBytes, startingOffset + 4, prefixKeyEncoded.length)
+ // NOTE: We don't put the length of remainingEncoded as we can calculate
later
+ // on deserialization.
+ Platform.copyMemory(remainingEncoded, Platform.BYTE_ARRAY_OFFSET,
+ encodedBytes, startingOffset + 4 + prefixKeyEncoded.length,
+ remainingEncoded.length)
+
+ encodedBytes
+ }
+
+ override def decodeKey(keyBytes: Array[Byte]): UnsafeRow = {
+ val prefixKeyEncodedLen = Platform.getInt(keyBytes, decodeKeyStartOffset)
+ val prefixKeyEncoded = new Array[Byte](prefixKeyEncodedLen)
+ Platform.copyMemory(keyBytes, decodeKeyStartOffset + 4,
+ prefixKeyEncoded, Platform.BYTE_ARRAY_OFFSET, prefixKeyEncodedLen)
+
+ // Here we calculate the remainingKeyEncodedLen leveraging the length of
keyBytes
+ val remainingKeyEncodedLen = keyBytes.length - 4 - prefixKeyEncodedLen -
+ offsetForColFamilyPrefix
- out.reset()
- val writer = new GenericDatumWriter[GenericRecord](rangeScanAvroType)
- val encoder = EncoderFactory.get().binaryEncoder(out, null)
- writer.write(record, encoder)
- encoder.flush()
- out.toByteArray
- }
+ val remainingKeyEncoded = new Array[Byte](remainingKeyEncodedLen)
+ Platform.copyMemory(keyBytes, decodeKeyStartOffset + 4 +
prefixKeyEncodedLen,
+ remainingKeyEncoded, Platform.BYTE_ARRAY_OFFSET, remainingKeyEncodedLen)
- /**
- * Decodes an Avro-encoded byte array back into an UnsafeRow for range scan
operations.
- *
- * This method reverses the encoding process performed by
encodePrefixKeyForRangeScan:
- * - Reads the marker byte to determine null status or sign
- * - Reconstructs the original values from big-endian format
- * - Handles special cases for floating point numbers by reversing bit
manipulations
- *
- * The decoding process preserves the original data types and values,
including:
- * - Null values marked by nullValMarker
- * - Sign information for numeric types
- * - Proper restoration of negative floating point values
- *
- * @param bytes The Avro-encoded byte array to decode
- * @param avroType The Avro schema defining the structure for decoding
- * @return UnsafeRow containing the decoded data
- * @throws UnsupportedOperationException if a field's data type is not
supported for range
- * scan decoding
- */
- def decodePrefixKeyForRangeScan(
- bytes: Array[Byte],
- avroType: Schema): UnsafeRow = {
+ val prefixKeyDecoded = dataEncoder.decodeKey(
+ prefixKeyEncoded)
+ val remainingKeyDecoded =
dataEncoder.decodeRemainingKey(remainingKeyEncoded)
- val reader = new GenericDatumReader[GenericRecord](avroType)
- val decoder = DecoderFactory.get().binaryDecoder(bytes, 0, bytes.length,
null)
- val record = reader.read(null, decoder)
+
restoreKeyProjection(joinedRowOnKey.withLeft(prefixKeyDecoded).withRight(remainingKeyDecoded))
+ }
- val rowWriter = new UnsafeRowWriter(rangeScanKeyFieldsWithOrdinal.length)
- rowWriter.resetRowWriter()
+ private def extractPrefixKey(key: UnsafeRow): UnsafeRow = {
+ prefixKeyProjection(key)
+ }
- var fieldIdx = 0
- rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case
(fieldWithOrdinal, idx) =>
- val field = fieldWithOrdinal._1
+ override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = {
+ val prefixKeyEncoded = dataEncoder.encodeKey(prefixKey)
+ val (prefix, startingOffset) = encodeColumnFamilyPrefix(
+ prefixKeyEncoded.length + 4
+ )
- val markerBytes = record.get(fieldIdx).asInstanceOf[ByteBuffer].array()
- val markerBuf = ByteBuffer.wrap(markerBytes)
- markerBuf.order(ByteOrder.BIG_ENDIAN)
- val marker = markerBuf.get()
+ Platform.putInt(prefix, startingOffset, prefixKeyEncoded.length)
+ Platform.copyMemory(prefixKeyEncoded, Platform.BYTE_ARRAY_OFFSET, prefix,
+ startingOffset + 4, prefixKeyEncoded.length)
+ prefix
+ }
- if (marker == nullValMarker) {
- rowWriter.setNullAt(idx)
- } else {
- field.dataType match {
- case BooleanType =>
- val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
- rowWriter.write(idx, bytes(0) == 1)
+ override def supportPrefixKeyScan: Boolean = true
+}
- case ByteType =>
- val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
- val valueBuf = ByteBuffer.wrap(bytes)
- valueBuf.order(ByteOrder.BIG_ENDIAN)
- rowWriter.write(idx, valueBuf.get())
+/**
+ * RocksDB Key Encoder for UnsafeRow that supports range scan for fixed size
fields
+ *
+ * To encode a row for range scan, we first project the orderingOrdinals from
the oridinal
+ * UnsafeRow into another UnsafeRow; we then rewrite that new UnsafeRow's
fields in BIG_ENDIAN
+ * to allow for scanning keys in sorted order using the byte-wise comparison
method that
+ * RocksDB uses.
+ *
+ * Then, for the rest of the fields, we project those into another UnsafeRow.
+ * We then effectively join these two UnsafeRows together, and finally take
those bytes
+ * to get the resulting row.
+ *
+ * We cannot support variable sized fields in the range scan because the
UnsafeRow format
+ * stores variable sized fields as offset and length pointers to the actual
values,
+ * thereby changing the required ordering.
+ *
+ * Note that we also support "null" values being passed for these fixed size
fields. We prepend
+ * a single byte to indicate whether the column value is null or not. We
cannot change the
+ * nullability on the UnsafeRow itself as the expected ordering would change
if non-first
+ * columns are marked as null. If the first col is null, those entries will
appear last in
+ * the iterator. If non-first columns are null, ordering based on the previous
columns will
+ * still be honored. For rows with null column values, ordering for subsequent
columns
+ * will also be maintained within those set of rows. We use the same byte to
also encode whether
+ * the value is negative or not. For negative float/double values, we flip all
the bits to ensure
+ * the right lexicographical ordering. For the rationale around this, please
check the link
+ * here: https://en.wikipedia.org/wiki/IEEE_754#Design_rationale
+ *
+ * @param dataEncoder - the encoder that handles the actual encoding/decoding
of data
+ * @param keySchema - schema of the key to be encoded
+ * @param orderingOrdinals - the ordinals for which the range scan is
constructed
+ * @param useColumnFamilies - if column family is enabled for this encoder
+ */
+class RangeKeyScanStateEncoder(
+ dataEncoder: RocksDBDataEncoder,
+ keySchema: StructType,
+ orderingOrdinals: Seq[Int],
+ useColumnFamilies: Boolean = false,
+ virtualColFamilyId: Option[Short] = None)
+ extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId)
with Logging {
- case ShortType =>
- val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
- val valueBuf = ByteBuffer.wrap(bytes)
- valueBuf.order(ByteOrder.BIG_ENDIAN)
- rowWriter.write(idx, valueBuf.getShort())
+ private val rangeScanKeyFieldsWithOrdinal: Seq[(StructField, Int)] = {
+ orderingOrdinals.map { ordinal =>
+ val field = keySchema(ordinal)
+ (field, ordinal)
+ }
+ }
- case IntegerType =>
- val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
- val valueBuf = ByteBuffer.wrap(bytes)
- valueBuf.order(ByteOrder.BIG_ENDIAN)
- rowWriter.write(idx, valueBuf.getInt())
+ private def isFixedSize(dataType: DataType): Boolean = dataType match {
+ case _: ByteType | _: BooleanType | _: ShortType | _: IntegerType | _:
LongType |
+ _: FloatType | _: DoubleType => true
+ case _ => false
+ }
- case LongType =>
- val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
- val valueBuf = ByteBuffer.wrap(bytes)
- valueBuf.order(ByteOrder.BIG_ENDIAN)
- rowWriter.write(idx, valueBuf.getLong())
+ // verify that only fixed sized columns are used for ordering
+ rangeScanKeyFieldsWithOrdinal.foreach { case (field, ordinal) =>
+ if (!isFixedSize(field.dataType)) {
+ // NullType is technically fixed size, but not supported for ordering
+ if (field.dataType == NullType) {
+ throw StateStoreErrors.nullTypeOrderingColsNotSupported(field.name,
ordinal.toString)
+ } else {
+ throw
StateStoreErrors.variableSizeOrderingColsNotSupported(field.name,
ordinal.toString)
+ }
+ }
+ }
- case FloatType =>
- val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
- val valueBuf = ByteBuffer.wrap(bytes)
- valueBuf.order(ByteOrder.BIG_ENDIAN)
- if (marker == negativeValMarker) {
- val floatVal = valueBuf.getFloat
- val updatedVal = floatToRawIntBits(floatVal) ^ floatFlipBitMask
- rowWriter.write(idx, intBitsToFloat(updatedVal))
- } else {
- rowWriter.write(idx, valueBuf.getFloat())
- }
+ private val remainingKeyFieldsWithOrdinal: Seq[(StructField, Int)] = {
+ 0.to(keySchema.length - 1).diff(orderingOrdinals).map { ordinal =>
+ val field = keySchema(ordinal)
+ (field, ordinal)
+ }
+ }
- case DoubleType =>
- val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
- val valueBuf = ByteBuffer.wrap(bytes)
- valueBuf.order(ByteOrder.BIG_ENDIAN)
- if (marker == negativeValMarker) {
- val doubleVal = valueBuf.getDouble
- val updatedVal = doubleToRawLongBits(doubleVal) ^
doubleFlipBitMask
- rowWriter.write(idx, longBitsToDouble(updatedVal))
- } else {
- rowWriter.write(idx, valueBuf.getDouble())
- }
+ private val rangeScanKeyProjection: UnsafeProjection = {
+ val refs = rangeScanKeyFieldsWithOrdinal.map(x =>
+ BoundReference(x._2, x._1.dataType, x._1.nullable))
+ UnsafeProjection.create(refs)
+ }
- case _ => throw new UnsupportedOperationException(
- s"Range scan decoding not supported for data type:
${field.dataType}")
- }
+ private val remainingKeyProjection: UnsafeProjection = {
+ val refs = remainingKeyFieldsWithOrdinal.map(x =>
+ BoundReference(x._2, x._1.dataType, x._1.nullable))
+ UnsafeProjection.create(refs)
+ }
+
+ // The original schema that we might get could be:
+ // [foo, bar, baz, buzz]
+ // We might order by bar and buzz, leading to:
+ // [bar, buzz, foo, baz]
+ // We need to create a projection that sends, for example, the buzz at index
1 to index
+ // 3. Thus, for every record in the original schema, we compute where it
would be in
+ // the joined row and created a projection based on that.
+ private val restoreKeyProjection: UnsafeProjection = {
+ val refs = keySchema.zipWithIndex.map { case (field, originalOrdinal) =>
+ val ordinalInJoinedRow = if (orderingOrdinals.contains(originalOrdinal))
{
+ orderingOrdinals.indexOf(originalOrdinal)
+ } else {
+ orderingOrdinals.length +
+ remainingKeyFieldsWithOrdinal.indexWhere(_._2 == originalOrdinal)
}
- fieldIdx += 2
+
+ BoundReference(ordinalInJoinedRow, field.dataType, field.nullable)
}
+ UnsafeProjection.create(refs)
+ }
- rowWriter.getRow()
+ // Reusable objects
+ private val joinedRowOnKey = new JoinedRow()
+
+ private def extractPrefixKey(key: UnsafeRow): UnsafeRow = {
+ rangeScanKeyProjection(key)
}
override def encodeKey(row: UnsafeRow): Array[Byte] = {
// This prefix key has the columns specified by orderingOrdinals
val prefixKey = extractPrefixKey(row)
- val rangeScanKeyEncoded = if (avroEnc.isDefined) {
- encodePrefixKeyForRangeScan(prefixKey, rangeScanAvroType)
- } else {
- encodeUnsafeRow(encodePrefixKeyForRangeScan(prefixKey))
- }
+ val rangeScanKeyEncoded =
dataEncoder.encodePrefixKeyForRangeScan(prefixKey)
val result = if (orderingOrdinals.length < keySchema.length) {
- val remainingEncoded = if (avroEnc.isDefined) {
- encodeUnsafeRowToAvro(
- remainingKeyProjection(row),
- avroEnc.get.keySerializer,
- remainingKeyAvroType,
- out
- )
- } else {
- encodeUnsafeRow(remainingKeyProjection(row))
- }
+ val remainingEncoded =
dataEncoder.encodeRemainingKey(remainingKeyProjection(row))
val (encodedBytes, startingOffset) = encodeColumnFamilyPrefix(
rangeScanKeyEncoded.length + remainingEncoded.length + 4
)
@@ -986,12 +1262,8 @@ class RangeKeyScanStateEncoder(
Platform.copyMemory(keyBytes, decodeKeyStartOffset + 4,
prefixKeyEncoded, Platform.BYTE_ARRAY_OFFSET, prefixKeyEncodedLen)
- val prefixKeyDecoded = if (avroEnc.isDefined) {
- decodePrefixKeyForRangeScan(prefixKeyEncoded, rangeScanAvroType)
- } else {
- decodePrefixKeyForRangeScan(decodeToUnsafeRow(prefixKeyEncoded,
- numFields = orderingOrdinals.length))
- }
+ val prefixKeyDecoded = dataEncoder.decodePrefixKeyForRangeScan(
+ prefixKeyEncoded)
if (orderingOrdinals.length < keySchema.length) {
// Here we calculate the remainingKeyEncodedLen leveraging the length of
keyBytes
@@ -1003,14 +1275,7 @@ class RangeKeyScanStateEncoder(
remainingKeyEncoded, Platform.BYTE_ARRAY_OFFSET,
remainingKeyEncodedLen)
- val remainingKeyDecoded = if (avroEnc.isDefined) {
- decodeFromAvroToUnsafeRow(remainingKeyEncoded,
- avroEnc.get.keyDeserializer,
- remainingKeyAvroType, remainingKeyAvroProjection)
- } else {
- decodeToUnsafeRow(remainingKeyEncoded,
- numFields = keySchema.length - orderingOrdinals.length)
- }
+ val remainingKeyDecoded =
dataEncoder.decodeRemainingKey(remainingKeyEncoded)
val joined =
joinedRowOnKey.withLeft(prefixKeyDecoded).withRight(remainingKeyDecoded)
val restored = restoreKeyProjection(joined)
@@ -1023,11 +1288,7 @@ class RangeKeyScanStateEncoder(
}
override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = {
- val rangeScanKeyEncoded = if (avroEnc.isDefined) {
- encodePrefixKeyForRangeScan(prefixKey, rangeScanAvroType)
- } else {
- encodeUnsafeRow(encodePrefixKeyForRangeScan(prefixKey))
- }
+ val rangeScanKeyEncoded =
dataEncoder.encodePrefixKeyForRangeScan(prefixKey)
val (prefix, startingOffset) =
encodeColumnFamilyPrefix(rangeScanKeyEncoded.length + 4)
Platform.putInt(prefix, startingOffset, rangeScanKeyEncoded.length)
@@ -1046,36 +1307,23 @@ class RangeKeyScanStateEncoder(
* It uses the first byte of the generated byte array to store the version the
describes how the
* row is encoded in the rest of the byte array. Currently, the default
version is 0,
*
- * If the avroEnc is specified, we are using Avro encoding for this column
family's keys
* VERSION 0: [ VERSION (1 byte) | ROW (N bytes) ]
* The bytes of a UnsafeRow is written unmodified to starting from offset 1
* (offset 0 is the version byte of value 0). That is, if the unsafe row
has N bytes,
* then the generated array byte will be N+1 bytes.
*/
class NoPrefixKeyStateEncoder(
+ dataEncoder: RocksDBDataEncoder,
keySchema: StructType,
useColumnFamilies: Boolean = false,
- virtualColFamilyId: Option[Short] = None,
- avroEnc: Option[AvroEncoder] = None)
+ virtualColFamilyId: Option[Short] = None)
extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId)
with Logging {
- import RocksDBStateEncoder._
-
- // Reusable objects
- private val usingAvroEncoding = avroEnc.isDefined
- private val keyRow = new UnsafeRow(keySchema.size)
- private lazy val keyAvroType = SchemaConverters.toAvroType(keySchema)
- private val keyProj = UnsafeProjection.create(keySchema)
-
override def encodeKey(row: UnsafeRow): Array[Byte] = {
if (!useColumnFamilies) {
- encodeUnsafeRow(row)
+ dataEncoder.encodeKey(row)
} else {
- // If avroEnc is defined, we know that we need to use Avro to
- // encode this UnsafeRow to Avro bytes
- val bytesToEncode = if (usingAvroEncoding) {
- encodeUnsafeRowToAvro(row, avroEnc.get.keySerializer, keyAvroType, out)
- } else row.getBytes
+ val bytesToEncode = dataEncoder.encodeKey(row)
val (encodedBytes, startingOffset) = encodeColumnFamilyPrefix(
bytesToEncode.length +
STATE_ENCODING_NUM_VERSION_BYTES
@@ -1098,26 +1346,23 @@ class NoPrefixKeyStateEncoder(
override def decodeKey(keyBytes: Array[Byte]): UnsafeRow = {
if (useColumnFamilies) {
if (keyBytes != null) {
- // Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st
offset. See Platform.
- if (usingAvroEncoding) {
- val dataLength = keyBytes.length - STATE_ENCODING_NUM_VERSION_BYTES -
- VIRTUAL_COL_FAMILY_PREFIX_BYTES
- val avroBytes = new Array[Byte](dataLength)
- Platform.copyMemory(
- keyBytes, decodeKeyStartOffset + STATE_ENCODING_NUM_VERSION_BYTES,
- avroBytes, Platform.BYTE_ARRAY_OFFSET, dataLength)
- decodeFromAvroToUnsafeRow(avroBytes, avroEnc.get.keyDeserializer,
keyAvroType, keyProj)
- } else {
- keyRow.pointTo(
- keyBytes,
- decodeKeyStartOffset + STATE_ENCODING_NUM_VERSION_BYTES,
- keyBytes.length - STATE_ENCODING_NUM_VERSION_BYTES -
VIRTUAL_COL_FAMILY_PREFIX_BYTES)
- keyRow
- }
+ // Create new byte array without prefix
+ val dataLength = keyBytes.length -
+ STATE_ENCODING_NUM_VERSION_BYTES - VIRTUAL_COL_FAMILY_PREFIX_BYTES
+ val dataBytes = new Array[Byte](dataLength)
+ Platform.copyMemory(
+ keyBytes,
+ decodeKeyStartOffset + STATE_ENCODING_NUM_VERSION_BYTES,
+ dataBytes,
+ Platform.BYTE_ARRAY_OFFSET,
+ dataLength)
+ dataEncoder.decodeKey(dataBytes)
} else {
null
}
- } else decodeToUnsafeRow(keyBytes, keyRow)
+ } else {
+ dataEncoder.decodeKey(keyBytes)
+ }
}
override def supportPrefixKeyScan: Boolean = false
@@ -1139,28 +1384,14 @@ class NoPrefixKeyStateEncoder(
* This encoder supports RocksDB StringAppendOperator merge operator. Values
encoded can be
* merged in RocksDB using merge operation, and all merged values can be read
using decodeValues
* operation.
- * If the avroEnc is specified, we are using Avro encoding for this column
family's values
*/
class MultiValuedStateEncoder(
- valueSchema: StructType,
- avroEnc: Option[AvroEncoder] = None)
+ dataEncoder: RocksDBDataEncoder,
+ valueSchema: StructType)
extends RocksDBValueStateEncoder with Logging {
- import RocksDBStateEncoder._
-
- private val usingAvroEncoding = avroEnc.isDefined
- // Reusable objects
- private val out = new ByteArrayOutputStream
- private val valueRow = new UnsafeRow(valueSchema.size)
- private lazy val valueAvroType = SchemaConverters.toAvroType(valueSchema)
- private val valueProj = UnsafeProjection.create(valueSchema)
-
override def encodeValue(row: UnsafeRow): Array[Byte] = {
- val bytes = if (usingAvroEncoding) {
- encodeUnsafeRowToAvro(row, avroEnc.get.valueSerializer, valueAvroType,
out)
- } else {
- encodeUnsafeRow(row)
- }
+ val bytes = dataEncoder.encodeValue(row)
val numBytes = bytes.length
val encodedBytes = new Array[Byte](java.lang.Integer.BYTES + bytes.length)
@@ -1179,12 +1410,7 @@ class MultiValuedStateEncoder(
val encodedValue = new Array[Byte](numBytes)
Platform.copyMemory(valueBytes, java.lang.Integer.BYTES +
Platform.BYTE_ARRAY_OFFSET,
encodedValue, Platform.BYTE_ARRAY_OFFSET, numBytes)
- if (usingAvroEncoding) {
- decodeFromAvroToUnsafeRow(
- encodedValue, avroEnc.get.valueDeserializer, valueAvroType,
valueProj)
- } else {
- decodeToUnsafeRow(encodedValue, valueRow)
- }
+ dataEncoder.decodeValue(encodedValue)
}
}
@@ -1210,12 +1436,7 @@ class MultiValuedStateEncoder(
pos += numBytes
pos += 1 // eat the delimiter character
- if (usingAvroEncoding) {
- decodeFromAvroToUnsafeRow(
- encodedValue, avroEnc.get.valueDeserializer, valueAvroType,
valueProj)
- } else {
- decodeToUnsafeRow(encodedValue, valueRow)
- }
+ dataEncoder.decodeValue(encodedValue)
}
}
}
@@ -1235,29 +1456,13 @@ class MultiValuedStateEncoder(
* The bytes of a UnsafeRow is written unmodified to starting from offset 1
* (offset 0 is the version byte of value 0). That is, if the unsafe row
has N bytes,
* then the generated array byte will be N+1 bytes.
- * If the avroEnc is specified, we are using Avro encoding for this column
family's values
*/
class SingleValueStateEncoder(
- valueSchema: StructType,
- avroEnc: Option[AvroEncoder] = None)
- extends RocksDBValueStateEncoder with Logging {
-
- import RocksDBStateEncoder._
-
- private val usingAvroEncoding = avroEnc.isDefined
- // Reusable objects
- private val out = new ByteArrayOutputStream
- private val valueRow = new UnsafeRow(valueSchema.size)
- private lazy val valueAvroType = SchemaConverters.toAvroType(valueSchema)
- private val valueProj = UnsafeProjection.create(valueSchema)
+ dataEncoder: RocksDBDataEncoder,
+ valueSchema: StructType)
+ extends RocksDBValueStateEncoder {
- override def encodeValue(row: UnsafeRow): Array[Byte] = {
- if (usingAvroEncoding) {
- encodeUnsafeRowToAvro(row, avroEnc.get.valueSerializer, valueAvroType,
out)
- } else {
- encodeUnsafeRow(row)
- }
- }
+ override def encodeValue(row: UnsafeRow): Array[Byte] =
dataEncoder.encodeValue(row)
/**
* Decode byte array for a value to a UnsafeRow.
@@ -1266,15 +1471,7 @@ class SingleValueStateEncoder(
* the given byte array.
*/
override def decodeValue(valueBytes: Array[Byte]): UnsafeRow = {
- if (valueBytes == null) {
- return null
- }
- if (usingAvroEncoding) {
- decodeFromAvroToUnsafeRow(
- valueBytes, avroEnc.get.valueDeserializer, valueAvroType, valueProj)
- } else {
- decodeToUnsafeRow(valueBytes, valueRow)
- }
+ dataEncoder.decodeValue(valueBytes)
}
override def supportsMultipleValuesPerKey: Boolean = false
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
index c9c987fa1620..fb0bf84d7aab 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
@@ -30,7 +30,6 @@ import org.apache.spark.{SparkConf, SparkEnv, SparkException}
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys._
import org.apache.spark.io.CompressionCodec
-import org.apache.spark.sql.avro.{AvroDeserializer, AvroOptions,
AvroSerializer, SchemaConverters}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.streaming.{CheckpointFileManager,
StreamExecution}
@@ -76,17 +75,28 @@ private[sql] class RocksDBStateStoreProvider
isInternal: Boolean = false): Unit = {
verifyColFamilyCreationOrDeletion("create_col_family", colFamilyName,
isInternal)
val newColFamilyId = rocksDB.createColFamilyIfAbsent(colFamilyName)
- // Create cache key using store ID to avoid collisions
- val avroEncCacheKey =
s"${getRunId(hadoopConf)}_${stateStoreId.operatorId}_" +
- s"${stateStoreId.partitionId}_$colFamilyName"
-
- val avroEnc = getAvroEnc(
- stateStoreEncoding, avroEncCacheKey, keyStateEncoderSpec, valueSchema)
-
- keyValueEncoderMap.putIfAbsent(colFamilyName,
- (RocksDBStateEncoder.getKeyEncoder(keyStateEncoderSpec,
useColumnFamilies,
- Some(newColFamilyId), avroEnc),
RocksDBStateEncoder.getValueEncoder(valueSchema,
- useMultipleValuesPerKey, avroEnc)))
+ val dataEncoderCacheKey = StateRowEncoderCacheKey(
+ queryRunId = getRunId(hadoopConf),
+ operatorId = stateStoreId.operatorId,
+ partitionId = stateStoreId.partitionId,
+ stateStoreName = stateStoreId.storeName,
+ colFamilyName = colFamilyName)
+
+ val dataEncoder = getDataEncoder(
+ stateStoreEncoding, dataEncoderCacheKey, keyStateEncoderSpec,
valueSchema)
+
+ val keyEncoder = RocksDBStateEncoder.getKeyEncoder(
+ dataEncoder,
+ keyStateEncoderSpec,
+ useColumnFamilies,
+ Some(newColFamilyId)
+ )
+ val valueEncoder = RocksDBStateEncoder.getValueEncoder(
+ dataEncoder,
+ valueSchema,
+ useMultipleValuesPerKey
+ )
+ keyValueEncoderMap.putIfAbsent(colFamilyName, (keyEncoder, valueEncoder))
}
override def get(key: UnsafeRow, colFamilyName: String): UnsafeRow = {
@@ -387,17 +397,28 @@ private[sql] class RocksDBStateStoreProvider
defaultColFamilyId =
Some(rocksDB.createColFamilyIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME))
}
- val colFamilyName = StateStore.DEFAULT_COL_FAMILY_NAME
- // Create cache key using store ID to avoid collisions
- val avroEncCacheKey =
s"${getRunId(hadoopConf)}_${stateStoreId.operatorId}_" +
- s"${stateStoreId.partitionId}_$colFamilyName"
- val avroEnc = getAvroEnc(
- stateStoreEncoding, avroEncCacheKey, keyStateEncoderSpec, valueSchema)
+ val dataEncoderCacheKey = StateRowEncoderCacheKey(
+ queryRunId = getRunId(hadoopConf),
+ operatorId = stateStoreId.operatorId,
+ partitionId = stateStoreId.partitionId,
+ stateStoreName = stateStoreId.storeName,
+ colFamilyName = StateStore.DEFAULT_COL_FAMILY_NAME)
- keyValueEncoderMap.putIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME,
- (RocksDBStateEncoder.getKeyEncoder(keyStateEncoderSpec,
- useColumnFamilies, defaultColFamilyId, avroEnc),
- RocksDBStateEncoder.getValueEncoder(valueSchema,
useMultipleValuesPerKey, avroEnc)))
+ val dataEncoder = getDataEncoder(
+ stateStoreEncoding, dataEncoderCacheKey, keyStateEncoderSpec,
valueSchema)
+
+ val keyEncoder = RocksDBStateEncoder.getKeyEncoder(
+ dataEncoder,
+ keyStateEncoderSpec,
+ useColumnFamilies,
+ defaultColFamilyId
+ )
+ val valueEncoder = RocksDBStateEncoder.getValueEncoder(
+ dataEncoder,
+ valueSchema,
+ useMultipleValuesPerKey
+ )
+ keyValueEncoderMap.putIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME,
(keyEncoder, valueEncoder))
}
override def stateStoreId: StateStoreId = stateStoreId_
@@ -605,6 +626,15 @@ private[sql] class RocksDBStateStoreProvider
}
}
+
+case class StateRowEncoderCacheKey(
+ queryRunId: String,
+ operatorId: Long,
+ partitionId: Int,
+ stateStoreName: String,
+ colFamilyName: String
+)
+
object RocksDBStateStoreProvider {
// Version as a single byte that specifies the encoding of the row data in
RocksDB
val STATE_ENCODING_NUM_VERSION_BYTES = 1
@@ -615,30 +645,48 @@ object RocksDBStateStoreProvider {
private val AVRO_ENCODER_LIFETIME_HOURS = 1L
// Add the cache at companion object level so it persists across provider
instances
- private val avroEncoderMap: NonFateSharingCache[String, AvroEncoder] =
+ private val dataEncoderCache: NonFateSharingCache[StateRowEncoderCacheKey,
RocksDBDataEncoder] =
NonFateSharingCache(
maximumSize = MAX_AVRO_ENCODERS_IN_CACHE,
expireAfterAccessTime = AVRO_ENCODER_LIFETIME_HOURS,
expireAfterAccessTimeUnit = TimeUnit.HOURS
)
- def getAvroEnc(
+ /**
+ * Creates and returns a data encoder for the state store based on the
specified encoding type.
+ * This method handles caching of encoders to improve performance by reusing
encoder instances
+ * when possible.
+ *
+ * The method supports two encoding types:
+ * - Avro: Uses Apache Avro for serialization with schema evolution support
+ * - UnsafeRow: Uses Spark's internal row format for optimal performance
+ *
+ * @param stateStoreEncoding The encoding type to use ("avro" or "unsaferow")
+ * @param encoderCacheKey A unique key for caching the encoder instance,
typically combining
+ * query ID, operator ID, partition ID, and column
family name
+ * @param keyStateEncoderSpec Specification for how to encode keys,
including schema and any
+ * prefix/range scan requirements
+ * @param valueSchema The schema for the values to be encoded
+ * @return A RocksDBDataEncoder instance configured for the specified
encoding type
+ */
+ def getDataEncoder(
stateStoreEncoding: String,
- avroEncCacheKey: String,
+ encoderCacheKey: StateRowEncoderCacheKey,
keyStateEncoderSpec: KeyStateEncoderSpec,
- valueSchema: StructType): Option[AvroEncoder] = {
-
- stateStoreEncoding match {
- case "avro" => Some(
- RocksDBStateStoreProvider.avroEncoderMap.get(
- avroEncCacheKey,
- new java.util.concurrent.Callable[AvroEncoder] {
- override def call(): AvroEncoder =
createAvroEnc(keyStateEncoderSpec, valueSchema)
+ valueSchema: StructType): RocksDBDataEncoder = {
+ assert(Set("avro", "unsaferow").contains(stateStoreEncoding))
+ RocksDBStateStoreProvider.dataEncoderCache.get(
+ encoderCacheKey,
+ new java.util.concurrent.Callable[RocksDBDataEncoder] {
+ override def call(): RocksDBDataEncoder = {
+ if (stateStoreEncoding == "avro") {
+ new AvroStateEncoder(keyStateEncoderSpec, valueSchema)
+ } else {
+ new UnsafeRowDataEncoder(keyStateEncoderSpec, valueSchema)
}
- )
- )
- case "unsaferow" => None
- }
+ }
+ }
+ )
}
private def getRunId(hadoopConf: Configuration): String = {
@@ -651,53 +699,6 @@ object RocksDBStateStoreProvider {
}
}
- private def getAvroSerializer(schema: StructType): AvroSerializer = {
- val avroType = SchemaConverters.toAvroType(schema)
- new AvroSerializer(schema, avroType, nullable = false)
- }
-
- private def getAvroDeserializer(schema: StructType): AvroDeserializer = {
- val avroType = SchemaConverters.toAvroType(schema)
- val avroOptions = AvroOptions(Map.empty)
- new AvroDeserializer(avroType, schema,
- avroOptions.datetimeRebaseModeInRead,
avroOptions.useStableIdForUnionType,
- avroOptions.stableIdPrefixForUnionType,
avroOptions.recursiveFieldMaxDepth)
- }
-
- private def createAvroEnc(
- keyStateEncoderSpec: KeyStateEncoderSpec,
- valueSchema: StructType
- ): AvroEncoder = {
- val valueSerializer = getAvroSerializer(valueSchema)
- val valueDeserializer = getAvroDeserializer(valueSchema)
- val keySchema = keyStateEncoderSpec match {
- case NoPrefixKeyStateEncoderSpec(schema) =>
- schema
- case PrefixKeyScanStateEncoderSpec(schema, numColsPrefixKey) =>
- StructType(schema.take(numColsPrefixKey))
- case RangeKeyScanStateEncoderSpec(schema, orderingOrdinals) =>
- val remainingSchema = {
- 0.until(schema.length).diff(orderingOrdinals).map { ordinal =>
- schema(ordinal)
- }
- }
- StructType(remainingSchema)
- }
- val suffixKeySchema = keyStateEncoderSpec match {
- case PrefixKeyScanStateEncoderSpec(schema, numColsPrefixKey) =>
- Some(StructType(schema.drop(numColsPrefixKey)))
- case _ => None
- }
- AvroEncoder(
- getAvroSerializer(keySchema),
- getAvroDeserializer(keySchema),
- valueSerializer,
- valueDeserializer,
- suffixKeySchema.map(getAvroSerializer),
- suffixKeySchema.map(getAvroDeserializer)
- )
- }
-
// Native operation latencies report as latency in microseconds
// as SQLMetrics support millis. Convert the value to millis
val CUSTOM_METRIC_GET_TIME = StateStoreCustomTimingMetric(
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index e2b93c147891..de10518035e2 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -322,8 +322,22 @@ case class StateStoreCustomTimingMetric(name: String,
desc: String) extends Stat
}
sealed trait KeyStateEncoderSpec {
+ def keySchema: StructType
def jsonValue: JValue
def json: String = compact(render(jsonValue))
+
+ /**
+ * Creates a RocksDBKeyStateEncoder for this specification.
+ *
+ * @param dataEncoder The encoder to handle the actual data encoding/decoding
+ * @param useColumnFamilies Whether to use RocksDB column families
+ * @param virtualColFamilyId Optional column family ID when column families
are used
+ * @return A RocksDBKeyStateEncoder configured for this spec
+ */
+ def toEncoder(
+ dataEncoder: RocksDBDataEncoder,
+ useColumnFamilies: Boolean,
+ virtualColFamilyId: Option[Short]): RocksDBKeyStateEncoder
}
object KeyStateEncoderSpec {
@@ -347,6 +361,14 @@ case class NoPrefixKeyStateEncoderSpec(keySchema:
StructType) extends KeyStateEn
override def jsonValue: JValue = {
("keyStateEncoderType" -> JString("NoPrefixKeyStateEncoderSpec"))
}
+
+ override def toEncoder(
+ dataEncoder: RocksDBDataEncoder,
+ useColumnFamilies: Boolean,
+ virtualColFamilyId: Option[Short]): RocksDBKeyStateEncoder = {
+ new NoPrefixKeyStateEncoder(
+ dataEncoder, keySchema, useColumnFamilies, virtualColFamilyId)
+ }
}
case class PrefixKeyScanStateEncoderSpec(
@@ -356,6 +378,15 @@ case class PrefixKeyScanStateEncoderSpec(
throw
StateStoreErrors.incorrectNumOrderingColsForPrefixScan(numColsPrefixKey.toString)
}
+ override def toEncoder(
+ dataEncoder: RocksDBDataEncoder,
+ useColumnFamilies: Boolean,
+ virtualColFamilyId: Option[Short]): RocksDBKeyStateEncoder = {
+ new PrefixKeyScanStateEncoder(
+ dataEncoder, keySchema, numColsPrefixKey, useColumnFamilies,
virtualColFamilyId)
+ }
+
+
override def jsonValue: JValue = {
("keyStateEncoderType" -> JString("PrefixKeyScanStateEncoderSpec")) ~
("numColsPrefixKey" -> JInt(numColsPrefixKey))
@@ -370,6 +401,14 @@ case class RangeKeyScanStateEncoderSpec(
throw
StateStoreErrors.incorrectNumOrderingColsForRangeScan(orderingOrdinals.length.toString)
}
+ override def toEncoder(
+ dataEncoder: RocksDBDataEncoder,
+ useColumnFamilies: Boolean,
+ virtualColFamilyId: Option[Short]): RocksDBKeyStateEncoder = {
+ new RangeKeyScanStateEncoder(
+ dataEncoder, keySchema, orderingOrdinals, useColumnFamilies,
virtualColFamilyId)
+ }
+
override def jsonValue: JValue = {
("keyStateEncoderType" -> JString("RangeKeyScanStateEncoderSpec")) ~
("orderingOrdinals" -> orderingOrdinals.map(JInt(_)))
@@ -758,6 +797,7 @@ object StateStore extends Logging {
storeConf: StateStoreConf,
hadoopConf: Configuration,
useMultipleValuesPerKey: Boolean = false): ReadStateStore = {
+ hadoopConf.set(StreamExecution.RUN_ID_KEY,
storeProviderId.queryRunId.toString)
if (version < 0) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(version)
}
@@ -778,6 +818,7 @@ object StateStore extends Logging {
storeConf: StateStoreConf,
hadoopConf: Configuration,
useMultipleValuesPerKey: Boolean = false): StateStore = {
+ hadoopConf.set(StreamExecution.RUN_ID_KEY,
storeProviderId.queryRunId.toString)
if (version < 0) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(version)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]