HeartSaVioR commented on code in PR #45503:
URL: https://github.com/apache/spark/pull/45503#discussion_r1535006405


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala:
##########
@@ -77,9 +77,18 @@ class StatePartitionReader(
       stateStoreMetadata.head.numColsPrefixKey
     }
 
+    // TODO: currently we don't support RangeKeyScanStateEncoderSpec. Support 
for this will be

Review Comment:
   Probably good to file a JIRA ticket for this and leave a ticket number. 
Could you please help filing one?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala:
##########
@@ -110,9 +120,6 @@ class PrefixKeyScanStateEncoder(
 
   import RocksDBStateEncoder._
 
-  require(keySchema.length > numColsPrefixKey, "The number of columns in the 
key must be " +

Review Comment:
   The reason I added this here is that it's semantically making zero sense if 
keySchema.length == numColsPrefixKey. The caller should just use get and prefix 
scan isn't needed.
   
   What is the desired use case?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala:
##########
@@ -192,6 +199,241 @@ class PrefixKeyScanStateEncoder(
   override def supportPrefixKeyScan: Boolean = true
 }
 
+/**
+ * RocksDB Key Encoder for UnsafeRow that supports range scan for fixed size 
fields
+ *
+ * To encode a row for range scan, we first project the first numOrderingCols 
needed
+ * for the range scan into an UnsafeRow; we then rewrite that UnsafeRow's 
fields in BIG_ENDIAN
+ * to allow for scanning keys in sorted order using the byte-wise comparison 
method that
+ * RocksDB uses.
+ * Then, for the rest of the fields, we project those into another UnsafeRow.
+ * We then effectively join these two UnsafeRows together, and finally take 
those bytes
+ * to get the resulting row.
+ * We cannot support variable sized fields given the UnsafeRow format which 
stores variable
+ * sized fields as offset and length pointers to the actual values, thereby 
changing the required
+ * ordering.
+ *
+ * @param keySchema - schema of the key to be encoded
+ * @param numOrderingCols - number of columns to be used for range scan
+ */
+class RangeKeyScanStateEncoder(
+    keySchema: StructType,
+    numOrderingCols: Int) extends RocksDBKeyStateEncoder {
+
+  import RocksDBStateEncoder._
+
+  private val rangeScanKeyFieldsWithIdx: Seq[(StructField, Int)] = {
+    keySchema.zipWithIndex.take(numOrderingCols)
+  }
+
+  private def isFixedSize(dataType: DataType): Boolean = dataType match {
+    case _: ByteType | _: BooleanType | _: ShortType | _: IntegerType | _: 
LongType |
+      _: FloatType | _: DoubleType => true
+    case _ => false
+  }
+
+  // verify that only fixed sized columns are used for ordering
+  rangeScanKeyFieldsWithIdx.foreach { case (field, idx) =>
+    if (!isFixedSize(field.dataType)) {
+      // NullType is technically fixed size, but not supported for ordering
+      if (field.dataType == NullType) {
+        throw StateStoreErrors.nullTypeOrderingColsNotSupported(field.name, 
idx.toString)
+      } else {
+        throw 
StateStoreErrors.variableSizeOrderingColsNotSupported(field.name, idx.toString)
+      }
+    }
+  }
+
+  private val remainingKeyFieldsWithIdx: Seq[(StructField, Int)] = {
+    keySchema.zipWithIndex.drop(numOrderingCols)
+  }
+
+  private val rangeScanKeyProjection: UnsafeProjection = {
+    val refs = rangeScanKeyFieldsWithIdx.map(x =>
+      BoundReference(x._2, x._1.dataType, x._1.nullable))
+    UnsafeProjection.create(refs)
+  }
+
+  private val remainingKeyProjection: UnsafeProjection = {
+    val refs = remainingKeyFieldsWithIdx.map(x =>
+      BoundReference(x._2, x._1.dataType, x._1.nullable))
+    UnsafeProjection.create(refs)
+  }
+
+  private val restoreKeyProjection: UnsafeProjection = 
UnsafeProjection.create(keySchema)
+
+  // Reusable objects
+  private val joinedRowOnKey = new JoinedRow()
+
+  private def extractPrefixKey(key: UnsafeRow): UnsafeRow = {
+    rangeScanKeyProjection(key)
+  }
+
+  // Rewrite the unsafe row by replacing fixed size fields with BIG_ENDIAN 
encoding
+  // using byte arrays.
+  private def encodePrefixKeyForRangeScan(row: UnsafeRow): UnsafeRow = {
+    val writer = new UnsafeRowWriter(numOrderingCols)
+    writer.resetRowWriter()
+    rangeScanKeyFieldsWithIdx.foreach { case (field, idx) =>
+      val value = row.get(idx, field.dataType)
+      field.dataType match {
+        // endian-ness doesn't matter for single byte objects. so just write 
these
+        // types directly.
+        case BooleanType => writer.write(idx, value.asInstanceOf[Boolean])
+        case ByteType => writer.write(idx, value.asInstanceOf[Byte])
+
+        // for other multi-byte types, we need to convert to big-endian
+        case ShortType =>
+          val bbuf = ByteBuffer.allocate(field.dataType.defaultSize)
+          bbuf.order(ByteOrder.BIG_ENDIAN)
+          bbuf.putShort(value.asInstanceOf[Short])
+          writer.write(idx, bbuf.array())
+
+        case IntegerType =>
+          val bbuf = ByteBuffer.allocate(field.dataType.defaultSize)
+          bbuf.order(ByteOrder.BIG_ENDIAN)
+          bbuf.putInt(value.asInstanceOf[Int])
+          writer.write(idx, bbuf.array())
+
+        case LongType =>
+          val bbuf = ByteBuffer.allocate(field.dataType.defaultSize)
+          bbuf.order(ByteOrder.BIG_ENDIAN)
+          bbuf.putLong(value.asInstanceOf[Long])
+          writer.write(idx, bbuf.array())
+
+        case FloatType =>
+          val bbuf = ByteBuffer.allocate(field.dataType.defaultSize)
+          bbuf.order(ByteOrder.BIG_ENDIAN)
+          bbuf.putFloat(value.asInstanceOf[Float])
+          writer.write(idx, bbuf.array())
+
+        case DoubleType =>
+          val bbuf = ByteBuffer.allocate(field.dataType.defaultSize)
+          bbuf.order(ByteOrder.BIG_ENDIAN)
+          bbuf.putDouble(value.asInstanceOf[Double])
+          writer.write(idx, bbuf.array())
+      }
+    }
+    writer.getRow().copy()

Review Comment:
   nit: is copy() needed?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala:
##########
@@ -289,6 +289,26 @@ class InvalidUnsafeRowException(error: String)
     "among restart. For the first case, you can try to restart the application 
without " +
     s"checkpoint or use the legacy Spark version to process the streaming 
state.\n$error", null)
 
+sealed trait KeyStateEncoderSpec
+
+case class NoPrefixKeyStateEncoderSpec(keySchema: StructType) extends 
KeyStateEncoderSpec
+
+case class PrefixKeyScanStateEncoderSpec(
+    keySchema: StructType,
+    numColsPrefixKey: Int) extends KeyStateEncoderSpec {
+  if (numColsPrefixKey == 0 || numColsPrefixKey >= keySchema.length) {

Review Comment:
   Ah OK you check the equality in here, good.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala:
##########
@@ -192,6 +199,241 @@ class PrefixKeyScanStateEncoder(
   override def supportPrefixKeyScan: Boolean = true
 }
 
+/**
+ * RocksDB Key Encoder for UnsafeRow that supports range scan for fixed size 
fields
+ *
+ * To encode a row for range scan, we first project the first numOrderingCols 
needed
+ * for the range scan into an UnsafeRow; we then rewrite that UnsafeRow's 
fields in BIG_ENDIAN
+ * to allow for scanning keys in sorted order using the byte-wise comparison 
method that
+ * RocksDB uses.
+ * Then, for the rest of the fields, we project those into another UnsafeRow.
+ * We then effectively join these two UnsafeRows together, and finally take 
those bytes
+ * to get the resulting row.
+ * We cannot support variable sized fields given the UnsafeRow format which 
stores variable
+ * sized fields as offset and length pointers to the actual values, thereby 
changing the required
+ * ordering.
+ *
+ * @param keySchema - schema of the key to be encoded
+ * @param numOrderingCols - number of columns to be used for range scan
+ */
+class RangeKeyScanStateEncoder(
+    keySchema: StructType,
+    numOrderingCols: Int) extends RocksDBKeyStateEncoder {
+
+  import RocksDBStateEncoder._
+
+  private val rangeScanKeyFieldsWithIdx: Seq[(StructField, Int)] = {
+    keySchema.zipWithIndex.take(numOrderingCols)
+  }
+
+  private def isFixedSize(dataType: DataType): Boolean = dataType match {
+    case _: ByteType | _: BooleanType | _: ShortType | _: IntegerType | _: 
LongType |
+      _: FloatType | _: DoubleType => true
+    case _ => false
+  }
+
+  // verify that only fixed sized columns are used for ordering
+  rangeScanKeyFieldsWithIdx.foreach { case (field, idx) =>
+    if (!isFixedSize(field.dataType)) {
+      // NullType is technically fixed size, but not supported for ordering
+      if (field.dataType == NullType) {
+        throw StateStoreErrors.nullTypeOrderingColsNotSupported(field.name, 
idx.toString)
+      } else {
+        throw 
StateStoreErrors.variableSizeOrderingColsNotSupported(field.name, idx.toString)
+      }
+    }
+  }
+
+  private val remainingKeyFieldsWithIdx: Seq[(StructField, Int)] = {
+    keySchema.zipWithIndex.drop(numOrderingCols)
+  }
+
+  private val rangeScanKeyProjection: UnsafeProjection = {
+    val refs = rangeScanKeyFieldsWithIdx.map(x =>
+      BoundReference(x._2, x._1.dataType, x._1.nullable))
+    UnsafeProjection.create(refs)
+  }
+
+  private val remainingKeyProjection: UnsafeProjection = {
+    val refs = remainingKeyFieldsWithIdx.map(x =>
+      BoundReference(x._2, x._1.dataType, x._1.nullable))
+    UnsafeProjection.create(refs)
+  }
+
+  private val restoreKeyProjection: UnsafeProjection = 
UnsafeProjection.create(keySchema)
+
+  // Reusable objects
+  private val joinedRowOnKey = new JoinedRow()
+
+  private def extractPrefixKey(key: UnsafeRow): UnsafeRow = {
+    rangeScanKeyProjection(key)
+  }
+
+  // Rewrite the unsafe row by replacing fixed size fields with BIG_ENDIAN 
encoding
+  // using byte arrays.
+  private def encodePrefixKeyForRangeScan(row: UnsafeRow): UnsafeRow = {
+    val writer = new UnsafeRowWriter(numOrderingCols)
+    writer.resetRowWriter()
+    rangeScanKeyFieldsWithIdx.foreach { case (field, idx) =>
+      val value = row.get(idx, field.dataType)
+      field.dataType match {
+        // endian-ness doesn't matter for single byte objects. so just write 
these
+        // types directly.
+        case BooleanType => writer.write(idx, value.asInstanceOf[Boolean])
+        case ByteType => writer.write(idx, value.asInstanceOf[Byte])
+
+        // for other multi-byte types, we need to convert to big-endian
+        case ShortType =>
+          val bbuf = ByteBuffer.allocate(field.dataType.defaultSize)
+          bbuf.order(ByteOrder.BIG_ENDIAN)
+          bbuf.putShort(value.asInstanceOf[Short])
+          writer.write(idx, bbuf.array())
+
+        case IntegerType =>
+          val bbuf = ByteBuffer.allocate(field.dataType.defaultSize)
+          bbuf.order(ByteOrder.BIG_ENDIAN)
+          bbuf.putInt(value.asInstanceOf[Int])
+          writer.write(idx, bbuf.array())
+
+        case LongType =>
+          val bbuf = ByteBuffer.allocate(field.dataType.defaultSize)
+          bbuf.order(ByteOrder.BIG_ENDIAN)
+          bbuf.putLong(value.asInstanceOf[Long])
+          writer.write(idx, bbuf.array())
+
+        case FloatType =>
+          val bbuf = ByteBuffer.allocate(field.dataType.defaultSize)
+          bbuf.order(ByteOrder.BIG_ENDIAN)
+          bbuf.putFloat(value.asInstanceOf[Float])
+          writer.write(idx, bbuf.array())
+
+        case DoubleType =>
+          val bbuf = ByteBuffer.allocate(field.dataType.defaultSize)
+          bbuf.order(ByteOrder.BIG_ENDIAN)
+          bbuf.putDouble(value.asInstanceOf[Double])
+          writer.write(idx, bbuf.array())
+      }
+    }
+    writer.getRow().copy()
+  }
+
+  // Rewrite the unsafe row by converting back from BIG_ENDIAN byte arrays to 
the
+  // original data types.
+  private def decodePrefixKeyForRangeScan(row: UnsafeRow): UnsafeRow = {
+    val writer = new UnsafeRowWriter(numOrderingCols)
+    writer.resetRowWriter()
+    rangeScanKeyFieldsWithIdx.foreach { case (field, idx) =>
+      val value = if (field.dataType == BooleanType || field.dataType == 
ByteType) {
+        row.get(idx, field.dataType)
+      } else {
+        row.getBinary(idx)
+      }
+
+      field.dataType match {
+        // for single byte types, read them directly
+        case BooleanType => writer.write(idx, value.asInstanceOf[Boolean])
+        case ByteType => writer.write(idx, value.asInstanceOf[Byte])
+
+        // for multi-byte types, convert from big-endian
+        case ShortType =>
+          val bbuf = ByteBuffer.wrap(value.asInstanceOf[Array[Byte]])
+          bbuf.order(ByteOrder.BIG_ENDIAN)
+          writer.write(idx, bbuf.getShort)
+
+        case IntegerType =>
+          val bbuf = ByteBuffer.wrap(value.asInstanceOf[Array[Byte]])
+          bbuf.order(ByteOrder.BIG_ENDIAN)
+          writer.write(idx, bbuf.getInt)
+
+        case LongType =>
+          val bbuf = ByteBuffer.wrap(value.asInstanceOf[Array[Byte]])
+          bbuf.order(ByteOrder.BIG_ENDIAN)
+          writer.write(idx, bbuf.getLong)
+
+        case FloatType =>
+          val bbuf = ByteBuffer.wrap(value.asInstanceOf[Array[Byte]])
+          bbuf.order(ByteOrder.BIG_ENDIAN)
+          writer.write(idx, bbuf.getFloat)
+
+        case DoubleType =>
+          val bbuf = ByteBuffer.wrap(value.asInstanceOf[Array[Byte]])
+          bbuf.order(ByteOrder.BIG_ENDIAN)
+          writer.write(idx, bbuf.getDouble)
+      }
+    }
+    writer.getRow().copy()

Review Comment:
   nit: ditto



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to