brkyvz commented on code in PR #48401:
URL: https://github.com/apache/spark/pull/48401#discussion_r1852032761
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala:
##########
@@ -563,13 +684,233 @@ class RangeKeyScanStateEncoder(
writer.getRow()
}
+ def encodePrefixKeyForRangeScan(
+ row: UnsafeRow,
+ avroType: Schema): Array[Byte] = {
+ val record = new GenericData.Record(avroType)
+ var fieldIdx = 0
+ rangeScanKeyFieldsWithOrdinal.zipWithIndex.foreach { case
(fieldWithOrdinal, idx) =>
+ val field = fieldWithOrdinal._1
+ val value = row.get(idx, field.dataType)
+
+ // Create marker byte buffer
+ val markerBuffer = ByteBuffer.allocate(1)
+ markerBuffer.order(ByteOrder.BIG_ENDIAN)
+
+ if (value == null) {
+ markerBuffer.put(nullValMarker)
+ record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array()))
+ record.put(fieldIdx + 1, ByteBuffer.wrap(new
Array[Byte](field.dataType.defaultSize)))
+ } else {
+ field.dataType match {
+ case BooleanType =>
+ markerBuffer.put(positiveValMarker)
+ record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array()))
+ val valueBuffer = ByteBuffer.allocate(1)
+ valueBuffer.put(if (value.asInstanceOf[Boolean]) 1.toByte else
0.toByte)
+ record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array()))
+
+ case ByteType =>
+ val byteVal = value.asInstanceOf[Byte]
+ markerBuffer.put(if (byteVal < 0) negativeValMarker else
positiveValMarker)
+ record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array()))
+
+ val valueBuffer = ByteBuffer.allocate(1)
+ valueBuffer.order(ByteOrder.BIG_ENDIAN)
+ valueBuffer.put(byteVal)
+ record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array()))
+
+ case ShortType =>
+ val shortVal = value.asInstanceOf[Short]
+ markerBuffer.put(if (shortVal < 0) negativeValMarker else
positiveValMarker)
+ record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array()))
+
+ val valueBuffer = ByteBuffer.allocate(2)
+ valueBuffer.order(ByteOrder.BIG_ENDIAN)
+ valueBuffer.putShort(shortVal)
+ record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array()))
+
+ case IntegerType =>
+ val intVal = value.asInstanceOf[Int]
+ markerBuffer.put(if (intVal < 0) negativeValMarker else
positiveValMarker)
+ record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array()))
+
+ val valueBuffer = ByteBuffer.allocate(4)
+ valueBuffer.order(ByteOrder.BIG_ENDIAN)
+ valueBuffer.putInt(intVal)
+ record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array()))
+
+ case LongType =>
+ val longVal = value.asInstanceOf[Long]
+ markerBuffer.put(if (longVal < 0) negativeValMarker else
positiveValMarker)
+ record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array()))
+
+ val valueBuffer = ByteBuffer.allocate(8)
+ valueBuffer.order(ByteOrder.BIG_ENDIAN)
+ valueBuffer.putLong(longVal)
+ record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array()))
+
+ case FloatType =>
+ val floatVal = value.asInstanceOf[Float]
+ val rawBits = floatToRawIntBits(floatVal)
+ markerBuffer.put(if ((rawBits & floatSignBitMask) != 0) {
+ negativeValMarker
+ } else {
+ positiveValMarker
+ })
+ record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array()))
+
+ val valueBuffer = ByteBuffer.allocate(4)
+ valueBuffer.order(ByteOrder.BIG_ENDIAN)
+ if ((rawBits & floatSignBitMask) != 0) {
+ val updatedVal = rawBits ^ floatFlipBitMask
+ valueBuffer.putFloat(intBitsToFloat(updatedVal))
+ } else {
+ valueBuffer.putFloat(floatVal)
+ }
+ record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array()))
+
+ case DoubleType =>
+ val doubleVal = value.asInstanceOf[Double]
+ val rawBits = doubleToRawLongBits(doubleVal)
+ markerBuffer.put(if ((rawBits & doubleSignBitMask) != 0) {
+ negativeValMarker
+ } else {
+ positiveValMarker
+ })
+ record.put(fieldIdx, ByteBuffer.wrap(markerBuffer.array()))
+
+ val valueBuffer = ByteBuffer.allocate(8)
+ valueBuffer.order(ByteOrder.BIG_ENDIAN)
+ if ((rawBits & doubleSignBitMask) != 0) {
+ val updatedVal = rawBits ^ doubleFlipBitMask
+ valueBuffer.putDouble(longBitsToDouble(updatedVal))
+ } else {
+ valueBuffer.putDouble(doubleVal)
+ }
+ record.put(fieldIdx + 1, ByteBuffer.wrap(valueBuffer.array()))
+
+ case _ => throw new UnsupportedOperationException(
+ s"Range scan encoding not supported for data type:
${field.dataType}")
+ }
+ }
+ fieldIdx += 2
+ }
+
+ out.reset()
+ val writer = new GenericDatumWriter[GenericRecord](rangeScanAvroType)
+ val encoder = EncoderFactory.get().binaryEncoder(out, null)
+ writer.write(record, encoder)
+ encoder.flush()
+ out.toByteArray
+ }
+
+ def decodePrefixKeyForRangeScan(
Review Comment:
ditto on scaladoc please
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala:
##########
@@ -299,13 +401,16 @@ class PrefixKeyScanStateEncoder(
* @param keySchema - schema of the key to be encoded
* @param orderingOrdinals - the ordinals for which the range scan is
constructed
* @param useColumnFamilies - if column family is enabled for this encoder
+ * @param avroEnc - if Avro encoding is specified for this StateEncoder, this
encoder will
+ * be defined
*/
class RangeKeyScanStateEncoder(
keySchema: StructType,
orderingOrdinals: Seq[Int],
useColumnFamilies: Boolean = false,
- virtualColFamilyId: Option[Short] = None)
- extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) {
+ virtualColFamilyId: Option[Short] = None,
+ avroEnc: Option[AvroEncoder] = None)
Review Comment:
Instead of avroEnc, I would honestly introduce another interface:
```scala
trait Serde {
def encodeToBytes(...)
def decodeToUnsafeRow(...)
def encodePrefixKeyForRangeScan(...)
def decodePrefixKeyForRangeScan(...)
}
```
and move the logic in there so that you don't have to keep on doing
`avroEnc.isDefined` for these
The logic seems pretty similar except for the input data. The AvroStateSerde
or whatever you want to name it would have the `private lazy val
remainingKeyAvroType = SchemaConverters.toAvroType(remainingKeySchema)`
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala:
##########
@@ -563,13 +684,233 @@ class RangeKeyScanStateEncoder(
writer.getRow()
}
+ def encodePrefixKeyForRangeScan(
Review Comment:
Can you add a scaladoc please?
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala:
##########
@@ -492,15 +495,16 @@ class RocksDBStateStoreSuite extends
StateStoreSuiteBase[RocksDBStateStoreProvid
}
}
- testWithColumnFamilies("rocksdb range scan multiple non-contiguous ordering
columns",
+ testWithColumnFamiliesAndEncodingTypes("rocksdb range scan multiple " +
+ "non-contiguous ordering columns",
TestWithBothChangelogCheckpointingEnabledAndDisabled ) {
colFamiliesEnabled =>
val testSchema: StructType = StructType(
Seq(
- StructField("ordering-1", LongType, false),
Review Comment:
oh, why'd you have to change these? If these are not supported by Avro, do
we have any check anywhere to disallow the usage of the Avro encoder?
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala:
##########
@@ -37,6 +38,17 @@ case class StateSchemaValidationResult(
schemaPath: String
)
+// Avro encoder that is used by the RocksDBStateStoreProvider and
RocksDBStateEncoder
+// in order to serialize from UnsafeRow to a byte array of Avro encoding.
Review Comment:
Can you please turn this into a proper scaladoc?
```scala
/**
* ...
*/
```
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala:
##########
@@ -74,10 +75,71 @@ private[sql] class RocksDBStateStoreProvider
isInternal: Boolean = false): Unit = {
verifyColFamilyCreationOrDeletion("create_col_family", colFamilyName,
isInternal)
val newColFamilyId = rocksDB.createColFamilyIfAbsent(colFamilyName)
+ // Create cache key using store ID to avoid collisions
+ val avroEncCacheKey = s"${stateStoreId.operatorId}_" +
Review Comment:
Do we have the stream runId (maybe it's available in the HadoopConf)? We
should add runId, otherwise there could be collisions
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala:
##########
@@ -74,10 +75,71 @@ private[sql] class RocksDBStateStoreProvider
isInternal: Boolean = false): Unit = {
verifyColFamilyCreationOrDeletion("create_col_family", colFamilyName,
isInternal)
val newColFamilyId = rocksDB.createColFamilyIfAbsent(colFamilyName)
+ // Create cache key using store ID to avoid collisions
+ val avroEncCacheKey = s"${stateStoreId.operatorId}_" +
+ s"${stateStoreId.partitionId}_$colFamilyName"
+
+ // If we have not created the avroEncoder for this column family, create
+ // it, or look in the cache maintained in the RocksDBStateStoreProvider
+ // companion object
+ lazy val avroEnc = stateStoreEncoding match {
+ case "avro" => Some(
+
RocksDBStateStoreProvider.avroEncoderMap.computeIfAbsent(avroEncCacheKey,
+ _ => getAvroEnc(keyStateEncoderSpec, valueSchema))
+ )
+ case "unsaferow" => None
+ }
+
keyValueEncoderMap.putIfAbsent(colFamilyName,
(RocksDBStateEncoder.getKeyEncoder(keyStateEncoderSpec,
useColumnFamilies,
- Some(newColFamilyId)),
RocksDBStateEncoder.getValueEncoder(valueSchema,
- useMultipleValuesPerKey)))
+ Some(newColFamilyId), avroEnc),
RocksDBStateEncoder.getValueEncoder(valueSchema,
+ useMultipleValuesPerKey, avroEnc)))
+ }
+ private def getAvroSerializer(schema: StructType): AvroSerializer = {
Review Comment:
nit: line before the method please
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala:
##########
@@ -128,6 +128,73 @@ trait AlsoTestWithChangelogCheckpointingEnabled
}
}
+ def testWithEncodingTypes(testName: String, testTags: Tag*)
+ (testBody: => Any): Unit = {
Review Comment:
one parameter per line like below please
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala:
##########
@@ -58,7 +58,7 @@ class RocksDBStateStoreSuite extends
StateStoreSuiteBase[RocksDBStateStoreProvid
import StateStoreTestsHelper._
- testWithColumnFamilies(s"version encoding",
+ testWithColumnFamiliesAndEncodingTypes(s"version encoding",
Review Comment:
I wonder if it was better to just extend these classes and override the
SQLConf
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala:
##########
@@ -661,19 +1016,27 @@ class RangeKeyScanStateEncoder(
class NoPrefixKeyStateEncoder(
keySchema: StructType,
useColumnFamilies: Boolean = false,
- virtualColFamilyId: Option[Short] = None)
- extends RocksDBKeyStateEncoderBase(useColumnFamilies, virtualColFamilyId) {
+ virtualColFamilyId: Option[Short] = None,
+ avroEnc: Option[AvroEncoder] = None)
Review Comment:
ditto on the Serde.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala:
##########
@@ -593,6 +657,9 @@ object RocksDBStateStoreProvider {
val STATE_ENCODING_VERSION: Byte = 0
val VIRTUAL_COL_FAMILY_PREFIX_BYTES = 2
+ // Add the cache at companion object level so it persists across provider
instances
+ private val avroEncoderMap = new
java.util.concurrent.ConcurrentHashMap[String, AvroEncoder]()
Review Comment:
Do we want to leverage LinkedHashMap to limit the size of the cache?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]