ericm-db commented on code in PR #48944:
URL: https://github.com/apache/spark/pull/48944#discussion_r1887871951


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala:
##########
@@ -219,469 +222,402 @@ object RocksDBStateEncoder extends Logging {
   }
 }
 
-/**
- * RocksDB Key Encoder for UnsafeRow that supports prefix scan
- *
- * @param keySchema - schema of the key to be encoded
- * @param numColsPrefixKey - number of columns to be used for prefix key
- * @param useColumnFamilies - if column family is enabled for this encoder
- * @param avroEnc - if Avro encoding is specified for this StateEncoder, this 
encoder will
- *                be defined
- */
-class PrefixKeyScanStateEncoder(
-    keySchema: StructType,
-    numColsPrefixKey: Int,
-    useColumnFamilies: Boolean = false,
-    virtualColFamilyId: Option[Short] = None,
-    avroEnc: Option[AvroEncoder] = None)
-  extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) {
-
-  import RocksDBStateEncoder._
-
-  private val usingAvroEncoding = avroEnc.isDefined
-  private val prefixKeyFieldsWithIdx: Seq[(StructField, Int)] = {
-    keySchema.zipWithIndex.take(numColsPrefixKey)
-  }
+class UnsafeRowDataEncoder(
+    keyStateEncoderSpec: KeyStateEncoderSpec,
+    valueSchema: StructType) extends RocksDBDataEncoder(keyStateEncoderSpec, 
valueSchema) {
 
-  private val remainingKeyFieldsWithIdx: Seq[(StructField, Int)] = {
-    keySchema.zipWithIndex.drop(numColsPrefixKey)
+  override def encodeKey(row: UnsafeRow): Array[Byte] = {
+    encodeUnsafeRow(row)
   }
 
-  private val prefixKeyProjection: UnsafeProjection = {
-    val refs = prefixKeyFieldsWithIdx.map(x => BoundReference(x._2, 
x._1.dataType, x._1.nullable))
-    UnsafeProjection.create(refs)
+  override def encodeRemainingKey(row: UnsafeRow): Array[Byte] = {
+    encodeUnsafeRow(row)
   }
 
-  private val remainingKeyProjection: UnsafeProjection = {
-    val refs = remainingKeyFieldsWithIdx.map(x =>
-      BoundReference(x._2, x._1.dataType, x._1.nullable))
-    UnsafeProjection.create(refs)
-  }
+  override def encodePrefixKeyForRangeScan(row: UnsafeRow): Array[Byte] = {
+    assert(keyStateEncoderSpec.isInstanceOf[RangeKeyScanStateEncoderSpec])
+    val rsk = keyStateEncoderSpec.asInstanceOf[RangeKeyScanStateEncoderSpec]
+    val rangeScanKeyFieldsWithOrdinal = rsk.orderingOrdinals.map { ordinal =>
+      val field = rsk.keySchema(ordinal)
+      (field, ordinal)
+    }
+    val writer = new UnsafeRowWriter(rsk.orderingOrdinals.length)
+    writer.resetRowWriter()
+    rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case 
(fieldWithOrdinal, idx) =>
+      val field = fieldWithOrdinal._1
+      val value = row.get(idx, field.dataType)
+      // Note that we cannot allocate a smaller buffer here even if the value 
is null
+      // because the effective byte array is considered variable size and 
needs to have
+      // the same size across all rows for the ordering to work as expected.
+      val bbuf = ByteBuffer.allocate(field.dataType.defaultSize + 1)
+      bbuf.order(ByteOrder.BIG_ENDIAN)
+      if (value == null) {
+        bbuf.put(nullValMarker)
+        writer.write(idx, bbuf.array())
+      } else {
+        field.dataType match {
+          case BooleanType =>
+          case ByteType =>
+            val byteVal = value.asInstanceOf[Byte]
+            val signCol = if (byteVal < 0) {
+              negativeValMarker
+            } else {
+              positiveValMarker
+            }
+            bbuf.put(signCol)
+            bbuf.put(byteVal)
+            writer.write(idx, bbuf.array())
 
-  // Prefix Key schema and projection definitions used by the Avro Serializers
-  // and Deserializers
-  private val prefixKeySchema = StructType(keySchema.take(numColsPrefixKey))
-  private lazy val prefixKeyAvroType = 
SchemaConverters.toAvroType(prefixKeySchema)
-  private val prefixKeyProj = UnsafeProjection.create(prefixKeySchema)
+          case ShortType =>
+            val shortVal = value.asInstanceOf[Short]
+            val signCol = if (shortVal < 0) {
+              negativeValMarker
+            } else {
+              positiveValMarker
+            }
+            bbuf.put(signCol)
+            bbuf.putShort(shortVal)
+            writer.write(idx, bbuf.array())
 
-  // Remaining Key schema and projection definitions used by the Avro 
Serializers
-  // and Deserializers
-  private val remainingKeySchema = StructType(keySchema.drop(numColsPrefixKey))
-  private lazy val remainingKeyAvroType = 
SchemaConverters.toAvroType(remainingKeySchema)
-  private val remainingKeyProj = UnsafeProjection.create(remainingKeySchema)
+          case IntegerType =>
+            val intVal = value.asInstanceOf[Int]
+            val signCol = if (intVal < 0) {
+              negativeValMarker
+            } else {
+              positiveValMarker
+            }
+            bbuf.put(signCol)
+            bbuf.putInt(intVal)
+            writer.write(idx, bbuf.array())
 
-  // This is quite simple to do - just bind sequentially, as we don't change 
the order.
-  private val restoreKeyProjection: UnsafeProjection = 
UnsafeProjection.create(keySchema)
+          case LongType =>
+            val longVal = value.asInstanceOf[Long]
+            val signCol = if (longVal < 0) {
+              negativeValMarker
+            } else {
+              positiveValMarker
+            }
+            bbuf.put(signCol)
+            bbuf.putLong(longVal)
+            writer.write(idx, bbuf.array())
 
-  // Reusable objects
-  private val joinedRowOnKey = new JoinedRow()
+          case FloatType =>
+            val floatVal = value.asInstanceOf[Float]
+            val rawBits = floatToRawIntBits(floatVal)
+            // perform sign comparison using bit manipulation to ensure NaN 
values are handled
+            // correctly
+            if ((rawBits & floatSignBitMask) != 0) {
+              // for negative values, we need to flip all the bits to ensure 
correct ordering
+              val updatedVal = rawBits ^ floatFlipBitMask
+              bbuf.put(negativeValMarker)
+              // convert the bits back to float
+              bbuf.putFloat(intBitsToFloat(updatedVal))
+            } else {
+              bbuf.put(positiveValMarker)
+              bbuf.putFloat(floatVal)
+            }
+            writer.write(idx, bbuf.array())
 
-  override def encodeKey(row: UnsafeRow): Array[Byte] = {
-    val (prefixKeyEncoded, remainingEncoded) = if (usingAvroEncoding) {
-      (
-        encodeUnsafeRowToAvro(
-          extractPrefixKey(row),
-          avroEnc.get.keySerializer,
-          prefixKeyAvroType,
-          out
-        ),
-        encodeUnsafeRowToAvro(
-          remainingKeyProjection(row),
-          avroEnc.get.suffixKeySerializer.get,
-          remainingKeyAvroType,
-          out
-        )
-      )
-    } else {
-      (encodeUnsafeRow(extractPrefixKey(row)), 
encodeUnsafeRow(remainingKeyProjection(row)))
+          case DoubleType =>
+            val doubleVal = value.asInstanceOf[Double]
+            val rawBits = doubleToRawLongBits(doubleVal)
+            // perform sign comparison using bit manipulation to ensure NaN 
values are handled
+            // correctly
+            if ((rawBits & doubleSignBitMask) != 0) {
+              // for negative values, we need to flip all the bits to ensure 
correct ordering
+              val updatedVal = rawBits ^ doubleFlipBitMask
+              bbuf.put(negativeValMarker)
+              // convert the bits back to double
+              bbuf.putDouble(longBitsToDouble(updatedVal))
+            } else {
+              bbuf.put(positiveValMarker)
+              bbuf.putDouble(doubleVal)
+            }
+            writer.write(idx, bbuf.array())
+        }
+      }
     }
-    val (encodedBytes, startingOffset) = encodeColumnFamilyPrefix(
-      prefixKeyEncoded.length + remainingEncoded.length + 4
-    )
+    encodeUnsafeRow(writer.getRow())
+  }
 
-    Platform.putInt(encodedBytes, startingOffset, prefixKeyEncoded.length)
-    Platform.copyMemory(prefixKeyEncoded, Platform.BYTE_ARRAY_OFFSET,
-      encodedBytes, startingOffset + 4, prefixKeyEncoded.length)
-    // NOTE: We don't put the length of remainingEncoded as we can calculate 
later
-    // on deserialization.
-    Platform.copyMemory(remainingEncoded, Platform.BYTE_ARRAY_OFFSET,
-      encodedBytes, startingOffset + 4 + prefixKeyEncoded.length,
-      remainingEncoded.length)
-
-    encodedBytes
-  }
-
-  override def decodeKey(keyBytes: Array[Byte]): UnsafeRow = {
-    val prefixKeyEncodedLen = Platform.getInt(keyBytes, decodeKeyStartOffset)
-    val prefixKeyEncoded = new Array[Byte](prefixKeyEncodedLen)
-    Platform.copyMemory(keyBytes, decodeKeyStartOffset + 4,
-      prefixKeyEncoded, Platform.BYTE_ARRAY_OFFSET, prefixKeyEncodedLen)
+  override def encodeValue(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row)
 
-    // Here we calculate the remainingKeyEncodedLen leveraging the length of 
keyBytes
-    val remainingKeyEncodedLen = keyBytes.length - 4 - prefixKeyEncodedLen -
-      offsetForColFamilyPrefix
-
-    val remainingKeyEncoded = new Array[Byte](remainingKeyEncodedLen)
-    Platform.copyMemory(keyBytes, decodeKeyStartOffset + 4 + 
prefixKeyEncodedLen,
-      remainingKeyEncoded, Platform.BYTE_ARRAY_OFFSET, remainingKeyEncodedLen)
-
-    val (prefixKeyDecoded, remainingKeyDecoded) = if (usingAvroEncoding) {
-      (
-        decodeFromAvroToUnsafeRow(
-          prefixKeyEncoded,
-          avroEnc.get.keyDeserializer,
-          prefixKeyAvroType,
-          prefixKeyProj
-        ),
-        decodeFromAvroToUnsafeRow(
-          remainingKeyEncoded,
-          avroEnc.get.suffixKeyDeserializer.get,
-          remainingKeyAvroType,
-          remainingKeyProj
-        )
-      )
-    } else {
-      (decodeToUnsafeRow(prefixKeyEncoded, numFields = numColsPrefixKey),
-        decodeToUnsafeRow(remainingKeyEncoded, numFields = keySchema.length - 
numColsPrefixKey))
+  override def decodeKey(bytes: Array[Byte]): UnsafeRow = {
+    keyStateEncoderSpec match {
+      case NoPrefixKeyStateEncoderSpec(_) =>
+        decodeToUnsafeRow(bytes, reusedKeyRow)
+      case PrefixKeyScanStateEncoderSpec(_, numColsPrefixKey) =>
+        decodeToUnsafeRow(bytes, numFields = numColsPrefixKey)
+      case _ => throw unsupportedOperationForKeyStateEncoder("decodeKey")
     }
-
-    
restoreKeyProjection(joinedRowOnKey.withLeft(prefixKeyDecoded).withRight(remainingKeyDecoded))
   }
 
-  private def extractPrefixKey(key: UnsafeRow): UnsafeRow = {
-    prefixKeyProjection(key)
+  override def decodeRemainingKey(bytes: Array[Byte]): UnsafeRow = {
+    keyStateEncoderSpec match {
+      case PrefixKeyScanStateEncoderSpec(_, numColsPrefixKey) =>
+        decodeToUnsafeRow(bytes, numFields = numColsPrefixKey)
+      case RangeKeyScanStateEncoderSpec(_, orderingOrdinals) =>
+        decodeToUnsafeRow(bytes, keySchema.length - orderingOrdinals.length)
+      case _ => throw 
unsupportedOperationForKeyStateEncoder("decodeRemainingKey")
+    }
   }
 
-  override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = {
-    val prefixKeyEncoded = if (usingAvroEncoding) {
-      encodeUnsafeRowToAvro(prefixKey, avroEnc.get.keySerializer, 
prefixKeyAvroType, out)
-    } else {
-      encodeUnsafeRow(prefixKey)
+  override def decodePrefixKeyForRangeScan(bytes: Array[Byte]): UnsafeRow = {
+    assert(keyStateEncoderSpec.isInstanceOf[RangeKeyScanStateEncoderSpec])
+    val rsk = keyStateEncoderSpec.asInstanceOf[RangeKeyScanStateEncoderSpec]
+    val writer = new UnsafeRowWriter(rsk.orderingOrdinals.length)
+    val rangeScanKeyFieldsWithOrdinal = rsk.orderingOrdinals.map { ordinal =>
+      val field = rsk.keySchema(ordinal)
+      (field, ordinal)
     }
-    val (prefix, startingOffset) = encodeColumnFamilyPrefix(
-      prefixKeyEncoded.length + 4
-    )
-
-    Platform.putInt(prefix, startingOffset, prefixKeyEncoded.length)
-    Platform.copyMemory(prefixKeyEncoded, Platform.BYTE_ARRAY_OFFSET, prefix,
-      startingOffset + 4, prefixKeyEncoded.length)
-    prefix
-  }
+    writer.resetRowWriter()
+    val row = decodeToUnsafeRow(bytes, numFields = rsk.orderingOrdinals.length)
+    rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case 
(fieldWithOrdinal, idx) =>
+      val field = fieldWithOrdinal._1
 
-  override def supportPrefixKeyScan: Boolean = true
-}
+      val value = row.getBinary(idx)
+      val bbuf = ByteBuffer.wrap(value.asInstanceOf[Array[Byte]])
+      bbuf.order(ByteOrder.BIG_ENDIAN)
+      val isNullOrSignCol = bbuf.get()
+      if (isNullOrSignCol == nullValMarker) {
+        // set the column to null and skip reading the next byte(s)
+        writer.setNullAt(idx)
+      } else {
+        field.dataType match {
+          case BooleanType =>
+          case ByteType =>
+            writer.write(idx, bbuf.get)
 
-/**
- * RocksDB Key Encoder for UnsafeRow that supports range scan for fixed size 
fields
- *
- * To encode a row for range scan, we first project the orderingOrdinals from 
the oridinal
- * UnsafeRow into another UnsafeRow; we then rewrite that new UnsafeRow's 
fields in BIG_ENDIAN
- * to allow for scanning keys in sorted order using the byte-wise comparison 
method that
- * RocksDB uses.
- *
- * Then, for the rest of the fields, we project those into another UnsafeRow.
- * We then effectively join these two UnsafeRows together, and finally take 
those bytes
- * to get the resulting row.
- *
- * We cannot support variable sized fields in the range scan because the 
UnsafeRow format
- * stores variable sized fields as offset and length pointers to the actual 
values,
- * thereby changing the required ordering.
- *
- * Note that we also support "null" values being passed for these fixed size 
fields. We prepend
- * a single byte to indicate whether the column value is null or not. We 
cannot change the
- * nullability on the UnsafeRow itself as the expected ordering would change 
if non-first
- * columns are marked as null. If the first col is null, those entries will 
appear last in
- * the iterator. If non-first columns are null, ordering based on the previous 
columns will
- * still be honored. For rows with null column values, ordering for subsequent 
columns
- * will also be maintained within those set of rows. We use the same byte to 
also encode whether
- * the value is negative or not. For negative float/double values, we flip all 
the bits to ensure
- * the right lexicographical ordering. For the rationale around this, please 
check the link
- * here: https://en.wikipedia.org/wiki/IEEE_754#Design_rationale
- *
- * @param keySchema - schema of the key to be encoded
- * @param orderingOrdinals - the ordinals for which the range scan is 
constructed
- * @param useColumnFamilies - if column family is enabled for this encoder
- * @param avroEnc - if Avro encoding is specified for this StateEncoder, this 
encoder will
- *                be defined
- */
-class RangeKeyScanStateEncoder(
-    keySchema: StructType,
-    orderingOrdinals: Seq[Int],
-    useColumnFamilies: Boolean = false,
-    virtualColFamilyId: Option[Short] = None,
-    avroEnc: Option[AvroEncoder] = None)
-  extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) 
with Logging {
+          case ShortType =>
+            writer.write(idx, bbuf.getShort)
 
-  import RocksDBStateEncoder._
+          case IntegerType =>
+            writer.write(idx, bbuf.getInt)
 
-  private val rangeScanKeyFieldsWithOrdinal: Seq[(StructField, Int)] = {
-    orderingOrdinals.map { ordinal =>
-      val field = keySchema(ordinal)
-      (field, ordinal)
-    }
-  }
+          case LongType =>
+            writer.write(idx, bbuf.getLong)
 
-  private def isFixedSize(dataType: DataType): Boolean = dataType match {
-    case _: ByteType | _: BooleanType | _: ShortType | _: IntegerType | _: 
LongType |
-      _: FloatType | _: DoubleType => true
-    case _ => false
-  }
+          case FloatType =>
+            if (isNullOrSignCol == negativeValMarker) {
+              // if the number is negative, get the raw binary bits for the 
float
+              // and flip the bits back
+              val updatedVal = floatToRawIntBits(bbuf.getFloat) ^ 
floatFlipBitMask
+              writer.write(idx, intBitsToFloat(updatedVal))
+            } else {
+              writer.write(idx, bbuf.getFloat)
+            }
 
-  // verify that only fixed sized columns are used for ordering
-  rangeScanKeyFieldsWithOrdinal.foreach { case (field, ordinal) =>
-    if (!isFixedSize(field.dataType)) {
-      // NullType is technically fixed size, but not supported for ordering
-      if (field.dataType == NullType) {
-        throw StateStoreErrors.nullTypeOrderingColsNotSupported(field.name, 
ordinal.toString)
-      } else {
-        throw 
StateStoreErrors.variableSizeOrderingColsNotSupported(field.name, 
ordinal.toString)
+          case DoubleType =>
+            if (isNullOrSignCol == negativeValMarker) {
+              // if the number is negative, get the raw binary bits for the 
double
+              // and flip the bits back
+              val updatedVal = doubleToRawLongBits(bbuf.getDouble) ^ 
doubleFlipBitMask
+              writer.write(idx, longBitsToDouble(updatedVal))
+            } else {
+              writer.write(idx, bbuf.getDouble)
+            }
+        }
       }
     }
+    writer.getRow()
   }
 
-  private val remainingKeyFieldsWithOrdinal: Seq[(StructField, Int)] = {
-    0.to(keySchema.length - 1).diff(orderingOrdinals).map { ordinal =>
-      val field = keySchema(ordinal)
-      (field, ordinal)
-    }
-  }
+  override def decodeValue(bytes: Array[Byte]): UnsafeRow = 
decodeToUnsafeRow(bytes, reusedValueRow)
+}
 
-  private val rangeScanKeyProjection: UnsafeProjection = {
-    val refs = rangeScanKeyFieldsWithOrdinal.map(x =>
-      BoundReference(x._2, x._1.dataType, x._1.nullable))
-    UnsafeProjection.create(refs)
-  }
+class AvroStateEncoder(
+    keyStateEncoderSpec: KeyStateEncoderSpec,
+    valueSchema: StructType) extends RocksDBDataEncoder(keyStateEncoderSpec, 
valueSchema)
+    with Logging {
 
-  private val remainingKeyProjection: UnsafeProjection = {
-    val refs = remainingKeyFieldsWithOrdinal.map(x =>
-      BoundReference(x._2, x._1.dataType, x._1.nullable))
-    UnsafeProjection.create(refs)
-  }
+  private val avroEncoder = createAvroEnc(keyStateEncoderSpec, valueSchema)
+  // Avro schema used by the avro encoders
+  private lazy val keyAvroType: Schema = SchemaConverters.toAvroType(keySchema)
+  private lazy val keyProj = UnsafeProjection.create(keySchema)
 
-  // The original schema that we might get could be:
-  //    [foo, bar, baz, buzz]
-  // We might order by bar and buzz, leading to:
-  //    [bar, buzz, foo, baz]
-  // We need to create a projection that sends, for example, the buzz at index 
1 to index
-  // 3. Thus, for every record in the original schema, we compute where it 
would be in
-  // the joined row and created a projection based on that.
-  private val restoreKeyProjection: UnsafeProjection = {
-    val refs = keySchema.zipWithIndex.map { case (field, originalOrdinal) =>
-      val ordinalInJoinedRow = if (orderingOrdinals.contains(originalOrdinal)) 
{
-          orderingOrdinals.indexOf(originalOrdinal)
-      } else {
-          orderingOrdinals.length +
-            remainingKeyFieldsWithOrdinal.indexWhere(_._2 == originalOrdinal)
-      }
+  private lazy val valueAvroType: Schema = 
SchemaConverters.toAvroType(valueSchema)
+  private lazy val valueProj = UnsafeProjection.create(valueSchema)
 
-      BoundReference(ordinalInJoinedRow, field.dataType, field.nullable)
-    }
-    UnsafeProjection.create(refs)
+  // Prefix Key schema and projection definitions used by the Avro Serializers
+  // and Deserializers
+  private lazy val prefixKeySchema = keyStateEncoderSpec match {
+    case PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey) =>
+      StructType(keySchema.take (numColsPrefixKey))
+    case _ => throw unsupportedOperationForKeyStateEncoder("prefixKeySchema")
+  }
+  private lazy val prefixKeyAvroType = 
SchemaConverters.toAvroType(prefixKeySchema)
+  private lazy val prefixKeyProj = UnsafeProjection.create(prefixKeySchema)
+
+  // Range Key schema nd projection definitions used by the Avro Serializers 
and
+  // Deserializers
+  private lazy val rangeScanKeyFieldsWithOrdinal = keyStateEncoderSpec match {
+    case RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) =>
+      orderingOrdinals.map { ordinal =>
+        val field = keySchema(ordinal)
+        (field, ordinal)
+      }
+    case _ =>
+      throw unsupportedOperationForKeyStateEncoder("rangeScanKey")
   }
 
-  private val rangeScanAvroSchema = 
StateStoreColumnFamilySchemaUtils.convertForRangeScan(
+  private lazy val rangeScanAvroSchema = 
StateStoreColumnFamilySchemaUtils.convertForRangeScan(
     StructType(rangeScanKeyFieldsWithOrdinal.map(_._1).toArray))
 
   private lazy val rangeScanAvroType = 
SchemaConverters.toAvroType(rangeScanAvroSchema)
 
-  private val rangeScanAvroProjection = 
UnsafeProjection.create(rangeScanAvroSchema)
+  private lazy val rangeScanAvroProjection = 
UnsafeProjection.create(rangeScanAvroSchema)
 
-  // Existing remainder key schema stuff
-  private val remainingKeySchema = StructType(
-    0.to(keySchema.length - 1).diff(orderingOrdinals).map(keySchema(_))
-  )
+  // Existing remainder key schema definitions
+  // Remaining Key schema and projection definitions used by the Avro 
Serializers
+  // and Deserializers
+  private lazy val remainingKeySchema = keyStateEncoderSpec match {
+    case PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey) =>
+      StructType(keySchema.drop(numColsPrefixKey))
+    case RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) =>
+      
StructType(0.until(keySchema.length).diff(orderingOrdinals).map(keySchema(_)))
+    case _ => throw 
unsupportedOperationForKeyStateEncoder("remainingKeySchema")
+  }
 
   private lazy val remainingKeyAvroType = 
SchemaConverters.toAvroType(remainingKeySchema)
 
-  private val remainingKeyAvroProjection = 
UnsafeProjection.create(remainingKeySchema)
+  private lazy val remainingKeyAvroProjection = 
UnsafeProjection.create(remainingKeySchema)
 
-  // Reusable objects
-  private val joinedRowOnKey = new JoinedRow()
-
-  private def extractPrefixKey(key: UnsafeRow): UnsafeRow = {
-    rangeScanKeyProjection(key)
+  private def getAvroSerializer(schema: StructType): AvroSerializer = {
+    val avroType = SchemaConverters.toAvroType(schema)
+    new AvroSerializer(schema, avroType, nullable = false)
   }
 
-  // bit masks used for checking sign or flipping all bits for negative 
float/double values
-  private val floatFlipBitMask = 0xFFFFFFFF
-  private val floatSignBitMask = 0x80000000
+  private def getAvroDeserializer(schema: StructType): AvroDeserializer = {
+    val avroType = SchemaConverters.toAvroType(schema)
+    val avroOptions = AvroOptions(Map.empty)
+    new AvroDeserializer(avroType, schema,
+      avroOptions.datetimeRebaseModeInRead, 
avroOptions.useStableIdForUnionType,
+      avroOptions.stableIdPrefixForUnionType, 
avroOptions.recursiveFieldMaxDepth)
+  }
 
-  private val doubleFlipBitMask = 0xFFFFFFFFFFFFFFFFL
-  private val doubleSignBitMask = 0x8000000000000000L
+  /**
+   * Creates an AvroEncoder that handles both key and value 
serialization/deserialization.
+   * This method sets up the complete encoding infrastructure needed for state 
store operations.
+   *
+   * The encoder handles different key encoding specifications:
+   * - NoPrefixKeyStateEncoderSpec: Simple key encoding without prefix
+   * - PrefixKeyScanStateEncoderSpec: Keys with prefix for efficient scanning
+   * - RangeKeyScanStateEncoderSpec: Keys with ordering requirements for range 
scans
+   *
+   * For prefix scan cases, it also creates separate encoders for the suffix 
portion of keys.
+   *
+   * @param keyStateEncoderSpec Specification for how to encode keys
+   * @param valueSchema Schema for the values to be encoded
+   * @return An AvroEncoder containing all necessary serializers and 
deserializers
+   */
+  private def createAvroEnc(
+      keyStateEncoderSpec: KeyStateEncoderSpec,
+      valueSchema: StructType): AvroEncoder = {
+    val valueSerializer = getAvroSerializer(valueSchema)
+    val valueDeserializer = getAvroDeserializer(valueSchema)
+
+    // Get key schema based on encoder spec type
+    val keySchema = keyStateEncoderSpec match {
+      case NoPrefixKeyStateEncoderSpec(schema) =>
+        schema
+      case PrefixKeyScanStateEncoderSpec(schema, numColsPrefixKey) =>
+        StructType(schema.take(numColsPrefixKey))
+      case RangeKeyScanStateEncoderSpec(schema, orderingOrdinals) =>
+        val remainingSchema = {
+          0.until(schema.length).diff(orderingOrdinals).map { ordinal =>
+            schema(ordinal)
+          }
+        }
+        StructType(remainingSchema)
+    }
 
-  // Byte markers used to identify whether the value is null, negative or 
positive
-  // To ensure sorted ordering, we use the lowest byte value for negative 
numbers followed by
-  // positive numbers and then null values.
-  private val negativeValMarker: Byte = 0x00.toByte
-  private val positiveValMarker: Byte = 0x01.toByte
-  private val nullValMarker: Byte = 0x02.toByte
-
-  // Rewrite the unsafe row by replacing fixed size fields with BIG_ENDIAN 
encoding
-  // using byte arrays.
-  // To handle "null" values, we prepend a byte to the byte array indicating 
whether the value
-  // is null or not. If the value is null, we write the null byte followed by 
zero bytes.
-  // If the value is not null, we write the null byte followed by the value.
-  // Note that setting null for the index on the unsafeRow is not feasible as 
it would change
-  // the sorting order on iteration.
-  // Also note that the same byte is used to indicate whether the value is 
negative or not.
-  private def encodePrefixKeyForRangeScan(row: UnsafeRow): UnsafeRow = {
-    val writer = new UnsafeRowWriter(orderingOrdinals.length)
-    writer.resetRowWriter()
-    rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case 
(fieldWithOrdinal, idx) =>
-      val field = fieldWithOrdinal._1
-      val value = row.get(idx, field.dataType)
-      // Note that we cannot allocate a smaller buffer here even if the value 
is null
-      // because the effective byte array is considered variable size and 
needs to have
-      // the same size across all rows for the ordering to work as expected.
-      val bbuf = ByteBuffer.allocate(field.dataType.defaultSize + 1)
-      bbuf.order(ByteOrder.BIG_ENDIAN)
-      if (value == null) {
-        bbuf.put(nullValMarker)
-        writer.write(idx, bbuf.array())
-      } else {
-        field.dataType match {
-          case BooleanType =>
-          case ByteType =>
-            val byteVal = value.asInstanceOf[Byte]
-            val signCol = if (byteVal < 0) {
-              negativeValMarker
-            } else {
-              positiveValMarker
-            }
-            bbuf.put(signCol)
-            bbuf.put(byteVal)
-            writer.write(idx, bbuf.array())
+    // Handle suffix key schema for prefix scan case
+    val suffixKeySchema = keyStateEncoderSpec match {
+      case PrefixKeyScanStateEncoderSpec(schema, numColsPrefixKey) =>
+        Some(StructType(schema.drop(numColsPrefixKey)))
+      case _ =>
+        None
+    }
 
-          case ShortType =>
-            val shortVal = value.asInstanceOf[Short]
-            val signCol = if (shortVal < 0) {
-              negativeValMarker
-            } else {
-              positiveValMarker
-            }
-            bbuf.put(signCol)
-            bbuf.putShort(shortVal)
-            writer.write(idx, bbuf.array())
+    val keySerializer = getAvroSerializer(keySchema)
+    val keyDeserializer = getAvroDeserializer(keySchema)
+
+    // Create the AvroEncoder with all components
+    AvroEncoder(
+      keySerializer,
+      keyDeserializer,
+      valueSerializer,
+      valueDeserializer,
+      suffixKeySchema.map(getAvroSerializer),
+      suffixKeySchema.map(getAvroDeserializer)
+    )
+  }
 
-          case IntegerType =>
-            val intVal = value.asInstanceOf[Int]
-            val signCol = if (intVal < 0) {
-              negativeValMarker
-            } else {
-              positiveValMarker
-            }
-            bbuf.put(signCol)
-            bbuf.putInt(intVal)
-            writer.write(idx, bbuf.array())
-
-          case LongType =>
-            val longVal = value.asInstanceOf[Long]
-            val signCol = if (longVal < 0) {
-              negativeValMarker
-            } else {
-              positiveValMarker
-            }
-            bbuf.put(signCol)
-            bbuf.putLong(longVal)
-            writer.write(idx, bbuf.array())
-
-          case FloatType =>
-            val floatVal = value.asInstanceOf[Float]
-            val rawBits = floatToRawIntBits(floatVal)
-            // perform sign comparison using bit manipulation to ensure NaN 
values are handled
-            // correctly
-            if ((rawBits & floatSignBitMask) != 0) {
-              // for negative values, we need to flip all the bits to ensure 
correct ordering
-              val updatedVal = rawBits ^ floatFlipBitMask
-              bbuf.put(negativeValMarker)
-              // convert the bits back to float
-              bbuf.putFloat(intBitsToFloat(updatedVal))
-            } else {
-              bbuf.put(positiveValMarker)
-              bbuf.putFloat(floatVal)
-            }
-            writer.write(idx, bbuf.array())
+  /**
+   * This method takes an UnsafeRow, and serializes to a byte array using Avro 
encoding.
+   */
+  def encodeUnsafeRowToAvro(
+      row: UnsafeRow,
+      avroSerializer: AvroSerializer,
+      valueAvroType: Schema,
+      out: ByteArrayOutputStream): Array[Byte] = {
+    // InternalRow -> Avro.GenericDataRecord
+    val avroData =
+      avroSerializer.serialize(row)
+    out.reset()
+    val encoder = EncoderFactory.get().directBinaryEncoder(out, null)
+    val writer = new GenericDatumWriter[Any](
+      valueAvroType) // Defining Avro writer for this struct type
+    writer.write(avroData, encoder) // Avro.GenericDataRecord -> byte array
+    encoder.flush()
+    out.toByteArray
+  }
 
-          case DoubleType =>
-            val doubleVal = value.asInstanceOf[Double]
-            val rawBits = doubleToRawLongBits(doubleVal)
-            // perform sign comparison using bit manipulation to ensure NaN 
values are handled
-            // correctly
-            if ((rawBits & doubleSignBitMask) != 0) {
-              // for negative values, we need to flip all the bits to ensure 
correct ordering
-              val updatedVal = rawBits ^ doubleFlipBitMask
-              bbuf.put(negativeValMarker)
-              // convert the bits back to double
-              bbuf.putDouble(longBitsToDouble(updatedVal))
-            } else {
-              bbuf.put(positiveValMarker)
-              bbuf.putDouble(doubleVal)
-            }
-            writer.write(idx, bbuf.array())
-        }
-      }
+  /**
+   * This method takes a byte array written using Avro encoding, and
+   * deserializes to an UnsafeRow using the Avro deserializer
+   */
+  def decodeFromAvroToUnsafeRow(
+      valueBytes: Array[Byte],
+      avroDeserializer: AvroDeserializer,
+      valueAvroType: Schema,
+      valueProj: UnsafeProjection): UnsafeRow = {
+    if (valueBytes != null) {
+      val reader = new GenericDatumReader[Any](valueAvroType)
+      val decoder = DecoderFactory.get().binaryDecoder(
+        valueBytes, 0, valueBytes.length, null)
+      // bytes -> Avro.GenericDataRecord
+      val genericData = reader.read(null, decoder)
+      // Avro.GenericDataRecord -> InternalRow
+      val internalRow = avroDeserializer.deserialize(
+        genericData).orNull.asInstanceOf[InternalRow]
+      // InternalRow -> UnsafeRow
+      valueProj.apply(internalRow)
+    } else {
+      null
     }
-    writer.getRow()
   }
 
-  // Rewrite the unsafe row by converting back from BIG_ENDIAN byte arrays to 
the
-  // original data types.
-  // For decode, we extract the byte array from the UnsafeRow, and then read 
the first byte
-  // to determine if the value is null or not. If the value is null, we set 
the ordinal on
-  // the UnsafeRow to null. If the value is not null, we read the rest of the 
bytes to get the
-  // actual value.
-  // For negative float/double values, we need to flip all the bits back to 
get the original value.
-  private def decodePrefixKeyForRangeScan(row: UnsafeRow): UnsafeRow = {
-    val writer = new UnsafeRowWriter(orderingOrdinals.length)
-    writer.resetRowWriter()
-    rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case 
(fieldWithOrdinal, idx) =>
-      val field = fieldWithOrdinal._1
-
-      val value = row.getBinary(idx)
-      val bbuf = ByteBuffer.wrap(value.asInstanceOf[Array[Byte]])
-      bbuf.order(ByteOrder.BIG_ENDIAN)
-      val isNullOrSignCol = bbuf.get()
-      if (isNullOrSignCol == nullValMarker) {
-        // set the column to null and skip reading the next byte(s)
-        writer.setNullAt(idx)
-      } else {
-        field.dataType match {
-          case BooleanType =>
-          case ByteType =>
-            writer.write(idx, bbuf.get)
-
-          case ShortType =>
-            writer.write(idx, bbuf.getShort)
-
-          case IntegerType =>
-            writer.write(idx, bbuf.getInt)
-
-          case LongType =>
-            writer.write(idx, bbuf.getLong)
+  private val out = new ByteArrayOutputStream
 
-          case FloatType =>
-            if (isNullOrSignCol == negativeValMarker) {
-              // if the number is negative, get the raw binary bits for the 
float
-              // and flip the bits back
-              val updatedVal = floatToRawIntBits(bbuf.getFloat) ^ 
floatFlipBitMask
-              writer.write(idx, intBitsToFloat(updatedVal))
-            } else {
-              writer.write(idx, bbuf.getFloat)
-            }
+  override def encodeKey(row: UnsafeRow): Array[Byte] = {
+    keyStateEncoderSpec match {
+      case NoPrefixKeyStateEncoderSpec(_) =>
+        encodeUnsafeRowToAvro(row, avroEncoder.keySerializer, keyAvroType, out)
+      case PrefixKeyScanStateEncoderSpec(_, _) =>
+        encodeUnsafeRowToAvro(row, avroEncoder.keySerializer, 
prefixKeyAvroType, out)
+      case _ => throw unsupportedOperationForKeyStateEncoder("encodeKey")
+    }
+  }
 
-          case DoubleType =>
-            if (isNullOrSignCol == negativeValMarker) {
-              // if the number is negative, get the raw binary bits for the 
double
-              // and flip the bits back
-              val updatedVal = doubleToRawLongBits(bbuf.getDouble) ^ 
doubleFlipBitMask
-              writer.write(idx, longBitsToDouble(updatedVal))
-            } else {
-              writer.write(idx, bbuf.getDouble)
-            }
-        }
-      }
+  override def encodeRemainingKey(row: UnsafeRow): Array[Byte] = {
+    keyStateEncoderSpec match {
+      case PrefixKeyScanStateEncoderSpec(_, _) =>
+        encodeUnsafeRowToAvro(row, avroEncoder.suffixKeySerializer.get, 
remainingKeyAvroType, out)
+      case RangeKeyScanStateEncoderSpec(_, _) =>
+        encodeUnsafeRowToAvro(row, avroEncoder.keySerializer, 
remainingKeyAvroType, out)

Review Comment:
   No, this is right - RangeKeyScan encoder doesn't use Avro serialization for 
both parts of the ke



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to