brkyvz commented on code in PR #48944:
URL: https://github.com/apache/spark/pull/48944#discussion_r1884513211
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala:
##########
@@ -76,17 +76,16 @@ private[sql] class RocksDBStateStoreProvider
isInternal: Boolean = false): Unit = {
verifyColFamilyCreationOrDeletion("create_col_family", colFamilyName,
isInternal)
val newColFamilyId = rocksDB.createColFamilyIfAbsent(colFamilyName)
- // Create cache key using store ID to avoid collisions
- val avroEncCacheKey =
s"${getRunId(hadoopConf)}_${stateStoreId.operatorId}_" +
- s"${stateStoreId.partitionId}_$colFamilyName"
+ val dataEncoderCacheKey =
s"${getRunId(hadoopConf)}_${stateStoreId.operatorId}_" +
+ s"${stateStoreId.partitionId}_${colFamilyName}"
- val avroEnc = getAvroEnc(
- stateStoreEncoding, avroEncCacheKey, keyStateEncoderSpec, valueSchema)
+ val dataEncoder = getDataEncoder(
+ stateStoreEncoding, dataEncoderCacheKey, keyStateEncoderSpec,
valueSchema)
keyValueEncoderMap.putIfAbsent(colFamilyName,
- (RocksDBStateEncoder.getKeyEncoder(keyStateEncoderSpec,
useColumnFamilies,
- Some(newColFamilyId), avroEnc),
RocksDBStateEncoder.getValueEncoder(valueSchema,
- useMultipleValuesPerKey, avroEnc)))
+ (RocksDBStateEncoder.getKeyEncoder(dataEncoder, keyStateEncoderSpec,
useColumnFamilies,
+ Some(newColFamilyId)),
RocksDBStateEncoder.getValueEncoder(dataEncoder, valueSchema,
+ useMultipleValuesPerKey)))
Review Comment:
nit: can you make these one parameter per line please? maybe define them
outside of the putIfAbsent first
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala:
##########
@@ -51,93 +51,56 @@ sealed trait RocksDBValueStateEncoder {
def decodeValues(valueBytes: Array[Byte]): Iterator[UnsafeRow]
}
-abstract class RocksDBKeyStateEncoderBase(
- useColumnFamilies: Boolean,
- virtualColFamilyId: Option[Short] = None) extends RocksDBKeyStateEncoder {
- def offsetForColFamilyPrefix: Int =
- if (useColumnFamilies) VIRTUAL_COL_FAMILY_PREFIX_BYTES else 0
-
- val out = new ByteArrayOutputStream
- /**
- * Get Byte Array for the virtual column family id that is used as prefix for
- * key state rows.
- */
- override def getColumnFamilyIdBytes(): Array[Byte] = {
- assert(useColumnFamilies, "Cannot return virtual Column Family Id Bytes" +
- " because multiple Column is not supported for this encoder")
- val encodedBytes = new Array[Byte](VIRTUAL_COL_FAMILY_PREFIX_BYTES)
- Platform.putShort(encodedBytes, Platform.BYTE_ARRAY_OFFSET,
virtualColFamilyId.get)
- encodedBytes
- }
-
- /**
- * Encode and put column family Id as a prefix to a pre-allocated byte array.
- *
- * @param numBytes - size of byte array to be created for storing key row
(without
- * column family prefix)
- * @return Array[Byte] for an array byte to put encoded key bytes
- * Int for a starting offset to put the encoded key bytes
- */
- protected def encodeColumnFamilyPrefix(numBytes: Int): (Array[Byte], Int) = {
- val encodedBytes = new Array[Byte](numBytes + offsetForColFamilyPrefix)
- var offset = Platform.BYTE_ARRAY_OFFSET
- if (useColumnFamilies) {
- Platform.putShort(encodedBytes, Platform.BYTE_ARRAY_OFFSET,
virtualColFamilyId.get)
- offset = Platform.BYTE_ARRAY_OFFSET + offsetForColFamilyPrefix
- }
- (encodedBytes, offset)
- }
+/**
+ * The DataEncoder can encode UnsafeRows into raw bytes in two ways:
+ * - Using the direct byte layout of the UnsafeRow
+ * - Converting the UnsafeRow into an Avro row, and encoding that
+ * In both of these cases, the raw bytes that are written into RockDB have
+ * headers, footers and other metadata, but they also have data that is
provided
+ * by the callers. The metadata in each row does not need to be written as
Avro or UnsafeRow,
+ * but the actual data provided by the caller does.
+ */
+trait DataEncoder {
+ def encodeKey(row: UnsafeRow): Array[Byte]
+ def encodeRemainingKey(row: UnsafeRow): Array[Byte]
+ def encodePrefixKeyForRangeScan(row: UnsafeRow): Array[Byte]
+ def encodeValue(row: UnsafeRow): Array[Byte]
- /**
- * Get starting offset for decoding an encoded key byte array.
- */
- protected def decodeKeyStartOffset: Int = {
- if (useColumnFamilies) {
- Platform.BYTE_ARRAY_OFFSET + VIRTUAL_COL_FAMILY_PREFIX_BYTES
- } else Platform.BYTE_ARRAY_OFFSET
- }
+ def decodeKey(bytes: Array[Byte]): UnsafeRow
+ def decodeRemainingKey(bytes: Array[Byte]): UnsafeRow
+ def decodePrefixKeyForRangeScan(bytes: Array[Byte]): UnsafeRow
+ def decodeValue(bytes: Array[Byte]): UnsafeRow
Review Comment:
ditto
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala:
##########
@@ -615,29 +613,56 @@ object RocksDBStateStoreProvider {
private val AVRO_ENCODER_LIFETIME_HOURS = 1L
// Add the cache at companion object level so it persists across provider
instances
- private val avroEncoderMap: NonFateSharingCache[String, AvroEncoder] =
+ private val dataEncoderCache: NonFateSharingCache[String,
RocksDBDataEncoder] =
NonFateSharingCache(
maximumSize = MAX_AVRO_ENCODERS_IN_CACHE,
expireAfterAccessTime = AVRO_ENCODER_LIFETIME_HOURS,
expireAfterAccessTimeUnit = TimeUnit.HOURS
)
- def getAvroEnc(
+ /**
+ * Creates and returns a data encoder for the state store based on the
specified encoding type.
+ * This method handles caching of encoders to improve performance by reusing
encoder instances
+ * when possible.
+ *
+ * The method supports two encoding types:
+ * - Avro: Uses Apache Avro for serialization with schema evolution support
+ * - UnsafeRow: Uses Spark's internal row format for optimal performance
+ *
+ * @param stateStoreEncoding The encoding type to use ("avro" or "unsaferow")
+ * @param encoderCacheKey A unique key for caching the encoder instance,
typically combining
+ * query ID, operator ID, partition ID, and column
family name
+ * @param keyStateEncoderSpec Specification for how to encode keys,
including schema and any
+ * prefix/range scan requirements
+ * @param valueSchema The schema for the values to be encoded
+ * @return A RocksDBDataEncoder instance configured for the specified
encoding type
+ */
+ def getDataEncoder(
stateStoreEncoding: String,
- avroEncCacheKey: String,
+ encoderCacheKey: String,
keyStateEncoderSpec: KeyStateEncoderSpec,
- valueSchema: StructType): Option[AvroEncoder] = {
+ valueSchema: StructType): RocksDBDataEncoder = {
stateStoreEncoding match {
- case "avro" => Some(
- RocksDBStateStoreProvider.avroEncoderMap.get(
- avroEncCacheKey,
- new java.util.concurrent.Callable[AvroEncoder] {
- override def call(): AvroEncoder =
createAvroEnc(keyStateEncoderSpec, valueSchema)
+ case "avro" =>
+ RocksDBStateStoreProvider.dataEncoderCache.get(
+ encoderCacheKey,
+ new java.util.concurrent.Callable[AvroStateEncoder] {
+ override def call(): AvroStateEncoder = {
+ val avroEncoder = createAvroEnc(keyStateEncoderSpec, valueSchema)
+ new AvroStateEncoder(keyStateEncoderSpec, valueSchema,
avroEncoder)
+ }
+ }
+ )
+ case "unsaferow" =>
+ RocksDBStateStoreProvider.dataEncoderCache.get(
+ encoderCacheKey,
+ new java.util.concurrent.Callable[UnsafeRowDataEncoder] {
+ override def call(): UnsafeRowDataEncoder = {
+ new UnsafeRowDataEncoder(keyStateEncoderSpec, valueSchema)
+ }
Review Comment:
nit: maybe change the order of the match-case with the cache.get?
```scala
assert(Set("avro", "unsaferow").contains(stateStoreEncoding))
RocksDBStateStoreProvider.dataEncoderCache.get(
encoderCacheKey,
new java.util.concurrent.Callable[DataEncoder] { () =>
if (stateStoreEncoding == "avro") {
...
} else {
...
}
}
```
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala:
##########
@@ -800,156 +583,509 @@ class RangeKeyScanStateEncoder(
})
record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array()))
- val valueBuffer = ByteBuffer.allocate(8)
- valueBuffer.order(ByteOrder.BIG_ENDIAN)
- if ((rawBits & doubleSignBitMask) != 0) {
- val updatedVal = rawBits ^ doubleFlipBitMask
- valueBuffer.putDouble(longBitsToDouble(updatedVal))
- } else {
- valueBuffer.putDouble(doubleVal)
- }
- record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array()))
+ val valueBuffer = ByteBuffer.allocate(8)
+ valueBuffer.order(ByteOrder.BIG_ENDIAN)
+ if ((rawBits & doubleSignBitMask) != 0) {
+ val updatedVal = rawBits ^ doubleFlipBitMask
+ valueBuffer.putDouble(longBitsToDouble(updatedVal))
+ } else {
+ valueBuffer.putDouble(doubleVal)
+ }
+ record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array()))
+
+ case _ => throw new UnsupportedOperationException(
+ s"Range scan encoding not supported for data type:
${field.dataType}")
+ }
+ }
+ fieldIdx += 2
+ }
+
+ out.reset()
+ val writer = new GenericDatumWriter[GenericRecord](rangeScanAvroType)
+ val encoder = EncoderFactory.get().binaryEncoder(out, null)
+ writer.write(record, encoder)
+ encoder.flush()
+ out.toByteArray
+ }
+
+ override def encodeValue(row: UnsafeRow): Array[Byte] =
+ encodeUnsafeRowToAvro(row, avroEncoder.valueSerializer, valueAvroType, out)
+
+ override def decodeKey(bytes: Array[Byte]): UnsafeRow = {
+ keyStateEncoderSpec match {
+ case NoPrefixKeyStateEncoderSpec(_) =>
+ decodeFromAvroToUnsafeRow(bytes, avroEncoder.keyDeserializer,
keyAvroType, keyProj)
+ case PrefixKeyScanStateEncoderSpec(_, _) =>
+ decodeFromAvroToUnsafeRow(
+ bytes, avroEncoder.keyDeserializer, prefixKeyAvroType, prefixKeyProj)
+ case _ => throw unsupportedOperationForKeyStateEncoder("decodeKey")
+ }
+ }
+
+
+ override def decodeRemainingKey(bytes: Array[Byte]): UnsafeRow = {
+ keyStateEncoderSpec match {
+ case PrefixKeyScanStateEncoderSpec(_, _) =>
+ decodeFromAvroToUnsafeRow(bytes,
+ avroEncoder.suffixKeyDeserializer.get, remainingKeyAvroType,
remainingKeyAvroProjection)
+ case RangeKeyScanStateEncoderSpec(_, _) =>
+ decodeFromAvroToUnsafeRow(
+ bytes, avroEncoder.keyDeserializer, remainingKeyAvroType,
remainingKeyAvroProjection)
+ case _ => throw
unsupportedOperationForKeyStateEncoder("decodeRemainingKey")
+ }
+ }
+
+ /**
+ * Decodes an Avro-encoded byte array back into an UnsafeRow for range scan
operations.
+ *
+ * This method reverses the encoding process performed by
encodePrefixKeyForRangeScan:
+ * - Reads the marker byte to determine null status or sign
+ * - Reconstructs the original values from big-endian format
+ * - Handles special cases for floating point numbers by reversing bit
manipulations
+ *
+ * The decoding process preserves the original data types and values,
including:
+ * - Null values marked by nullValMarker
+ * - Sign information for numeric types
+ * - Proper restoration of negative floating point values
+ *
+ * @param bytes The Avro-encoded byte array to decode
+ * @param avroType The Avro schema defining the structure for decoding
+ * @return UnsafeRow containing the decoded data
+ * @throws UnsupportedOperationException if a field's data type is not
supported for range
+ * scan decoding
+ */
+ override def decodePrefixKeyForRangeScan(bytes: Array[Byte]): UnsafeRow = {
+ val reader = new GenericDatumReader[GenericRecord](rangeScanAvroType)
+ val decoder = DecoderFactory.get().binaryDecoder(bytes, 0, bytes.length,
null)
+ val record = reader.read(null, decoder)
+
+ val rowWriter = new UnsafeRowWriter(rangeScanKeyFieldsWithOrdinal.length)
+ rowWriter.resetRowWriter()
+
+ var fieldIdx = 0
+ rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case
(fieldWithOrdinal, idx) =>
+ val field = fieldWithOrdinal._1
+
+ val markerBytes = record.get(fieldIdx).asInstanceOf[ByteBuffer].array()
+ val markerBuf = ByteBuffer.wrap(markerBytes)
+ markerBuf.order(ByteOrder.BIG_ENDIAN)
+ val marker = markerBuf.get()
+
+ if (marker == nullValMarker) {
+ rowWriter.setNullAt(idx)
+ } else {
+ field.dataType match {
+ case BooleanType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ rowWriter.write(idx, bytes(0) == 1)
+
+ case ByteType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ rowWriter.write(idx, valueBuf.get())
+
+ case ShortType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ rowWriter.write(idx, valueBuf.getShort())
+
+ case IntegerType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ rowWriter.write(idx, valueBuf.getInt())
+
+ case LongType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ rowWriter.write(idx, valueBuf.getLong())
+
+ case FloatType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ if (marker == negativeValMarker) {
+ val floatVal = valueBuf.getFloat
+ val updatedVal = floatToRawIntBits(floatVal) ^ floatFlipBitMask
+ rowWriter.write(idx, intBitsToFloat(updatedVal))
+ } else {
+ rowWriter.write(idx, valueBuf.getFloat())
+ }
+
+ case DoubleType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ if (marker == negativeValMarker) {
+ val doubleVal = valueBuf.getDouble
+ val updatedVal = doubleToRawLongBits(doubleVal) ^
doubleFlipBitMask
+ rowWriter.write(idx, longBitsToDouble(updatedVal))
+ } else {
+ rowWriter.write(idx, valueBuf.getDouble())
+ }
+
+ case _ => throw new UnsupportedOperationException(
+ s"Range scan decoding not supported for data type:
${field.dataType}")
+ }
+ }
+ fieldIdx += 2
+ }
+
+ rowWriter.getRow()
+ }
+
+ override def decodeValue(bytes: Array[Byte]): UnsafeRow =
+ decodeFromAvroToUnsafeRow(
+ bytes, avroEncoder.valueDeserializer, valueAvroType, valueProj)
+}
+
+abstract class RocksDBKeyStateEncoderBase(
+ useColumnFamilies: Boolean,
+ virtualColFamilyId: Option[Short] = None) extends RocksDBKeyStateEncoder {
+ def offsetForColFamilyPrefix: Int =
+ if (useColumnFamilies) VIRTUAL_COL_FAMILY_PREFIX_BYTES else 0
+
+ val out = new ByteArrayOutputStream
+ /**
+ * Get Byte Array for the virtual column family id that is used as prefix for
+ * key state rows.
+ */
+ override def getColumnFamilyIdBytes(): Array[Byte] = {
+ assert(useColumnFamilies, "Cannot return virtual Column Family Id Bytes" +
+ " because multiple Column is not supported for this encoder")
+ val encodedBytes = new Array[Byte](VIRTUAL_COL_FAMILY_PREFIX_BYTES)
+ Platform.putShort(encodedBytes, Platform.BYTE_ARRAY_OFFSET,
virtualColFamilyId.get)
+ encodedBytes
+ }
+
+ /**
+ * Encode and put column family Id as a prefix to a pre-allocated byte array.
+ *
+ * @param numBytes - size of byte array to be created for storing key row
(without
+ * column family prefix)
+ * @return Array[Byte] for an array byte to put encoded key bytes
+ * Int for a starting offset to put the encoded key bytes
+ */
+ protected def encodeColumnFamilyPrefix(numBytes: Int): (Array[Byte], Int) = {
+ val encodedBytes = new Array[Byte](numBytes + offsetForColFamilyPrefix)
+ var offset = Platform.BYTE_ARRAY_OFFSET
+ if (useColumnFamilies) {
+ Platform.putShort(encodedBytes, Platform.BYTE_ARRAY_OFFSET,
virtualColFamilyId.get)
+ offset = Platform.BYTE_ARRAY_OFFSET + offsetForColFamilyPrefix
+ }
+ (encodedBytes, offset)
+ }
+
+ /**
+ * Get starting offset for decoding an encoded key byte array.
+ */
+ protected def decodeKeyStartOffset: Int = {
+ if (useColumnFamilies) {
+ Platform.BYTE_ARRAY_OFFSET + VIRTUAL_COL_FAMILY_PREFIX_BYTES
+ } else Platform.BYTE_ARRAY_OFFSET
+ }
+}
+
+object RocksDBStateEncoder extends Logging {
+ def getKeyEncoder(
+ dataEncoder: RocksDBDataEncoder,
+ keyStateEncoderSpec: KeyStateEncoderSpec,
+ useColumnFamilies: Boolean,
+ virtualColFamilyId: Option[Short] = None,
+ avroEnc: Option[AvroEncoder] = None): RocksDBKeyStateEncoder = {
+ // Return the key state encoder based on the requested type
+ keyStateEncoderSpec match {
+ case NoPrefixKeyStateEncoderSpec(keySchema) =>
+ new NoPrefixKeyStateEncoder(dataEncoder, keySchema, useColumnFamilies,
virtualColFamilyId)
+
+ case PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey) =>
+ new PrefixKeyScanStateEncoder(dataEncoder, keySchema, numColsPrefixKey,
+ useColumnFamilies, virtualColFamilyId)
+
+ case RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) =>
+ new RangeKeyScanStateEncoder(dataEncoder, keySchema, orderingOrdinals,
+ useColumnFamilies, virtualColFamilyId)
+
+ case _ =>
+ throw new IllegalArgumentException(s"Unsupported key state encoder
spec: " +
+ s"$keyStateEncoderSpec")
+ }
+ }
+
+ def getValueEncoder(
+ dataEncoder: RocksDBDataEncoder,
+ valueSchema: StructType,
+ useMultipleValuesPerKey: Boolean,
+ avroEnc: Option[AvroEncoder] = None): RocksDBValueStateEncoder = {
Review Comment:
ditto
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala:
##########
@@ -51,93 +51,56 @@ sealed trait RocksDBValueStateEncoder {
def decodeValues(valueBytes: Array[Byte]): Iterator[UnsafeRow]
}
-abstract class RocksDBKeyStateEncoderBase(
- useColumnFamilies: Boolean,
- virtualColFamilyId: Option[Short] = None) extends RocksDBKeyStateEncoder {
- def offsetForColFamilyPrefix: Int =
- if (useColumnFamilies) VIRTUAL_COL_FAMILY_PREFIX_BYTES else 0
-
- val out = new ByteArrayOutputStream
- /**
- * Get Byte Array for the virtual column family id that is used as prefix for
- * key state rows.
- */
- override def getColumnFamilyIdBytes(): Array[Byte] = {
- assert(useColumnFamilies, "Cannot return virtual Column Family Id Bytes" +
- " because multiple Column is not supported for this encoder")
- val encodedBytes = new Array[Byte](VIRTUAL_COL_FAMILY_PREFIX_BYTES)
- Platform.putShort(encodedBytes, Platform.BYTE_ARRAY_OFFSET,
virtualColFamilyId.get)
- encodedBytes
- }
-
- /**
- * Encode and put column family Id as a prefix to a pre-allocated byte array.
- *
- * @param numBytes - size of byte array to be created for storing key row
(without
- * column family prefix)
- * @return Array[Byte] for an array byte to put encoded key bytes
- * Int for a starting offset to put the encoded key bytes
- */
- protected def encodeColumnFamilyPrefix(numBytes: Int): (Array[Byte], Int) = {
- val encodedBytes = new Array[Byte](numBytes + offsetForColFamilyPrefix)
- var offset = Platform.BYTE_ARRAY_OFFSET
- if (useColumnFamilies) {
- Platform.putShort(encodedBytes, Platform.BYTE_ARRAY_OFFSET,
virtualColFamilyId.get)
- offset = Platform.BYTE_ARRAY_OFFSET + offsetForColFamilyPrefix
- }
- (encodedBytes, offset)
- }
+/**
+ * The DataEncoder can encode UnsafeRows into raw bytes in two ways:
+ * - Using the direct byte layout of the UnsafeRow
+ * - Converting the UnsafeRow into an Avro row, and encoding that
+ * In both of these cases, the raw bytes that are written into RockDB have
+ * headers, footers and other metadata, but they also have data that is
provided
+ * by the callers. The metadata in each row does not need to be written as
Avro or UnsafeRow,
+ * but the actual data provided by the caller does.
+ */
+trait DataEncoder {
+ def encodeKey(row: UnsafeRow): Array[Byte]
+ def encodeRemainingKey(row: UnsafeRow): Array[Byte]
+ def encodePrefixKeyForRangeScan(row: UnsafeRow): Array[Byte]
+ def encodeValue(row: UnsafeRow): Array[Byte]
Review Comment:
can you add docs for these actually? What additional information would be
needed for these?
For example, when would `encodeRemainingKey` be used? What schema
information needs to be passed in for this?
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala:
##########
@@ -800,156 +583,509 @@ class RangeKeyScanStateEncoder(
})
record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array()))
- val valueBuffer = ByteBuffer.allocate(8)
- valueBuffer.order(ByteOrder.BIG_ENDIAN)
- if ((rawBits & doubleSignBitMask) != 0) {
- val updatedVal = rawBits ^ doubleFlipBitMask
- valueBuffer.putDouble(longBitsToDouble(updatedVal))
- } else {
- valueBuffer.putDouble(doubleVal)
- }
- record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array()))
+ val valueBuffer = ByteBuffer.allocate(8)
+ valueBuffer.order(ByteOrder.BIG_ENDIAN)
+ if ((rawBits & doubleSignBitMask) != 0) {
+ val updatedVal = rawBits ^ doubleFlipBitMask
+ valueBuffer.putDouble(longBitsToDouble(updatedVal))
+ } else {
+ valueBuffer.putDouble(doubleVal)
+ }
+ record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array()))
+
+ case _ => throw new UnsupportedOperationException(
+ s"Range scan encoding not supported for data type:
${field.dataType}")
+ }
+ }
+ fieldIdx += 2
+ }
+
+ out.reset()
+ val writer = new GenericDatumWriter[GenericRecord](rangeScanAvroType)
+ val encoder = EncoderFactory.get().binaryEncoder(out, null)
+ writer.write(record, encoder)
+ encoder.flush()
+ out.toByteArray
+ }
+
+ override def encodeValue(row: UnsafeRow): Array[Byte] =
+ encodeUnsafeRowToAvro(row, avroEncoder.valueSerializer, valueAvroType, out)
+
+ override def decodeKey(bytes: Array[Byte]): UnsafeRow = {
+ keyStateEncoderSpec match {
+ case NoPrefixKeyStateEncoderSpec(_) =>
+ decodeFromAvroToUnsafeRow(bytes, avroEncoder.keyDeserializer,
keyAvroType, keyProj)
+ case PrefixKeyScanStateEncoderSpec(_, _) =>
+ decodeFromAvroToUnsafeRow(
+ bytes, avroEncoder.keyDeserializer, prefixKeyAvroType, prefixKeyProj)
+ case _ => throw unsupportedOperationForKeyStateEncoder("decodeKey")
+ }
+ }
+
+
+ override def decodeRemainingKey(bytes: Array[Byte]): UnsafeRow = {
+ keyStateEncoderSpec match {
+ case PrefixKeyScanStateEncoderSpec(_, _) =>
+ decodeFromAvroToUnsafeRow(bytes,
+ avroEncoder.suffixKeyDeserializer.get, remainingKeyAvroType,
remainingKeyAvroProjection)
+ case RangeKeyScanStateEncoderSpec(_, _) =>
+ decodeFromAvroToUnsafeRow(
+ bytes, avroEncoder.keyDeserializer, remainingKeyAvroType,
remainingKeyAvroProjection)
+ case _ => throw
unsupportedOperationForKeyStateEncoder("decodeRemainingKey")
+ }
+ }
+
+ /**
+ * Decodes an Avro-encoded byte array back into an UnsafeRow for range scan
operations.
+ *
+ * This method reverses the encoding process performed by
encodePrefixKeyForRangeScan:
+ * - Reads the marker byte to determine null status or sign
+ * - Reconstructs the original values from big-endian format
+ * - Handles special cases for floating point numbers by reversing bit
manipulations
+ *
+ * The decoding process preserves the original data types and values,
including:
+ * - Null values marked by nullValMarker
+ * - Sign information for numeric types
+ * - Proper restoration of negative floating point values
+ *
+ * @param bytes The Avro-encoded byte array to decode
+ * @param avroType The Avro schema defining the structure for decoding
+ * @return UnsafeRow containing the decoded data
+ * @throws UnsupportedOperationException if a field's data type is not
supported for range
+ * scan decoding
+ */
+ override def decodePrefixKeyForRangeScan(bytes: Array[Byte]): UnsafeRow = {
+ val reader = new GenericDatumReader[GenericRecord](rangeScanAvroType)
+ val decoder = DecoderFactory.get().binaryDecoder(bytes, 0, bytes.length,
null)
+ val record = reader.read(null, decoder)
+
+ val rowWriter = new UnsafeRowWriter(rangeScanKeyFieldsWithOrdinal.length)
+ rowWriter.resetRowWriter()
+
+ var fieldIdx = 0
+ rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case
(fieldWithOrdinal, idx) =>
+ val field = fieldWithOrdinal._1
+
+ val markerBytes = record.get(fieldIdx).asInstanceOf[ByteBuffer].array()
+ val markerBuf = ByteBuffer.wrap(markerBytes)
+ markerBuf.order(ByteOrder.BIG_ENDIAN)
+ val marker = markerBuf.get()
+
+ if (marker == nullValMarker) {
+ rowWriter.setNullAt(idx)
+ } else {
+ field.dataType match {
+ case BooleanType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ rowWriter.write(idx, bytes(0) == 1)
+
+ case ByteType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ rowWriter.write(idx, valueBuf.get())
+
+ case ShortType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ rowWriter.write(idx, valueBuf.getShort())
+
+ case IntegerType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ rowWriter.write(idx, valueBuf.getInt())
+
+ case LongType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ rowWriter.write(idx, valueBuf.getLong())
+
+ case FloatType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ if (marker == negativeValMarker) {
+ val floatVal = valueBuf.getFloat
+ val updatedVal = floatToRawIntBits(floatVal) ^ floatFlipBitMask
+ rowWriter.write(idx, intBitsToFloat(updatedVal))
+ } else {
+ rowWriter.write(idx, valueBuf.getFloat())
+ }
+
+ case DoubleType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ if (marker == negativeValMarker) {
+ val doubleVal = valueBuf.getDouble
+ val updatedVal = doubleToRawLongBits(doubleVal) ^
doubleFlipBitMask
+ rowWriter.write(idx, longBitsToDouble(updatedVal))
+ } else {
+ rowWriter.write(idx, valueBuf.getDouble())
+ }
+
+ case _ => throw new UnsupportedOperationException(
+ s"Range scan decoding not supported for data type:
${field.dataType}")
+ }
+ }
+ fieldIdx += 2
+ }
+
+ rowWriter.getRow()
+ }
+
+ override def decodeValue(bytes: Array[Byte]): UnsafeRow =
+ decodeFromAvroToUnsafeRow(
+ bytes, avroEncoder.valueDeserializer, valueAvroType, valueProj)
+}
+
+abstract class RocksDBKeyStateEncoderBase(
+ useColumnFamilies: Boolean,
+ virtualColFamilyId: Option[Short] = None) extends RocksDBKeyStateEncoder {
+ def offsetForColFamilyPrefix: Int =
+ if (useColumnFamilies) VIRTUAL_COL_FAMILY_PREFIX_BYTES else 0
+
+ val out = new ByteArrayOutputStream
+ /**
+ * Get Byte Array for the virtual column family id that is used as prefix for
+ * key state rows.
+ */
+ override def getColumnFamilyIdBytes(): Array[Byte] = {
+ assert(useColumnFamilies, "Cannot return virtual Column Family Id Bytes" +
+ " because multiple Column is not supported for this encoder")
+ val encodedBytes = new Array[Byte](VIRTUAL_COL_FAMILY_PREFIX_BYTES)
+ Platform.putShort(encodedBytes, Platform.BYTE_ARRAY_OFFSET,
virtualColFamilyId.get)
+ encodedBytes
+ }
+
+ /**
+ * Encode and put column family Id as a prefix to a pre-allocated byte array.
+ *
+ * @param numBytes - size of byte array to be created for storing key row
(without
+ * column family prefix)
+ * @return Array[Byte] for an array byte to put encoded key bytes
+ * Int for a starting offset to put the encoded key bytes
+ */
+ protected def encodeColumnFamilyPrefix(numBytes: Int): (Array[Byte], Int) = {
+ val encodedBytes = new Array[Byte](numBytes + offsetForColFamilyPrefix)
+ var offset = Platform.BYTE_ARRAY_OFFSET
+ if (useColumnFamilies) {
+ Platform.putShort(encodedBytes, Platform.BYTE_ARRAY_OFFSET,
virtualColFamilyId.get)
+ offset = Platform.BYTE_ARRAY_OFFSET + offsetForColFamilyPrefix
+ }
+ (encodedBytes, offset)
+ }
+
+ /**
+ * Get starting offset for decoding an encoded key byte array.
+ */
+ protected def decodeKeyStartOffset: Int = {
+ if (useColumnFamilies) {
+ Platform.BYTE_ARRAY_OFFSET + VIRTUAL_COL_FAMILY_PREFIX_BYTES
+ } else Platform.BYTE_ARRAY_OFFSET
+ }
+}
+
+object RocksDBStateEncoder extends Logging {
Review Comment:
Can you add scaladoc for all the methods below?
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala:
##########
@@ -800,156 +583,509 @@ class RangeKeyScanStateEncoder(
})
record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array()))
- val valueBuffer = ByteBuffer.allocate(8)
- valueBuffer.order(ByteOrder.BIG_ENDIAN)
- if ((rawBits & doubleSignBitMask) != 0) {
- val updatedVal = rawBits ^ doubleFlipBitMask
- valueBuffer.putDouble(longBitsToDouble(updatedVal))
- } else {
- valueBuffer.putDouble(doubleVal)
- }
- record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array()))
+ val valueBuffer = ByteBuffer.allocate(8)
+ valueBuffer.order(ByteOrder.BIG_ENDIAN)
+ if ((rawBits & doubleSignBitMask) != 0) {
+ val updatedVal = rawBits ^ doubleFlipBitMask
+ valueBuffer.putDouble(longBitsToDouble(updatedVal))
+ } else {
+ valueBuffer.putDouble(doubleVal)
+ }
+ record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array()))
+
+ case _ => throw new UnsupportedOperationException(
+ s"Range scan encoding not supported for data type:
${field.dataType}")
+ }
+ }
+ fieldIdx += 2
+ }
+
+ out.reset()
+ val writer = new GenericDatumWriter[GenericRecord](rangeScanAvroType)
+ val encoder = EncoderFactory.get().binaryEncoder(out, null)
+ writer.write(record, encoder)
+ encoder.flush()
+ out.toByteArray
+ }
+
+ override def encodeValue(row: UnsafeRow): Array[Byte] =
+ encodeUnsafeRowToAvro(row, avroEncoder.valueSerializer, valueAvroType, out)
+
+ override def decodeKey(bytes: Array[Byte]): UnsafeRow = {
+ keyStateEncoderSpec match {
+ case NoPrefixKeyStateEncoderSpec(_) =>
+ decodeFromAvroToUnsafeRow(bytes, avroEncoder.keyDeserializer,
keyAvroType, keyProj)
+ case PrefixKeyScanStateEncoderSpec(_, _) =>
+ decodeFromAvroToUnsafeRow(
+ bytes, avroEncoder.keyDeserializer, prefixKeyAvroType, prefixKeyProj)
+ case _ => throw unsupportedOperationForKeyStateEncoder("decodeKey")
+ }
+ }
+
+
+ override def decodeRemainingKey(bytes: Array[Byte]): UnsafeRow = {
+ keyStateEncoderSpec match {
+ case PrefixKeyScanStateEncoderSpec(_, _) =>
+ decodeFromAvroToUnsafeRow(bytes,
+ avroEncoder.suffixKeyDeserializer.get, remainingKeyAvroType,
remainingKeyAvroProjection)
+ case RangeKeyScanStateEncoderSpec(_, _) =>
+ decodeFromAvroToUnsafeRow(
+ bytes, avroEncoder.keyDeserializer, remainingKeyAvroType,
remainingKeyAvroProjection)
+ case _ => throw
unsupportedOperationForKeyStateEncoder("decodeRemainingKey")
+ }
+ }
+
+ /**
+ * Decodes an Avro-encoded byte array back into an UnsafeRow for range scan
operations.
+ *
+ * This method reverses the encoding process performed by
encodePrefixKeyForRangeScan:
+ * - Reads the marker byte to determine null status or sign
+ * - Reconstructs the original values from big-endian format
+ * - Handles special cases for floating point numbers by reversing bit
manipulations
+ *
+ * The decoding process preserves the original data types and values,
including:
+ * - Null values marked by nullValMarker
+ * - Sign information for numeric types
+ * - Proper restoration of negative floating point values
+ *
+ * @param bytes The Avro-encoded byte array to decode
+ * @param avroType The Avro schema defining the structure for decoding
+ * @return UnsafeRow containing the decoded data
+ * @throws UnsupportedOperationException if a field's data type is not
supported for range
+ * scan decoding
+ */
+ override def decodePrefixKeyForRangeScan(bytes: Array[Byte]): UnsafeRow = {
+ val reader = new GenericDatumReader[GenericRecord](rangeScanAvroType)
+ val decoder = DecoderFactory.get().binaryDecoder(bytes, 0, bytes.length,
null)
+ val record = reader.read(null, decoder)
+
+ val rowWriter = new UnsafeRowWriter(rangeScanKeyFieldsWithOrdinal.length)
+ rowWriter.resetRowWriter()
+
+ var fieldIdx = 0
+ rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case
(fieldWithOrdinal, idx) =>
+ val field = fieldWithOrdinal._1
+
+ val markerBytes = record.get(fieldIdx).asInstanceOf[ByteBuffer].array()
+ val markerBuf = ByteBuffer.wrap(markerBytes)
+ markerBuf.order(ByteOrder.BIG_ENDIAN)
+ val marker = markerBuf.get()
+
+ if (marker == nullValMarker) {
+ rowWriter.setNullAt(idx)
+ } else {
+ field.dataType match {
+ case BooleanType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ rowWriter.write(idx, bytes(0) == 1)
+
+ case ByteType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ rowWriter.write(idx, valueBuf.get())
+
+ case ShortType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ rowWriter.write(idx, valueBuf.getShort())
+
+ case IntegerType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ rowWriter.write(idx, valueBuf.getInt())
+
+ case LongType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ rowWriter.write(idx, valueBuf.getLong())
+
+ case FloatType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ if (marker == negativeValMarker) {
+ val floatVal = valueBuf.getFloat
+ val updatedVal = floatToRawIntBits(floatVal) ^ floatFlipBitMask
+ rowWriter.write(idx, intBitsToFloat(updatedVal))
+ } else {
+ rowWriter.write(idx, valueBuf.getFloat())
+ }
+
+ case DoubleType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ if (marker == negativeValMarker) {
+ val doubleVal = valueBuf.getDouble
+ val updatedVal = doubleToRawLongBits(doubleVal) ^
doubleFlipBitMask
+ rowWriter.write(idx, longBitsToDouble(updatedVal))
+ } else {
+ rowWriter.write(idx, valueBuf.getDouble())
+ }
+
+ case _ => throw new UnsupportedOperationException(
+ s"Range scan decoding not supported for data type:
${field.dataType}")
+ }
+ }
+ fieldIdx += 2
+ }
+
+ rowWriter.getRow()
+ }
+
+ override def decodeValue(bytes: Array[Byte]): UnsafeRow =
+ decodeFromAvroToUnsafeRow(
+ bytes, avroEncoder.valueDeserializer, valueAvroType, valueProj)
+}
+
+abstract class RocksDBKeyStateEncoderBase(
+ useColumnFamilies: Boolean,
+ virtualColFamilyId: Option[Short] = None) extends RocksDBKeyStateEncoder {
+ def offsetForColFamilyPrefix: Int =
+ if (useColumnFamilies) VIRTUAL_COL_FAMILY_PREFIX_BYTES else 0
+
+ val out = new ByteArrayOutputStream
+ /**
+ * Get Byte Array for the virtual column family id that is used as prefix for
+ * key state rows.
+ */
+ override def getColumnFamilyIdBytes(): Array[Byte] = {
+ assert(useColumnFamilies, "Cannot return virtual Column Family Id Bytes" +
+ " because multiple Column is not supported for this encoder")
+ val encodedBytes = new Array[Byte](VIRTUAL_COL_FAMILY_PREFIX_BYTES)
+ Platform.putShort(encodedBytes, Platform.BYTE_ARRAY_OFFSET,
virtualColFamilyId.get)
+ encodedBytes
+ }
+
+ /**
+ * Encode and put column family Id as a prefix to a pre-allocated byte array.
+ *
+ * @param numBytes - size of byte array to be created for storing key row
(without
+ * column family prefix)
+ * @return Array[Byte] for an array byte to put encoded key bytes
+ * Int for a starting offset to put the encoded key bytes
+ */
+ protected def encodeColumnFamilyPrefix(numBytes: Int): (Array[Byte], Int) = {
+ val encodedBytes = new Array[Byte](numBytes + offsetForColFamilyPrefix)
+ var offset = Platform.BYTE_ARRAY_OFFSET
+ if (useColumnFamilies) {
+ Platform.putShort(encodedBytes, Platform.BYTE_ARRAY_OFFSET,
virtualColFamilyId.get)
+ offset = Platform.BYTE_ARRAY_OFFSET + offsetForColFamilyPrefix
+ }
+ (encodedBytes, offset)
+ }
+
+ /**
+ * Get starting offset for decoding an encoded key byte array.
+ */
+ protected def decodeKeyStartOffset: Int = {
+ if (useColumnFamilies) {
+ Platform.BYTE_ARRAY_OFFSET + VIRTUAL_COL_FAMILY_PREFIX_BYTES
+ } else Platform.BYTE_ARRAY_OFFSET
+ }
+}
+
+object RocksDBStateEncoder extends Logging {
+ def getKeyEncoder(
+ dataEncoder: RocksDBDataEncoder,
+ keyStateEncoderSpec: KeyStateEncoderSpec,
+ useColumnFamilies: Boolean,
+ virtualColFamilyId: Option[Short] = None,
+ avroEnc: Option[AvroEncoder] = None): RocksDBKeyStateEncoder = {
Review Comment:
you can remove this line now?
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala:
##########
@@ -800,156 +583,509 @@ class RangeKeyScanStateEncoder(
})
record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array()))
- val valueBuffer = ByteBuffer.allocate(8)
- valueBuffer.order(ByteOrder.BIG_ENDIAN)
- if ((rawBits & doubleSignBitMask) != 0) {
- val updatedVal = rawBits ^ doubleFlipBitMask
- valueBuffer.putDouble(longBitsToDouble(updatedVal))
- } else {
- valueBuffer.putDouble(doubleVal)
- }
- record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array()))
+ val valueBuffer = ByteBuffer.allocate(8)
+ valueBuffer.order(ByteOrder.BIG_ENDIAN)
+ if ((rawBits & doubleSignBitMask) != 0) {
+ val updatedVal = rawBits ^ doubleFlipBitMask
+ valueBuffer.putDouble(longBitsToDouble(updatedVal))
+ } else {
+ valueBuffer.putDouble(doubleVal)
+ }
+ record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array()))
+
+ case _ => throw new UnsupportedOperationException(
+ s"Range scan encoding not supported for data type:
${field.dataType}")
+ }
+ }
+ fieldIdx += 2
+ }
+
+ out.reset()
+ val writer = new GenericDatumWriter[GenericRecord](rangeScanAvroType)
+ val encoder = EncoderFactory.get().binaryEncoder(out, null)
+ writer.write(record, encoder)
+ encoder.flush()
+ out.toByteArray
+ }
+
+ override def encodeValue(row: UnsafeRow): Array[Byte] =
+ encodeUnsafeRowToAvro(row, avroEncoder.valueSerializer, valueAvroType, out)
+
+ override def decodeKey(bytes: Array[Byte]): UnsafeRow = {
+ keyStateEncoderSpec match {
+ case NoPrefixKeyStateEncoderSpec(_) =>
+ decodeFromAvroToUnsafeRow(bytes, avroEncoder.keyDeserializer,
keyAvroType, keyProj)
+ case PrefixKeyScanStateEncoderSpec(_, _) =>
+ decodeFromAvroToUnsafeRow(
+ bytes, avroEncoder.keyDeserializer, prefixKeyAvroType, prefixKeyProj)
+ case _ => throw unsupportedOperationForKeyStateEncoder("decodeKey")
+ }
+ }
+
+
+ override def decodeRemainingKey(bytes: Array[Byte]): UnsafeRow = {
+ keyStateEncoderSpec match {
+ case PrefixKeyScanStateEncoderSpec(_, _) =>
+ decodeFromAvroToUnsafeRow(bytes,
+ avroEncoder.suffixKeyDeserializer.get, remainingKeyAvroType,
remainingKeyAvroProjection)
+ case RangeKeyScanStateEncoderSpec(_, _) =>
+ decodeFromAvroToUnsafeRow(
+ bytes, avroEncoder.keyDeserializer, remainingKeyAvroType,
remainingKeyAvroProjection)
+ case _ => throw
unsupportedOperationForKeyStateEncoder("decodeRemainingKey")
+ }
+ }
+
+ /**
+ * Decodes an Avro-encoded byte array back into an UnsafeRow for range scan
operations.
+ *
+ * This method reverses the encoding process performed by
encodePrefixKeyForRangeScan:
+ * - Reads the marker byte to determine null status or sign
+ * - Reconstructs the original values from big-endian format
+ * - Handles special cases for floating point numbers by reversing bit
manipulations
+ *
+ * The decoding process preserves the original data types and values,
including:
+ * - Null values marked by nullValMarker
+ * - Sign information for numeric types
+ * - Proper restoration of negative floating point values
+ *
+ * @param bytes The Avro-encoded byte array to decode
+ * @param avroType The Avro schema defining the structure for decoding
+ * @return UnsafeRow containing the decoded data
+ * @throws UnsupportedOperationException if a field's data type is not
supported for range
+ * scan decoding
+ */
+ override def decodePrefixKeyForRangeScan(bytes: Array[Byte]): UnsafeRow = {
+ val reader = new GenericDatumReader[GenericRecord](rangeScanAvroType)
+ val decoder = DecoderFactory.get().binaryDecoder(bytes, 0, bytes.length,
null)
+ val record = reader.read(null, decoder)
+
+ val rowWriter = new UnsafeRowWriter(rangeScanKeyFieldsWithOrdinal.length)
+ rowWriter.resetRowWriter()
+
+ var fieldIdx = 0
+ rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case
(fieldWithOrdinal, idx) =>
+ val field = fieldWithOrdinal._1
+
+ val markerBytes = record.get(fieldIdx).asInstanceOf[ByteBuffer].array()
+ val markerBuf = ByteBuffer.wrap(markerBytes)
+ markerBuf.order(ByteOrder.BIG_ENDIAN)
+ val marker = markerBuf.get()
+
+ if (marker == nullValMarker) {
+ rowWriter.setNullAt(idx)
+ } else {
+ field.dataType match {
+ case BooleanType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ rowWriter.write(idx, bytes(0) == 1)
+
+ case ByteType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ rowWriter.write(idx, valueBuf.get())
+
+ case ShortType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ rowWriter.write(idx, valueBuf.getShort())
+
+ case IntegerType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ rowWriter.write(idx, valueBuf.getInt())
+
+ case LongType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ rowWriter.write(idx, valueBuf.getLong())
+
+ case FloatType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ if (marker == negativeValMarker) {
+ val floatVal = valueBuf.getFloat
+ val updatedVal = floatToRawIntBits(floatVal) ^ floatFlipBitMask
+ rowWriter.write(idx, intBitsToFloat(updatedVal))
+ } else {
+ rowWriter.write(idx, valueBuf.getFloat())
+ }
+
+ case DoubleType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ if (marker == negativeValMarker) {
+ val doubleVal = valueBuf.getDouble
+ val updatedVal = doubleToRawLongBits(doubleVal) ^
doubleFlipBitMask
+ rowWriter.write(idx, longBitsToDouble(updatedVal))
+ } else {
+ rowWriter.write(idx, valueBuf.getDouble())
+ }
+
+ case _ => throw new UnsupportedOperationException(
+ s"Range scan decoding not supported for data type:
${field.dataType}")
+ }
+ }
+ fieldIdx += 2
+ }
+
+ rowWriter.getRow()
+ }
+
+ override def decodeValue(bytes: Array[Byte]): UnsafeRow =
+ decodeFromAvroToUnsafeRow(
+ bytes, avroEncoder.valueDeserializer, valueAvroType, valueProj)
+}
+
+abstract class RocksDBKeyStateEncoderBase(
+ useColumnFamilies: Boolean,
+ virtualColFamilyId: Option[Short] = None) extends RocksDBKeyStateEncoder {
+ def offsetForColFamilyPrefix: Int =
+ if (useColumnFamilies) VIRTUAL_COL_FAMILY_PREFIX_BYTES else 0
+
+ val out = new ByteArrayOutputStream
+ /**
+ * Get Byte Array for the virtual column family id that is used as prefix for
+ * key state rows.
+ */
+ override def getColumnFamilyIdBytes(): Array[Byte] = {
+ assert(useColumnFamilies, "Cannot return virtual Column Family Id Bytes" +
+ " because multiple Column is not supported for this encoder")
+ val encodedBytes = new Array[Byte](VIRTUAL_COL_FAMILY_PREFIX_BYTES)
+ Platform.putShort(encodedBytes, Platform.BYTE_ARRAY_OFFSET,
virtualColFamilyId.get)
+ encodedBytes
+ }
+
+ /**
+ * Encode and put column family Id as a prefix to a pre-allocated byte array.
+ *
+ * @param numBytes - size of byte array to be created for storing key row
(without
+ * column family prefix)
+ * @return Array[Byte] for an array byte to put encoded key bytes
+ * Int for a starting offset to put the encoded key bytes
+ */
+ protected def encodeColumnFamilyPrefix(numBytes: Int): (Array[Byte], Int) = {
+ val encodedBytes = new Array[Byte](numBytes + offsetForColFamilyPrefix)
+ var offset = Platform.BYTE_ARRAY_OFFSET
+ if (useColumnFamilies) {
+ Platform.putShort(encodedBytes, Platform.BYTE_ARRAY_OFFSET,
virtualColFamilyId.get)
+ offset = Platform.BYTE_ARRAY_OFFSET + offsetForColFamilyPrefix
+ }
+ (encodedBytes, offset)
+ }
+
+ /**
+ * Get starting offset for decoding an encoded key byte array.
+ */
+ protected def decodeKeyStartOffset: Int = {
+ if (useColumnFamilies) {
+ Platform.BYTE_ARRAY_OFFSET + VIRTUAL_COL_FAMILY_PREFIX_BYTES
+ } else Platform.BYTE_ARRAY_OFFSET
+ }
+}
+
+object RocksDBStateEncoder extends Logging {
+ def getKeyEncoder(
+ dataEncoder: RocksDBDataEncoder,
+ keyStateEncoderSpec: KeyStateEncoderSpec,
+ useColumnFamilies: Boolean,
+ virtualColFamilyId: Option[Short] = None,
+ avroEnc: Option[AvroEncoder] = None): RocksDBKeyStateEncoder = {
+ // Return the key state encoder based on the requested type
+ keyStateEncoderSpec match {
+ case NoPrefixKeyStateEncoderSpec(keySchema) =>
+ new NoPrefixKeyStateEncoder(dataEncoder, keySchema, useColumnFamilies,
virtualColFamilyId)
+
+ case PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey) =>
+ new PrefixKeyScanStateEncoder(dataEncoder, keySchema, numColsPrefixKey,
+ useColumnFamilies, virtualColFamilyId)
+
+ case RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) =>
+ new RangeKeyScanStateEncoder(dataEncoder, keySchema, orderingOrdinals,
+ useColumnFamilies, virtualColFamilyId)
+
+ case _ =>
+ throw new IllegalArgumentException(s"Unsupported key state encoder
spec: " +
+ s"$keyStateEncoderSpec")
+ }
+ }
Review Comment:
should the encoderSpec's expose a method instead:
```
def toEncoder(dataEncoder: DataEncoder, useColumnFamilies,
virtualColFamilyId): RocksDBKeyStateEncoderBase
```
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala:
##########
@@ -800,156 +583,509 @@ class RangeKeyScanStateEncoder(
})
record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array()))
- val valueBuffer = ByteBuffer.allocate(8)
- valueBuffer.order(ByteOrder.BIG_ENDIAN)
- if ((rawBits & doubleSignBitMask) != 0) {
- val updatedVal = rawBits ^ doubleFlipBitMask
- valueBuffer.putDouble(longBitsToDouble(updatedVal))
- } else {
- valueBuffer.putDouble(doubleVal)
- }
- record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array()))
+ val valueBuffer = ByteBuffer.allocate(8)
+ valueBuffer.order(ByteOrder.BIG_ENDIAN)
+ if ((rawBits & doubleSignBitMask) != 0) {
+ val updatedVal = rawBits ^ doubleFlipBitMask
+ valueBuffer.putDouble(longBitsToDouble(updatedVal))
+ } else {
+ valueBuffer.putDouble(doubleVal)
+ }
+ record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array()))
+
+ case _ => throw new UnsupportedOperationException(
+ s"Range scan encoding not supported for data type:
${field.dataType}")
+ }
+ }
+ fieldIdx += 2
+ }
+
+ out.reset()
+ val writer = new GenericDatumWriter[GenericRecord](rangeScanAvroType)
+ val encoder = EncoderFactory.get().binaryEncoder(out, null)
+ writer.write(record, encoder)
+ encoder.flush()
+ out.toByteArray
+ }
+
+ override def encodeValue(row: UnsafeRow): Array[Byte] =
+ encodeUnsafeRowToAvro(row, avroEncoder.valueSerializer, valueAvroType, out)
+
+ override def decodeKey(bytes: Array[Byte]): UnsafeRow = {
+ keyStateEncoderSpec match {
+ case NoPrefixKeyStateEncoderSpec(_) =>
+ decodeFromAvroToUnsafeRow(bytes, avroEncoder.keyDeserializer,
keyAvroType, keyProj)
+ case PrefixKeyScanStateEncoderSpec(_, _) =>
+ decodeFromAvroToUnsafeRow(
+ bytes, avroEncoder.keyDeserializer, prefixKeyAvroType, prefixKeyProj)
+ case _ => throw unsupportedOperationForKeyStateEncoder("decodeKey")
+ }
+ }
+
+
+ override def decodeRemainingKey(bytes: Array[Byte]): UnsafeRow = {
+ keyStateEncoderSpec match {
+ case PrefixKeyScanStateEncoderSpec(_, _) =>
+ decodeFromAvroToUnsafeRow(bytes,
+ avroEncoder.suffixKeyDeserializer.get, remainingKeyAvroType,
remainingKeyAvroProjection)
+ case RangeKeyScanStateEncoderSpec(_, _) =>
+ decodeFromAvroToUnsafeRow(
+ bytes, avroEncoder.keyDeserializer, remainingKeyAvroType,
remainingKeyAvroProjection)
+ case _ => throw
unsupportedOperationForKeyStateEncoder("decodeRemainingKey")
+ }
+ }
+
+ /**
+ * Decodes an Avro-encoded byte array back into an UnsafeRow for range scan
operations.
+ *
+ * This method reverses the encoding process performed by
encodePrefixKeyForRangeScan:
+ * - Reads the marker byte to determine null status or sign
+ * - Reconstructs the original values from big-endian format
+ * - Handles special cases for floating point numbers by reversing bit
manipulations
+ *
+ * The decoding process preserves the original data types and values,
including:
+ * - Null values marked by nullValMarker
+ * - Sign information for numeric types
+ * - Proper restoration of negative floating point values
+ *
+ * @param bytes The Avro-encoded byte array to decode
+ * @param avroType The Avro schema defining the structure for decoding
+ * @return UnsafeRow containing the decoded data
+ * @throws UnsupportedOperationException if a field's data type is not
supported for range
+ * scan decoding
+ */
+ override def decodePrefixKeyForRangeScan(bytes: Array[Byte]): UnsafeRow = {
+ val reader = new GenericDatumReader[GenericRecord](rangeScanAvroType)
+ val decoder = DecoderFactory.get().binaryDecoder(bytes, 0, bytes.length,
null)
+ val record = reader.read(null, decoder)
+
+ val rowWriter = new UnsafeRowWriter(rangeScanKeyFieldsWithOrdinal.length)
+ rowWriter.resetRowWriter()
+
+ var fieldIdx = 0
+ rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case
(fieldWithOrdinal, idx) =>
+ val field = fieldWithOrdinal._1
+
+ val markerBytes = record.get(fieldIdx).asInstanceOf[ByteBuffer].array()
+ val markerBuf = ByteBuffer.wrap(markerBytes)
+ markerBuf.order(ByteOrder.BIG_ENDIAN)
+ val marker = markerBuf.get()
+
+ if (marker == nullValMarker) {
+ rowWriter.setNullAt(idx)
+ } else {
+ field.dataType match {
+ case BooleanType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ rowWriter.write(idx, bytes(0) == 1)
+
+ case ByteType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ rowWriter.write(idx, valueBuf.get())
+
+ case ShortType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ rowWriter.write(idx, valueBuf.getShort())
+
+ case IntegerType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ rowWriter.write(idx, valueBuf.getInt())
+
+ case LongType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ rowWriter.write(idx, valueBuf.getLong())
+
+ case FloatType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ if (marker == negativeValMarker) {
+ val floatVal = valueBuf.getFloat
+ val updatedVal = floatToRawIntBits(floatVal) ^ floatFlipBitMask
+ rowWriter.write(idx, intBitsToFloat(updatedVal))
+ } else {
+ rowWriter.write(idx, valueBuf.getFloat())
+ }
+
+ case DoubleType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ if (marker == negativeValMarker) {
+ val doubleVal = valueBuf.getDouble
+ val updatedVal = doubleToRawLongBits(doubleVal) ^
doubleFlipBitMask
+ rowWriter.write(idx, longBitsToDouble(updatedVal))
+ } else {
+ rowWriter.write(idx, valueBuf.getDouble())
+ }
+
+ case _ => throw new UnsupportedOperationException(
+ s"Range scan decoding not supported for data type:
${field.dataType}")
+ }
+ }
+ fieldIdx += 2
+ }
+
+ rowWriter.getRow()
+ }
+
+ override def decodeValue(bytes: Array[Byte]): UnsafeRow =
+ decodeFromAvroToUnsafeRow(
+ bytes, avroEncoder.valueDeserializer, valueAvroType, valueProj)
+}
+
+abstract class RocksDBKeyStateEncoderBase(
+ useColumnFamilies: Boolean,
+ virtualColFamilyId: Option[Short] = None) extends RocksDBKeyStateEncoder {
+ def offsetForColFamilyPrefix: Int =
+ if (useColumnFamilies) VIRTUAL_COL_FAMILY_PREFIX_BYTES else 0
+
+ val out = new ByteArrayOutputStream
+ /**
+ * Get Byte Array for the virtual column family id that is used as prefix for
+ * key state rows.
+ */
+ override def getColumnFamilyIdBytes(): Array[Byte] = {
+ assert(useColumnFamilies, "Cannot return virtual Column Family Id Bytes" +
+ " because multiple Column is not supported for this encoder")
+ val encodedBytes = new Array[Byte](VIRTUAL_COL_FAMILY_PREFIX_BYTES)
+ Platform.putShort(encodedBytes, Platform.BYTE_ARRAY_OFFSET,
virtualColFamilyId.get)
+ encodedBytes
+ }
+
+ /**
+ * Encode and put column family Id as a prefix to a pre-allocated byte array.
+ *
+ * @param numBytes - size of byte array to be created for storing key row
(without
+ * column family prefix)
+ * @return Array[Byte] for an array byte to put encoded key bytes
+ * Int for a starting offset to put the encoded key bytes
+ */
+ protected def encodeColumnFamilyPrefix(numBytes: Int): (Array[Byte], Int) = {
+ val encodedBytes = new Array[Byte](numBytes + offsetForColFamilyPrefix)
+ var offset = Platform.BYTE_ARRAY_OFFSET
+ if (useColumnFamilies) {
+ Platform.putShort(encodedBytes, Platform.BYTE_ARRAY_OFFSET,
virtualColFamilyId.get)
+ offset = Platform.BYTE_ARRAY_OFFSET + offsetForColFamilyPrefix
+ }
+ (encodedBytes, offset)
+ }
+
+ /**
+ * Get starting offset for decoding an encoded key byte array.
+ */
+ protected def decodeKeyStartOffset: Int = {
+ if (useColumnFamilies) {
+ Platform.BYTE_ARRAY_OFFSET + VIRTUAL_COL_FAMILY_PREFIX_BYTES
+ } else Platform.BYTE_ARRAY_OFFSET
+ }
+}
+
+object RocksDBStateEncoder extends Logging {
+ def getKeyEncoder(
+ dataEncoder: RocksDBDataEncoder,
+ keyStateEncoderSpec: KeyStateEncoderSpec,
+ useColumnFamilies: Boolean,
+ virtualColFamilyId: Option[Short] = None,
+ avroEnc: Option[AvroEncoder] = None): RocksDBKeyStateEncoder = {
+ // Return the key state encoder based on the requested type
+ keyStateEncoderSpec match {
+ case NoPrefixKeyStateEncoderSpec(keySchema) =>
+ new NoPrefixKeyStateEncoder(dataEncoder, keySchema, useColumnFamilies,
virtualColFamilyId)
+
+ case PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey) =>
+ new PrefixKeyScanStateEncoder(dataEncoder, keySchema, numColsPrefixKey,
+ useColumnFamilies, virtualColFamilyId)
+
+ case RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) =>
+ new RangeKeyScanStateEncoder(dataEncoder, keySchema, orderingOrdinals,
+ useColumnFamilies, virtualColFamilyId)
+
+ case _ =>
+ throw new IllegalArgumentException(s"Unsupported key state encoder
spec: " +
+ s"$keyStateEncoderSpec")
+ }
+ }
+
+ def getValueEncoder(
+ dataEncoder: RocksDBDataEncoder,
+ valueSchema: StructType,
+ useMultipleValuesPerKey: Boolean,
+ avroEnc: Option[AvroEncoder] = None): RocksDBValueStateEncoder = {
+ if (useMultipleValuesPerKey) {
+ new MultiValuedStateEncoder(dataEncoder, valueSchema)
+ } else {
+ new SingleValueStateEncoder(dataEncoder, valueSchema)
+ }
+ }
+
+ def getColumnFamilyIdBytes(virtualColFamilyId: Short): Array[Byte] = {
+ val encodedBytes = new Array[Byte](VIRTUAL_COL_FAMILY_PREFIX_BYTES)
+ Platform.putShort(encodedBytes, Platform.BYTE_ARRAY_OFFSET,
virtualColFamilyId)
+ encodedBytes
+ }
+}
+
+/**
+ * RocksDB Key Encoder for UnsafeRow that supports prefix scan
+ *
+ * @param dataEncoder - the encoder that handles actual encoding/decoding of
data
+ * @param keySchema - schema of the key to be encoded
+ * @param numColsPrefixKey - number of columns to be used for prefix key
+ * @param useColumnFamilies - if column family is enabled for this encoder
+ * @param avroEnc - if Avro encoding is specified for this StateEncoder, this
encoder will
+ * be defined
+ */
+class PrefixKeyScanStateEncoder(
+ dataEncoder: RocksDBDataEncoder,
+ keySchema: StructType,
+ numColsPrefixKey: Int,
+ useColumnFamilies: Boolean = false,
+ virtualColFamilyId: Option[Short] = None)
+ extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId)
with Logging {
+
+ private val prefixKeyFieldsWithIdx: Seq[(StructField, Int)] = {
+ keySchema.zipWithIndex.take(numColsPrefixKey)
+ }
+
+ private val remainingKeyFieldsWithIdx: Seq[(StructField, Int)] = {
+ keySchema.zipWithIndex.drop(numColsPrefixKey)
+ }
+
+ private val prefixKeyProjection: UnsafeProjection = {
+ val refs = prefixKeyFieldsWithIdx.map(x => BoundReference(x._2,
x._1.dataType, x._1.nullable))
+ UnsafeProjection.create(refs)
+ }
+
+ private val remainingKeyProjection: UnsafeProjection = {
+ val refs = remainingKeyFieldsWithIdx.map(x =>
+ BoundReference(x._2, x._1.dataType, x._1.nullable))
+ UnsafeProjection.create(refs)
+ }
+
+ // Prefix Key schema and projection definitions used by the Avro Serializers
+ // and Deserializers
+ private val prefixKeySchema = StructType(keySchema.take(numColsPrefixKey))
+ private lazy val prefixKeyAvroType =
SchemaConverters.toAvroType(prefixKeySchema)
+ private val prefixKeyProj = UnsafeProjection.create(prefixKeySchema)
+
+ // Remaining Key schema and projection definitions used by the Avro
Serializers
+ // and Deserializers
+ private val remainingKeySchema = StructType(keySchema.drop(numColsPrefixKey))
+ private lazy val remainingKeyAvroType =
SchemaConverters.toAvroType(remainingKeySchema)
+ private val remainingKeyProj = UnsafeProjection.create(remainingKeySchema)
Review Comment:
some of these need to be cleaned up. They're unused and no longer needed
##########
sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala:
##########
@@ -41,7 +41,7 @@ case class OutputEvent(
* Test suite base for TransformWithState with TTL support.
*/
abstract class TransformWithStateTTLTest
- extends StreamTest with AlsoTestWithChangelogCheckpointingEnabled
+ extends StreamTest
Review Comment:
are these changes merge conflicts?
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala:
##########
@@ -664,12 +689,29 @@ object RocksDBStateStoreProvider {
avroOptions.stableIdPrefixForUnionType,
avroOptions.recursiveFieldMaxDepth)
}
+ /**
Review Comment:
IMHO all of this logic should live inside the AvroDataEncoder
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala:
##########
@@ -387,17 +386,16 @@ private[sql] class RocksDBStateStoreProvider
defaultColFamilyId =
Some(rocksDB.createColFamilyIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME))
}
- val colFamilyName = StateStore.DEFAULT_COL_FAMILY_NAME
- // Create cache key using store ID to avoid collisions
- val avroEncCacheKey =
s"${getRunId(hadoopConf)}_${stateStoreId.operatorId}_" +
- s"${stateStoreId.partitionId}_$colFamilyName"
- val avroEnc = getAvroEnc(
- stateStoreEncoding, avroEncCacheKey, keyStateEncoderSpec, valueSchema)
+ val dataEncoderCacheKey =
s"${getRunId(hadoopConf)}_${stateStoreId.operatorId}_" +
+ s"${stateStoreId.partitionId}_${StateStore.DEFAULT_COL_FAMILY_NAME}"
+
+ val dataEncoder = getDataEncoder(
+ stateStoreEncoding, dataEncoderCacheKey, keyStateEncoderSpec,
valueSchema)
keyValueEncoderMap.putIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME,
- (RocksDBStateEncoder.getKeyEncoder(keyStateEncoderSpec,
- useColumnFamilies, defaultColFamilyId, avroEnc),
- RocksDBStateEncoder.getValueEncoder(valueSchema,
useMultipleValuesPerKey, avroEnc)))
+ (RocksDBStateEncoder.getKeyEncoder(dataEncoder, keyStateEncoderSpec,
+ useColumnFamilies, defaultColFamilyId),
+ RocksDBStateEncoder.getValueEncoder(dataEncoder, valueSchema,
useMultipleValuesPerKey)))
Review Comment:
ditto
##########
sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateTTLTest.scala:
##########
@@ -41,7 +41,7 @@ case class OutputEvent(
* Test suite base for TransformWithState with TTL support.
*/
abstract class TransformWithStateTTLTest
- extends StreamTest with AlsoTestWithChangelogCheckpointingEnabled
+ extends StreamTest
Review Comment:
I think the merge accidentally removed this one
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala:
##########
@@ -800,156 +583,509 @@ class RangeKeyScanStateEncoder(
})
record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array()))
- val valueBuffer = ByteBuffer.allocate(8)
- valueBuffer.order(ByteOrder.BIG_ENDIAN)
- if ((rawBits & doubleSignBitMask) != 0) {
- val updatedVal = rawBits ^ doubleFlipBitMask
- valueBuffer.putDouble(longBitsToDouble(updatedVal))
- } else {
- valueBuffer.putDouble(doubleVal)
- }
- record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array()))
+ val valueBuffer = ByteBuffer.allocate(8)
+ valueBuffer.order(ByteOrder.BIG_ENDIAN)
+ if ((rawBits & doubleSignBitMask) != 0) {
+ val updatedVal = rawBits ^ doubleFlipBitMask
+ valueBuffer.putDouble(longBitsToDouble(updatedVal))
+ } else {
+ valueBuffer.putDouble(doubleVal)
+ }
+ record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array()))
+
+ case _ => throw new UnsupportedOperationException(
+ s"Range scan encoding not supported for data type:
${field.dataType}")
+ }
+ }
+ fieldIdx += 2
+ }
+
+ out.reset()
+ val writer = new GenericDatumWriter[GenericRecord](rangeScanAvroType)
+ val encoder = EncoderFactory.get().binaryEncoder(out, null)
+ writer.write(record, encoder)
+ encoder.flush()
+ out.toByteArray
+ }
+
+ override def encodeValue(row: UnsafeRow): Array[Byte] =
+ encodeUnsafeRowToAvro(row, avroEncoder.valueSerializer, valueAvroType, out)
+
+ override def decodeKey(bytes: Array[Byte]): UnsafeRow = {
+ keyStateEncoderSpec match {
+ case NoPrefixKeyStateEncoderSpec(_) =>
+ decodeFromAvroToUnsafeRow(bytes, avroEncoder.keyDeserializer,
keyAvroType, keyProj)
+ case PrefixKeyScanStateEncoderSpec(_, _) =>
+ decodeFromAvroToUnsafeRow(
+ bytes, avroEncoder.keyDeserializer, prefixKeyAvroType, prefixKeyProj)
+ case _ => throw unsupportedOperationForKeyStateEncoder("decodeKey")
+ }
+ }
+
+
+ override def decodeRemainingKey(bytes: Array[Byte]): UnsafeRow = {
+ keyStateEncoderSpec match {
+ case PrefixKeyScanStateEncoderSpec(_, _) =>
+ decodeFromAvroToUnsafeRow(bytes,
+ avroEncoder.suffixKeyDeserializer.get, remainingKeyAvroType,
remainingKeyAvroProjection)
+ case RangeKeyScanStateEncoderSpec(_, _) =>
+ decodeFromAvroToUnsafeRow(
+ bytes, avroEncoder.keyDeserializer, remainingKeyAvroType,
remainingKeyAvroProjection)
+ case _ => throw
unsupportedOperationForKeyStateEncoder("decodeRemainingKey")
+ }
+ }
+
+ /**
+ * Decodes an Avro-encoded byte array back into an UnsafeRow for range scan
operations.
+ *
+ * This method reverses the encoding process performed by
encodePrefixKeyForRangeScan:
+ * - Reads the marker byte to determine null status or sign
+ * - Reconstructs the original values from big-endian format
+ * - Handles special cases for floating point numbers by reversing bit
manipulations
+ *
+ * The decoding process preserves the original data types and values,
including:
+ * - Null values marked by nullValMarker
+ * - Sign information for numeric types
+ * - Proper restoration of negative floating point values
+ *
+ * @param bytes The Avro-encoded byte array to decode
+ * @param avroType The Avro schema defining the structure for decoding
+ * @return UnsafeRow containing the decoded data
+ * @throws UnsupportedOperationException if a field's data type is not
supported for range
+ * scan decoding
+ */
+ override def decodePrefixKeyForRangeScan(bytes: Array[Byte]): UnsafeRow = {
+ val reader = new GenericDatumReader[GenericRecord](rangeScanAvroType)
+ val decoder = DecoderFactory.get().binaryDecoder(bytes, 0, bytes.length,
null)
+ val record = reader.read(null, decoder)
+
+ val rowWriter = new UnsafeRowWriter(rangeScanKeyFieldsWithOrdinal.length)
+ rowWriter.resetRowWriter()
+
+ var fieldIdx = 0
+ rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case
(fieldWithOrdinal, idx) =>
+ val field = fieldWithOrdinal._1
+
+ val markerBytes = record.get(fieldIdx).asInstanceOf[ByteBuffer].array()
+ val markerBuf = ByteBuffer.wrap(markerBytes)
+ markerBuf.order(ByteOrder.BIG_ENDIAN)
+ val marker = markerBuf.get()
+
+ if (marker == nullValMarker) {
+ rowWriter.setNullAt(idx)
+ } else {
+ field.dataType match {
+ case BooleanType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ rowWriter.write(idx, bytes(0) == 1)
+
+ case ByteType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ rowWriter.write(idx, valueBuf.get())
+
+ case ShortType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ rowWriter.write(idx, valueBuf.getShort())
+
+ case IntegerType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ rowWriter.write(idx, valueBuf.getInt())
+
+ case LongType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ rowWriter.write(idx, valueBuf.getLong())
+
+ case FloatType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ if (marker == negativeValMarker) {
+ val floatVal = valueBuf.getFloat
+ val updatedVal = floatToRawIntBits(floatVal) ^ floatFlipBitMask
+ rowWriter.write(idx, intBitsToFloat(updatedVal))
+ } else {
+ rowWriter.write(idx, valueBuf.getFloat())
+ }
+
+ case DoubleType =>
+ val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
+ val valueBuf = ByteBuffer.wrap(bytes)
+ valueBuf.order(ByteOrder.BIG_ENDIAN)
+ if (marker == negativeValMarker) {
+ val doubleVal = valueBuf.getDouble
+ val updatedVal = doubleToRawLongBits(doubleVal) ^
doubleFlipBitMask
+ rowWriter.write(idx, longBitsToDouble(updatedVal))
+ } else {
+ rowWriter.write(idx, valueBuf.getDouble())
+ }
+
+ case _ => throw new UnsupportedOperationException(
+ s"Range scan decoding not supported for data type:
${field.dataType}")
+ }
+ }
+ fieldIdx += 2
+ }
+
+ rowWriter.getRow()
+ }
+
+ override def decodeValue(bytes: Array[Byte]): UnsafeRow =
+ decodeFromAvroToUnsafeRow(
+ bytes, avroEncoder.valueDeserializer, valueAvroType, valueProj)
+}
+
+abstract class RocksDBKeyStateEncoderBase(
+ useColumnFamilies: Boolean,
+ virtualColFamilyId: Option[Short] = None) extends RocksDBKeyStateEncoder {
+ def offsetForColFamilyPrefix: Int =
+ if (useColumnFamilies) VIRTUAL_COL_FAMILY_PREFIX_BYTES else 0
+
+ val out = new ByteArrayOutputStream
+ /**
+ * Get Byte Array for the virtual column family id that is used as prefix for
+ * key state rows.
+ */
+ override def getColumnFamilyIdBytes(): Array[Byte] = {
+ assert(useColumnFamilies, "Cannot return virtual Column Family Id Bytes" +
+ " because multiple Column is not supported for this encoder")
+ val encodedBytes = new Array[Byte](VIRTUAL_COL_FAMILY_PREFIX_BYTES)
+ Platform.putShort(encodedBytes, Platform.BYTE_ARRAY_OFFSET,
virtualColFamilyId.get)
+ encodedBytes
+ }
+
+ /**
+ * Encode and put column family Id as a prefix to a pre-allocated byte array.
+ *
+ * @param numBytes - size of byte array to be created for storing key row
(without
+ * column family prefix)
+ * @return Array[Byte] for an array byte to put encoded key bytes
+ * Int for a starting offset to put the encoded key bytes
+ */
+ protected def encodeColumnFamilyPrefix(numBytes: Int): (Array[Byte], Int) = {
+ val encodedBytes = new Array[Byte](numBytes + offsetForColFamilyPrefix)
+ var offset = Platform.BYTE_ARRAY_OFFSET
+ if (useColumnFamilies) {
+ Platform.putShort(encodedBytes, Platform.BYTE_ARRAY_OFFSET,
virtualColFamilyId.get)
+ offset = Platform.BYTE_ARRAY_OFFSET + offsetForColFamilyPrefix
+ }
+ (encodedBytes, offset)
+ }
+
+ /**
+ * Get starting offset for decoding an encoded key byte array.
+ */
+ protected def decodeKeyStartOffset: Int = {
+ if (useColumnFamilies) {
+ Platform.BYTE_ARRAY_OFFSET + VIRTUAL_COL_FAMILY_PREFIX_BYTES
+ } else Platform.BYTE_ARRAY_OFFSET
+ }
+}
+
+object RocksDBStateEncoder extends Logging {
+ def getKeyEncoder(
+ dataEncoder: RocksDBDataEncoder,
+ keyStateEncoderSpec: KeyStateEncoderSpec,
+ useColumnFamilies: Boolean,
+ virtualColFamilyId: Option[Short] = None,
+ avroEnc: Option[AvroEncoder] = None): RocksDBKeyStateEncoder = {
+ // Return the key state encoder based on the requested type
+ keyStateEncoderSpec match {
+ case NoPrefixKeyStateEncoderSpec(keySchema) =>
+ new NoPrefixKeyStateEncoder(dataEncoder, keySchema, useColumnFamilies,
virtualColFamilyId)
+
+ case PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey) =>
+ new PrefixKeyScanStateEncoder(dataEncoder, keySchema, numColsPrefixKey,
+ useColumnFamilies, virtualColFamilyId)
+
+ case RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) =>
+ new RangeKeyScanStateEncoder(dataEncoder, keySchema, orderingOrdinals,
+ useColumnFamilies, virtualColFamilyId)
+
+ case _ =>
+ throw new IllegalArgumentException(s"Unsupported key state encoder
spec: " +
+ s"$keyStateEncoderSpec")
+ }
+ }
+
+ def getValueEncoder(
+ dataEncoder: RocksDBDataEncoder,
+ valueSchema: StructType,
+ useMultipleValuesPerKey: Boolean,
+ avroEnc: Option[AvroEncoder] = None): RocksDBValueStateEncoder = {
+ if (useMultipleValuesPerKey) {
+ new MultiValuedStateEncoder(dataEncoder, valueSchema)
+ } else {
+ new SingleValueStateEncoder(dataEncoder, valueSchema)
+ }
+ }
+
+ def getColumnFamilyIdBytes(virtualColFamilyId: Short): Array[Byte] = {
+ val encodedBytes = new Array[Byte](VIRTUAL_COL_FAMILY_PREFIX_BYTES)
+ Platform.putShort(encodedBytes, Platform.BYTE_ARRAY_OFFSET,
virtualColFamilyId)
+ encodedBytes
+ }
+}
+
+/**
+ * RocksDB Key Encoder for UnsafeRow that supports prefix scan
+ *
+ * @param dataEncoder - the encoder that handles actual encoding/decoding of
data
+ * @param keySchema - schema of the key to be encoded
+ * @param numColsPrefixKey - number of columns to be used for prefix key
+ * @param useColumnFamilies - if column family is enabled for this encoder
+ * @param avroEnc - if Avro encoding is specified for this StateEncoder, this
encoder will
+ * be defined
+ */
+class PrefixKeyScanStateEncoder(
+ dataEncoder: RocksDBDataEncoder,
+ keySchema: StructType,
+ numColsPrefixKey: Int,
+ useColumnFamilies: Boolean = false,
+ virtualColFamilyId: Option[Short] = None)
+ extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId)
with Logging {
+
+ private val prefixKeyFieldsWithIdx: Seq[(StructField, Int)] = {
+ keySchema.zipWithIndex.take(numColsPrefixKey)
+ }
+
+ private val remainingKeyFieldsWithIdx: Seq[(StructField, Int)] = {
+ keySchema.zipWithIndex.drop(numColsPrefixKey)
+ }
+
+ private val prefixKeyProjection: UnsafeProjection = {
+ val refs = prefixKeyFieldsWithIdx.map(x => BoundReference(x._2,
x._1.dataType, x._1.nullable))
+ UnsafeProjection.create(refs)
+ }
+
+ private val remainingKeyProjection: UnsafeProjection = {
+ val refs = remainingKeyFieldsWithIdx.map(x =>
+ BoundReference(x._2, x._1.dataType, x._1.nullable))
+ UnsafeProjection.create(refs)
+ }
+
+ // Prefix Key schema and projection definitions used by the Avro Serializers
+ // and Deserializers
+ private val prefixKeySchema = StructType(keySchema.take(numColsPrefixKey))
+ private lazy val prefixKeyAvroType =
SchemaConverters.toAvroType(prefixKeySchema)
+ private val prefixKeyProj = UnsafeProjection.create(prefixKeySchema)
+
+ // Remaining Key schema and projection definitions used by the Avro
Serializers
+ // and Deserializers
+ private val remainingKeySchema = StructType(keySchema.drop(numColsPrefixKey))
+ private lazy val remainingKeyAvroType =
SchemaConverters.toAvroType(remainingKeySchema)
+ private val remainingKeyProj = UnsafeProjection.create(remainingKeySchema)
+
+ // This is quite simple to do - just bind sequentially, as we don't change
the order.
+ private val restoreKeyProjection: UnsafeProjection =
UnsafeProjection.create(keySchema)
+
+ // Reusable objects
+ private val joinedRowOnKey = new JoinedRow()
+
+ override def encodeKey(row: UnsafeRow): Array[Byte] = {
+ val prefixKeyEncoded = dataEncoder.encodeKey(extractPrefixKey(row))
+ val remainingEncoded =
dataEncoder.encodeRemainingKey(remainingKeyProjection(row))
+
+ val (encodedBytes, startingOffset) = encodeColumnFamilyPrefix(
+ prefixKeyEncoded.length + remainingEncoded.length + 4
+ )
+
+ Platform.putInt(encodedBytes, startingOffset, prefixKeyEncoded.length)
+ Platform.copyMemory(prefixKeyEncoded, Platform.BYTE_ARRAY_OFFSET,
+ encodedBytes, startingOffset + 4, prefixKeyEncoded.length)
+ // NOTE: We don't put the length of remainingEncoded as we can calculate
later
+ // on deserialization.
+ Platform.copyMemory(remainingEncoded, Platform.BYTE_ARRAY_OFFSET,
+ encodedBytes, startingOffset + 4 + prefixKeyEncoded.length,
+ remainingEncoded.length)
+
+ encodedBytes
+ }
+
+ override def decodeKey(keyBytes: Array[Byte]): UnsafeRow = {
+ val prefixKeyEncodedLen = Platform.getInt(keyBytes, decodeKeyStartOffset)
+ val prefixKeyEncoded = new Array[Byte](prefixKeyEncodedLen)
+ Platform.copyMemory(keyBytes, decodeKeyStartOffset + 4,
+ prefixKeyEncoded, Platform.BYTE_ARRAY_OFFSET, prefixKeyEncodedLen)
+
+ // Here we calculate the remainingKeyEncodedLen leveraging the length of
keyBytes
+ val remainingKeyEncodedLen = keyBytes.length - 4 - prefixKeyEncodedLen -
+ offsetForColFamilyPrefix
+
+ val remainingKeyEncoded = new Array[Byte](remainingKeyEncodedLen)
+ Platform.copyMemory(keyBytes, decodeKeyStartOffset + 4 +
prefixKeyEncodedLen,
+ remainingKeyEncoded, Platform.BYTE_ARRAY_OFFSET, remainingKeyEncodedLen)
+
+ val prefixKeyDecoded = dataEncoder.decodeKey(
+ prefixKeyEncoded)
+ val remainingKeyDecoded =
dataEncoder.decodeRemainingKey(remainingKeyEncoded)
+
+
restoreKeyProjection(joinedRowOnKey.withLeft(prefixKeyDecoded).withRight(remainingKeyDecoded))
+ }
+
+ private def extractPrefixKey(key: UnsafeRow): UnsafeRow = {
+ prefixKeyProjection(key)
+ }
+
+ override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = {
+ val prefixKeyEncoded = dataEncoder.encodeKey(prefixKey)
+ val (prefix, startingOffset) = encodeColumnFamilyPrefix(
+ prefixKeyEncoded.length + 4
+ )
+
+ Platform.putInt(prefix, startingOffset, prefixKeyEncoded.length)
+ Platform.copyMemory(prefixKeyEncoded, Platform.BYTE_ARRAY_OFFSET, prefix,
+ startingOffset + 4, prefixKeyEncoded.length)
+ prefix
+ }
+
+ override def supportPrefixKeyScan: Boolean = true
+}
+
+/**
+ * RocksDB Key Encoder for UnsafeRow that supports range scan for fixed size
fields
+ *
+ * To encode a row for range scan, we first project the orderingOrdinals from
the oridinal
+ * UnsafeRow into another UnsafeRow; we then rewrite that new UnsafeRow's
fields in BIG_ENDIAN
+ * to allow for scanning keys in sorted order using the byte-wise comparison
method that
+ * RocksDB uses.
+ *
+ * Then, for the rest of the fields, we project those into another UnsafeRow.
+ * We then effectively join these two UnsafeRows together, and finally take
those bytes
+ * to get the resulting row.
+ *
+ * We cannot support variable sized fields in the range scan because the
UnsafeRow format
+ * stores variable sized fields as offset and length pointers to the actual
values,
+ * thereby changing the required ordering.
+ *
+ * Note that we also support "null" values being passed for these fixed size
fields. We prepend
+ * a single byte to indicate whether the column value is null or not. We
cannot change the
+ * nullability on the UnsafeRow itself as the expected ordering would change
if non-first
+ * columns are marked as null. If the first col is null, those entries will
appear last in
+ * the iterator. If non-first columns are null, ordering based on the previous
columns will
+ * still be honored. For rows with null column values, ordering for subsequent
columns
+ * will also be maintained within those set of rows. We use the same byte to
also encode whether
+ * the value is negative or not. For negative float/double values, we flip all
the bits to ensure
+ * the right lexicographical ordering. For the rationale around this, please
check the link
+ * here: https://en.wikipedia.org/wiki/IEEE_754#Design_rationale
+ *
+ * @param dataEncoder - the encoder that handles the actual encoding/decoding
of data
+ * @param keySchema - schema of the key to be encoded
+ * @param orderingOrdinals - the ordinals for which the range scan is
constructed
+ * @param useColumnFamilies - if column family is enabled for this encoder
+ * @param avroEnc - if Avro encoding is specified for this StateEncoder, this
encoder will
+ * be defined
+ */
+class RangeKeyScanStateEncoder(
+ dataEncoder: RocksDBDataEncoder,
+ keySchema: StructType,
+ orderingOrdinals: Seq[Int],
+ useColumnFamilies: Boolean = false,
+ virtualColFamilyId: Option[Short] = None)
+ extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId)
with Logging {
- case _ => throw new UnsupportedOperationException(
- s"Range scan encoding not supported for data type:
${field.dataType}")
- }
- }
- fieldIdx += 2
+ private val rangeScanKeyFieldsWithOrdinal: Seq[(StructField, Int)] = {
+ orderingOrdinals.map { ordinal =>
+ val field = keySchema(ordinal)
+ (field, ordinal)
}
-
- out.reset()
- val writer = new GenericDatumWriter[GenericRecord](rangeScanAvroType)
- val encoder = EncoderFactory.get().binaryEncoder(out, null)
- writer.write(record, encoder)
- encoder.flush()
- out.toByteArray
}
- /**
- * Decodes an Avro-encoded byte array back into an UnsafeRow for range scan
operations.
- *
- * This method reverses the encoding process performed by
encodePrefixKeyForRangeScan:
- * - Reads the marker byte to determine null status or sign
- * - Reconstructs the original values from big-endian format
- * - Handles special cases for floating point numbers by reversing bit
manipulations
- *
- * The decoding process preserves the original data types and values,
including:
- * - Null values marked by nullValMarker
- * - Sign information for numeric types
- * - Proper restoration of negative floating point values
- *
- * @param bytes The Avro-encoded byte array to decode
- * @param avroType The Avro schema defining the structure for decoding
- * @return UnsafeRow containing the decoded data
- * @throws UnsupportedOperationException if a field's data type is not
supported for range
- * scan decoding
- */
- def decodePrefixKeyForRangeScan(
- bytes: Array[Byte],
- avroType: Schema): UnsafeRow = {
+ private def isFixedSize(dataType: DataType): Boolean = dataType match {
+ case _: ByteType | _: BooleanType | _: ShortType | _: IntegerType | _:
LongType |
+ _: FloatType | _: DoubleType => true
+ case _ => false
+ }
- val reader = new GenericDatumReader[GenericRecord](avroType)
- val decoder = DecoderFactory.get().binaryDecoder(bytes, 0, bytes.length,
null)
- val record = reader.read(null, decoder)
+ // verify that only fixed sized columns are used for ordering
+ rangeScanKeyFieldsWithOrdinal.foreach { case (field, ordinal) =>
+ if (!isFixedSize(field.dataType)) {
+ // NullType is technically fixed size, but not supported for ordering
+ if (field.dataType == NullType) {
+ throw StateStoreErrors.nullTypeOrderingColsNotSupported(field.name,
ordinal.toString)
+ } else {
+ throw
StateStoreErrors.variableSizeOrderingColsNotSupported(field.name,
ordinal.toString)
+ }
+ }
+ }
- val rowWriter = new UnsafeRowWriter(rangeScanKeyFieldsWithOrdinal.length)
- rowWriter.resetRowWriter()
+ private val remainingKeyFieldsWithOrdinal: Seq[(StructField, Int)] = {
+ 0.to(keySchema.length - 1).diff(orderingOrdinals).map { ordinal =>
+ val field = keySchema(ordinal)
+ (field, ordinal)
+ }
+ }
- var fieldIdx = 0
- rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case
(fieldWithOrdinal, idx) =>
- val field = fieldWithOrdinal._1
+ private val rangeScanKeyProjection: UnsafeProjection = {
+ val refs = rangeScanKeyFieldsWithOrdinal.map(x =>
+ BoundReference(x._2, x._1.dataType, x._1.nullable))
+ UnsafeProjection.create(refs)
+ }
- val markerBytes = record.get(fieldIdx).asInstanceOf[ByteBuffer].array()
- val markerBuf = ByteBuffer.wrap(markerBytes)
- markerBuf.order(ByteOrder.BIG_ENDIAN)
- val marker = markerBuf.get()
+ private val remainingKeyProjection: UnsafeProjection = {
+ val refs = remainingKeyFieldsWithOrdinal.map(x =>
+ BoundReference(x._2, x._1.dataType, x._1.nullable))
+ UnsafeProjection.create(refs)
+ }
- if (marker == nullValMarker) {
- rowWriter.setNullAt(idx)
+ // The original schema that we might get could be:
+ // [foo, bar, baz, buzz]
+ // We might order by bar and buzz, leading to:
+ // [bar, buzz, foo, baz]
+ // We need to create a projection that sends, for example, the buzz at index
1 to index
+ // 3. Thus, for every record in the original schema, we compute where it
would be in
+ // the joined row and created a projection based on that.
+ private val restoreKeyProjection: UnsafeProjection = {
+ val refs = keySchema.zipWithIndex.map { case (field, originalOrdinal) =>
+ val ordinalInJoinedRow = if (orderingOrdinals.contains(originalOrdinal))
{
+ orderingOrdinals.indexOf(originalOrdinal)
} else {
- field.dataType match {
- case BooleanType =>
- val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
- rowWriter.write(idx, bytes(0) == 1)
+ orderingOrdinals.length +
+ remainingKeyFieldsWithOrdinal.indexWhere(_._2 == originalOrdinal)
+ }
- case ByteType =>
- val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
- val valueBuf = ByteBuffer.wrap(bytes)
- valueBuf.order(ByteOrder.BIG_ENDIAN)
- rowWriter.write(idx, valueBuf.get())
+ BoundReference(ordinalInJoinedRow, field.dataType, field.nullable)
+ }
+ UnsafeProjection.create(refs)
+ }
- case ShortType =>
- val bytes = record.get(fieldIdx +
1).asInstanceOf[ByteBuffer].array()
- val valueBuf = ByteBuffer.wrap(bytes)
- valueBuf.order(ByteOrder.BIG_ENDIAN)
- rowWriter.write(idx, valueBuf.getShort())
+ private val rangeScanAvroSchema =
StateStoreColumnFamilySchemaUtils.convertForRangeScan(
+ StructType(rangeScanKeyFieldsWithOrdinal.map(_._1).toArray))
Review Comment:
these are also no longer used. Please clean them up
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]