jackylee-ch commented on code in PR #12211:
URL: https://github.com/apache/gluten/pull/12211#discussion_r3360023664
##########
backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala:
##########
@@ -617,45 +663,109 @@ object CachedColumnarBatchKryoSerializer {
}
/**
- * Parse the JNI `serializeWithStats` framed return into (stats InternalRow,
bytesBlob).
- *
- * Framed layout (matches cpp VeloxColumnarBatchSerializer.cc): `[
STATS_FRAMED_MAGIC: 4B ] [
- * statsLen: u32 LE ] [ statsBlob ] [ bytesLen: u32 LE ] [ bytesBlob ]`.
+ * Parse the JNI `serializeWithStats` framed return into (stats InternalRow,
bytesBlob). Routes on
+ * the full 4-byte magic: V2 -> 0xFECA5302, V3 -> 0xFECA5303.
*
- * Eager guards catch corrupt magic / truncated framing before they
propagate.
+ * V2 layout: `[ magic: 4B ] [ statsLen: u32 LE ] [ statsBlob ] [ bytesLen:
u32 LE ] [ bytesBlob
+ * ]` V3 layout: `[ magic: 4B ] [ statsLen: u32 LE ] [ statsBlob ] [
numRows: u32 LE ] [ numCols:
+ * u32 LE ] [ per-col ]`
*/
private[execution] def parseFramedBytes(
framed: Array[Byte],
schema: StructType): (InternalRow, Array[Byte]) = {
+ // V2 minimum = 4+4+4=12B; V3 minimum = 4+4+4+4=16B; use 12 for dispatcher
guard.
require(
- framed != null && framed.length >= 4 + 4 + 4,
+ framed != null && framed.length >= 12,
s"framed bytes too short: len=${if (framed == null) -1 else
framed.length}")
- require(
- framed(0) == STATS_FRAMED_MAGIC(0) && framed(1) == STATS_FRAMED_MAGIC(1)
&&
- framed(2) == STATS_FRAMED_MAGIC(2) && framed(3) ==
STATS_FRAMED_MAGIC(3),
- f"framed bytes magic mismatch: expected " +
- f"0x${STATS_FRAMED_MAGIC(0) & 0xff}%02X${STATS_FRAMED_MAGIC(1) &
0xff}%02X" +
- f"${STATS_FRAMED_MAGIC(2) & 0xff}%02X${STATS_FRAMED_MAGIC(3) &
0xff}%02X, got " +
- f"0x${framed(0) & 0xff}%02X${framed(1) & 0xff}%02X" +
- f"${framed(2) & 0xff}%02X${framed(3) & 0xff}%02X"
- )
+ framedMagicVersion(framed) match {
+ case 0x02 => parseV2Frame(framed, schema)
+ case 0x03 => parseV3Frame(framed, schema)
+ }
+ }
+
+ /** V2 parse: extract stats + pure Presto bytesBlob. */
+ private def parseV2Frame(framed: Array[Byte], schema: StructType):
(InternalRow, Array[Byte]) = {
+ requireFrameMagic(framed, STATS_FRAMED_MAGIC, "V2")
val buf = ByteBuffer.wrap(framed).order(ByteOrder.LITTLE_ENDIAN)
buf.position(4) // skip magic
val statsLen = buf.getInt
require(
statsLen >= 0 && statsLen <= buf.remaining() - 4,
- s"framed bytes statsLen=$statsLen exceeds remaining buffer
${buf.remaining() - 4}")
+ s"V2 framed bytes statsLen=$statsLen exceeds remaining buffer
${buf.remaining() - 4}")
val statsBlob = new Array[Byte](statsLen)
buf.get(statsBlob)
val stats = deserializeStats(statsBlob, schema)
val bytesLen = buf.getInt
require(
bytesLen >= 0 && bytesLen == buf.remaining(),
- s"framed bytes bytesLen=$bytesLen != remaining ${buf.remaining()}
(truncated or trailing)")
+ s"V2 framed bytes bytesLen=$bytesLen != remaining ${buf.remaining()}
(truncated or trailing)")
val bytesBlob = new Array[Byte](bytesLen)
buf.get(bytesBlob)
(stats, bytesBlob)
}
+
+ /**
+ * V3 parse: extract stats; bytes = the full V3 framed array (C++
deserializeV3 starts at magic).
+ * Invariant: returned bytes[0..3] == V3 magic; C++ deserializeV3
re-validates the schema-level
+ * contract, while the JVM parser fails fast on top-level frame bounds.
+ */
+ private def parseV3Frame(framed: Array[Byte], schema: StructType):
(InternalRow, Array[Byte]) = {
+ val parsed = parseV3FrameInternal(framed, schema, decodeStats = true)
+ (parsed.stats, parsed.bytes)
+ }
+
+ private[execution] def requireV3FrameNumRows(
+ framed: Array[Byte],
+ expectedNumRows: Int,
+ context: String): Unit = {
+ val frameNumRows = parseV3FrameInternal(framed, null, decodeStats =
false).numRows
+ require(
+ frameNumRows == expectedNumRows,
+ s"$context: V3 frame numRows=$frameNumRows != CachedBatch
numRows=$expectedNumRows")
+ }
+
+ private def parseV3FrameInternal(
+ framed: Array[Byte],
+ schema: StructType,
+ decodeStats: Boolean): V3ParsedFrame = {
+ require(framed.length >= 16, s"V3 framed bytes too short (min 16B):
len=${framed.length}")
+ requireFrameMagic(framed, STATS_FRAMED_MAGIC_V3, "V3")
+ val buf = ByteBuffer.wrap(framed).order(ByteOrder.LITTLE_ENDIAN)
+ buf.position(4) // skip magic
+ val statsLen = buf.getInt
+ require(
+ statsLen >= 0 && statsLen <= buf.remaining() - 8, // 8 =
numRows(4)+numCols(4)
+ s"V3 framed bytes statsLen=$statsLen invalid")
+ val statsBlob = new Array[Byte](statsLen)
+ buf.get(statsBlob)
+ val stats =
+ if (!decodeStats || statsLen == 0) null else deserializeStats(statsBlob,
schema)
+ val numRows = buf.getInt
+ require(numRows >= 0, s"V3 framed bytes numRows=$numRows invalid")
+ val numCols = buf.getInt
+ require(numCols >= 0, s"V3 framed bytes numCols=$numCols invalid")
Review Comment:
done
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]