brkyvz commented on code in PR #48401:
URL: https://github.com/apache/spark/pull/48401#discussion_r1852032761


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala:
##########
@@ -563,13 +684,233 @@ class RangeKeyScanStateEncoder(
     writer.getRow()
   }
 
+  def encodePrefixKeyForRangeScan(
+      row: UnsafeRow,
+      avroType: Schema): Array[Byte] = {
+    val record = new GenericData.Record(avroType)
+    var fieldIdx = 0
+    rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case 
(fieldWithOrdinal, idx) =>
+      val field = fieldWithOrdinal._1
+      val value = row.get(idx, field.dataType)
+
+      // Create marker byte buffer
+      val markerBuffer = ByteBuffer.allocate(1)
+      markerBuffer.order(ByteOrder.BIG_ENDIAN)
+
+      if (value == null) {
+        markerBuffer.put(nullValMarker)
+        record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array()))
+        record.put(fieldIdx + 1, ByteBuffer.wrap(new 
Array[Byte](field.dataType.defaultSize)))
+      } else {
+        field.dataType match {
+          case BooleanType =>
+            markerBuffer.put(positiveValMarker)
+            record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array()))
+            val valueBuffer = ByteBuffer.allocate(1)
+            valueBuffer.put(if (value.asInstanceOf[Boolean]) 1.toByte else 
0.toByte)
+            record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array()))
+
+          case ByteType =>
+            val byteVal = value.asInstanceOf[Byte]
+            markerBuffer.put(if (byteVal < 0) negativeValMarker else 
positiveValMarker)
+            record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array()))
+
+            val valueBuffer = ByteBuffer.allocate(1)
+            valueBuffer.order(ByteOrder.BIG_ENDIAN)
+            valueBuffer.put(byteVal)
+            record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array()))
+
+          case ShortType =>
+            val shortVal = value.asInstanceOf[Short]
+            markerBuffer.put(if (shortVal < 0) negativeValMarker else 
positiveValMarker)
+            record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array()))
+
+            val valueBuffer = ByteBuffer.allocate(2)
+            valueBuffer.order(ByteOrder.BIG_ENDIAN)
+            valueBuffer.putShort(shortVal)
+            record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array()))
+
+          case IntegerType =>
+            val intVal = value.asInstanceOf[Int]
+            markerBuffer.put(if (intVal < 0) negativeValMarker else 
positiveValMarker)
+            record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array()))
+
+            val valueBuffer = ByteBuffer.allocate(4)
+            valueBuffer.order(ByteOrder.BIG_ENDIAN)
+            valueBuffer.putInt(intVal)
+            record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array()))
+
+          case LongType =>
+            val longVal = value.asInstanceOf[Long]
+            markerBuffer.put(if (longVal < 0) negativeValMarker else 
positiveValMarker)
+            record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array()))
+
+            val valueBuffer = ByteBuffer.allocate(8)
+            valueBuffer.order(ByteOrder.BIG_ENDIAN)
+            valueBuffer.putLong(longVal)
+            record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array()))
+
+          case FloatType =>
+            val floatVal = value.asInstanceOf[Float]
+            val rawBits = floatToRawIntBits(floatVal)
+            markerBuffer.put(if ((rawBits & floatSignBitMask) != 0) {
+              negativeValMarker
+            } else {
+              positiveValMarker
+            })
+            record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array()))
+
+            val valueBuffer = ByteBuffer.allocate(4)
+            valueBuffer.order(ByteOrder.BIG_ENDIAN)
+            if ((rawBits & floatSignBitMask) != 0) {
+              val updatedVal = rawBits ^ floatFlipBitMask
+              valueBuffer.putFloat(intBitsToFloat(updatedVal))
+            } else {
+              valueBuffer.putFloat(floatVal)
+            }
+            record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array()))
+
+          case DoubleType =>
+            val doubleVal = value.asInstanceOf[Double]
+            val rawBits = doubleToRawLongBits(doubleVal)
+            markerBuffer.put(if ((rawBits & doubleSignBitMask) != 0) {
+              negativeValMarker
+            } else {
+              positiveValMarker
+            })
+            record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array()))
+
+            val valueBuffer = ByteBuffer.allocate(8)
+            valueBuffer.order(ByteOrder.BIG_ENDIAN)
+            if ((rawBits & doubleSignBitMask) != 0) {
+              val updatedVal = rawBits ^ doubleFlipBitMask
+              valueBuffer.putDouble(longBitsToDouble(updatedVal))
+            } else {
+              valueBuffer.putDouble(doubleVal)
+            }
+            record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array()))
+
+          case _ => throw new UnsupportedOperationException(
+            s"Range scan encoding not supported for data type: 
${field.dataType}")
+        }
+      }
+      fieldIdx += 2
+    }
+
+    out.reset()
+    val writer = new GenericDatumWriter[GenericRecord](rangeScanAvroType)
+    val encoder = EncoderFactory.get().binaryEncoder(out, null)
+    writer.write(record, encoder)
+    encoder.flush()
+    out.toByteArray
+  }
+
+  def decodePrefixKeyForRangeScan(

Review Comment:
   ditto on scaladoc please



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala:
##########
@@ -299,13 +401,16 @@ class PrefixKeyScanStateEncoder(
  * @param keySchema - schema of the key to be encoded
  * @param orderingOrdinals - the ordinals for which the range scan is 
constructed
  * @param useColumnFamilies - if column family is enabled for this encoder
+ * @param avroEnc - if Avro encoding is specified for this StateEncoder, this 
encoder will
+ *                be defined
  */
 class RangeKeyScanStateEncoder(
     keySchema: StructType,
     orderingOrdinals: Seq[Int],
     useColumnFamilies: Boolean = false,
-    virtualColFamilyId: Option[Short] = None)
-  extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) {
+    virtualColFamilyId: Option[Short] = None,
+    avroEnc: Option[AvroEncoder] = None)

Review Comment:
   Instead of avroEnc, I would honestly introduce another interface:
   
   ```scala
   trait Serde {
   
     def encodeToBytes(...)
   
     def decodeToUnsafeRow(...)
     
     def encodePrefixKeyForRangeScan(...)
   
     def decodePrefixKeyForRangeScan(...)
   }
   ```
   and move the logic in there so that you don't have to keep on doing 
`avroEnc.isDefined` for these
   
   The logic seems pretty similar except for the input data. The AvroStateSerde 
or whatever you want to name it would have the `private lazy val 
remainingKeyAvroType = SchemaConverters.toAvroType(remainingKeySchema)`



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala:
##########
@@ -563,13 +684,233 @@ class RangeKeyScanStateEncoder(
     writer.getRow()
   }
 
+  def encodePrefixKeyForRangeScan(

Review Comment:
   Can you add a scaladoc please?



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala:
##########
@@ -492,15 +495,16 @@ class RocksDBStateStoreSuite extends 
StateStoreSuiteBase[RocksDBStateStoreProvid
     }
   }
 
-  testWithColumnFamilies("rocksdb range scan multiple non-contiguous ordering 
columns",
+  testWithColumnFamiliesAndEncodingTypes("rocksdb range scan multiple " +
+    "non-contiguous ordering columns",
     TestWithBothChangelogCheckpointingEnabledAndDisabled ) { 
colFamiliesEnabled =>
     val testSchema: StructType = StructType(
       Seq(
-        StructField("ordering-1", LongType, false),

Review Comment:
   oh, why'd you have to change these? If these are not supported by Avro, do 
we have any check anywhere to disallow the usage of the Avro encoder? 



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala:
##########
@@ -37,6 +38,17 @@ case class StateSchemaValidationResult(
     schemaPath: String
 )
 
+// Avro encoder that is used by the RocksDBStateStoreProvider and 
RocksDBStateEncoder
+// in order to serialize from UnsafeRow to a byte array of Avro encoding.

Review Comment:
   Can you please turn this into a proper scaladoc?
   ```scala
   /**
    * ...
    */
   ```



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala:
##########
@@ -74,10 +75,71 @@ private[sql] class RocksDBStateStoreProvider
         isInternal: Boolean = false): Unit = {
       verifyColFamilyCreationOrDeletion("create_col_family", colFamilyName, 
isInternal)
       val newColFamilyId = rocksDB.createColFamilyIfAbsent(colFamilyName)
+      // Create cache key using store ID to avoid collisions
+      val avroEncCacheKey = s"${stateStoreId.operatorId}_" +

Review Comment:
   Do we have the stream runId (maybe it's available in the HadoopConf)? We 
should add runId, otherwise there could be collisions



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala:
##########
@@ -74,10 +75,71 @@ private[sql] class RocksDBStateStoreProvider
         isInternal: Boolean = false): Unit = {
       verifyColFamilyCreationOrDeletion("create_col_family", colFamilyName, 
isInternal)
       val newColFamilyId = rocksDB.createColFamilyIfAbsent(colFamilyName)
+      // Create cache key using store ID to avoid collisions
+      val avroEncCacheKey = s"${stateStoreId.operatorId}_" +
+        s"${stateStoreId.partitionId}_$colFamilyName"
+
+      // If we have not created the avroEncoder for this column family, create
+      // it, or look in the cache maintained in the RocksDBStateStoreProvider
+      // companion object
+      lazy val avroEnc = stateStoreEncoding match {
+        case "avro" => Some(
+          
RocksDBStateStoreProvider.avroEncoderMap.computeIfAbsent(avroEncCacheKey,
+            _ => getAvroEnc(keyStateEncoderSpec, valueSchema))
+        )
+        case "unsaferow" => None
+      }
+
       keyValueEncoderMap.putIfAbsent(colFamilyName,
         (RocksDBStateEncoder.getKeyEncoder(keyStateEncoderSpec, 
useColumnFamilies,
-          Some(newColFamilyId)), 
RocksDBStateEncoder.getValueEncoder(valueSchema,
-          useMultipleValuesPerKey)))
+          Some(newColFamilyId), avroEnc), 
RocksDBStateEncoder.getValueEncoder(valueSchema,
+          useMultipleValuesPerKey, avroEnc)))
+    }
+    private def getAvroSerializer(schema: StructType): AvroSerializer = {

Review Comment:
   nit: line before the method please



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala:
##########
@@ -128,6 +128,73 @@ trait AlsoTestWithChangelogCheckpointingEnabled
     }
   }
 
+  def testWithEncodingTypes(testName: String, testTags: Tag*)
+                           (testBody: => Any): Unit = {

Review Comment:
   one parameter per line like below please



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala:
##########
@@ -58,7 +58,7 @@ class RocksDBStateStoreSuite extends 
StateStoreSuiteBase[RocksDBStateStoreProvid
 
   import StateStoreTestsHelper._
 
-  testWithColumnFamilies(s"version encoding",
+  testWithColumnFamiliesAndEncodingTypes(s"version encoding",

Review Comment:
   I wonder if it was better to just extend these classes and override the 
SQLConf



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala:
##########
@@ -661,19 +1016,27 @@ class RangeKeyScanStateEncoder(
 class NoPrefixKeyStateEncoder(
     keySchema: StructType,
     useColumnFamilies: Boolean = false,
-    virtualColFamilyId: Option[Short] = None)
-  extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) {
+    virtualColFamilyId: Option[Short] = None,
+    avroEnc: Option[AvroEncoder] = None)

Review Comment:
   ditto on the Serde.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala:
##########
@@ -593,6 +657,9 @@ object RocksDBStateStoreProvider {
   val STATE_ENCODING_VERSION: Byte = 0
   val VIRTUAL_COL_FAMILY_PREFIX_BYTES = 2
 
+  // Add the cache at companion object level so it persists across provider 
instances
+  private val avroEncoderMap = new 
java.util.concurrent.ConcurrentHashMap[String, AvroEncoder]()

Review Comment:
   Do we want to leverage LinkedHashMap to limit the size of the cache? 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to