This is an automated email from the ASF dual-hosted git repository.

yao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new beb71bb5a443 [SPARK-51813][SQL][CORE] Add a nonnullable 
DefaultCachedBatchKryoSerializer to avoid null propagating in 
DefaultCachedBatch serde
beb71bb5a443 is described below

commit beb71bb5a4438fd0204601425c907eb7d5ec8874
Author: Kent Yao <y...@apache.org>
AuthorDate: Wed Apr 16 13:55:55 2025 +0800

    [SPARK-51813][SQL][CORE] Add a nonnullable DefaultCachedBatchKryoSerializer 
to avoid null propagating in DefaultCachedBatch serde
    
    ### What changes were proposed in this pull request?
    
    Add a nonnullable DefaultCachedBatchKryoSerializer to avoid null 
propagating in DefaultCachedBatch serde
    
    ### Why are the changes needed?
    
    The cached data might become malformed sometimes. If Kryo fails to 
serialize/deserialize it and throw exceptions, it's kind of fine. But if it 
reads it as null, it will lead Spark to NPEs.
    
    
![image](https://github.com/user-attachments/assets/7e9d4286-32eb-4ad8-b760-a80d9c0b53ca)
    
    We need to detect these errors early
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    new tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    no
    
    Closes #50599 from yaooqinn/SPARK-51813.
    
    Authored-by: Kent Yao <y...@apache.org>
    Signed-off-by: Kent Yao <y...@apache.org>
---
 .../src/main/resources/error/error-conditions.json |  6 ++
 .../sql/execution/columnar/InMemoryRelation.scala  | 55 ++++++++++++++++-
 .../apache/spark/sql/CacheTableInKryoSuite.scala   | 68 +++++++++++++++++++++-
 3 files changed, 125 insertions(+), 4 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-conditions.json 
b/common/utils/src/main/resources/error/error-conditions.json
index 70550f0b4e13..3d7977673a3f 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -2937,6 +2937,12 @@
     ],
     "sqlState" : "F0000"
   },
+  "INVALID_KRYO_SERIALIZER_NO_DATA" : {
+    "message" : [
+      "The object '<obj>' is invalid or malformed to <serdeOp> using 
<serdeClass>."
+    ],
+    "sqlState" : "22002"
+  },
   "INVALID_LABEL_USAGE" : {
     "message" : [
       "The usage of the label <labelName> is invalid."
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
index 0f280d236203..bdbaee16d4e9 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.execution.columnar
 
+import com.esotericsoftware.kryo.{DefaultSerializer, Kryo, Serializer => 
KryoSerializer}
+import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput}
 import org.apache.commons.lang3.StringUtils
 
 import org.apache.spark.{SparkException, TaskContext}
@@ -30,11 +32,11 @@ import 
org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Sta
 import org.apache.spark.sql.catalyst.types.DataTypeUtils
 import org.apache.spark.sql.catalyst.util.truncatedString
 import org.apache.spark.sql.columnar.{CachedBatch, CachedBatchSerializer, 
SimpleMetricsCachedBatch, SimpleMetricsCachedBatchSerializer}
-import org.apache.spark.sql.execution.{ColumnarToRowTransition, InputAdapter, 
QueryExecution, SparkPlan, WholeStageCodegenExec}
+import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
 import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, 
OnHeapColumnVector, WritableColumnVector}
 import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
-import org.apache.spark.sql.types.{BooleanType, ByteType, DoubleType, 
FloatType, IntegerType, LongType, ShortType, StructType, UserDefinedType}
+import org.apache.spark.sql.types._
 import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.{LongAccumulator, Utils}
@@ -47,9 +49,56 @@ import org.apache.spark.util.ArrayImplicits._
  * @param buffers The buffers for serialized columns
  * @param stats The stat of columns
  */
-case class DefaultCachedBatch(numRows: Int, buffers: Array[Array[Byte]], 
stats: InternalRow)
+@DefaultSerializer(classOf[DefaultCachedBatchKryoSerializer])
+case class DefaultCachedBatch(
+     numRows: Int,
+     buffers: Array[Array[Byte]],
+     stats: InternalRow)
   extends SimpleMetricsCachedBatch
 
+class DefaultCachedBatchKryoSerializer extends 
KryoSerializer[DefaultCachedBatch] {
+  override def write(kryo: Kryo, output: KryoOutput, batch: 
DefaultCachedBatch): Unit = {
+    output.writeInt(batch.numRows)
+    SparkException.require(batch.buffers != null, 
"INVALID_KRYO_SERIALIZER_NO_DATA",
+      Map("obj" -> "DefaultCachedBatch.buffers",
+        "serdeOp" -> "serialize",
+        "serdeClass" -> this.getClass.getName))
+    output.writeInt(batch.buffers.length + 1) // +1 to distinguish Kryo.NULL
+    for (i <- batch.buffers.indices) {
+      val buffer = batch.buffers(i)
+        SparkException.require(buffer != null, 
"INVALID_KRYO_SERIALIZER_NO_DATA",
+          Map("obj" -> s"DefaultCachedBatch.buffers($i)",
+            "serdeOp" -> "serialize",
+            "serdeClass" -> this.getClass.getName))
+      output.writeInt(buffer.length + 1)  // +1 to distinguish Kryo.NULL
+      output.writeBytes(buffer)
+    }
+    kryo.writeClassAndObject(output, batch.stats)
+  }
+
+  override def read(
+      kryo: Kryo, input: KryoInput, cls: Class[DefaultCachedBatch]): 
DefaultCachedBatch = {
+    val numRows = input.readInt()
+    val length = input.readInt()
+    SparkException.require(length != Kryo.NULL, 
"INVALID_KRYO_SERIALIZER_NO_DATA",
+      Map("obj" -> "DefaultCachedBatch.buffers",
+        "serdeOp" -> "deserialize",
+        "serdeClass" -> this.getClass.getName))
+    val buffers = 0.until(length - 1).map { i => // -1 to restore
+      val subLength = input.readInt()
+      SparkException.require(subLength != Kryo.NULL, 
"INVALID_KRYO_SERIALIZER_NO_DATA",
+          Map("obj" -> s"DefaultCachedBatch.buffers($i)",
+          "serdeOp" -> "deserialize",
+          "serdeClass" -> this.getClass.getName))
+      val innerArray = new Array[Byte](subLength - 1) // -1 to restore
+      input.readBytes(innerArray)
+      innerArray
+    }.toArray
+    val stats = kryo.readClassAndObject(input).asInstanceOf[InternalRow]
+    DefaultCachedBatch(numRows, buffers, stats)
+  }
+}
+
 /**
  * The default implementation of CachedBatchSerializer.
  */
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CacheTableInKryoSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/CacheTableInKryoSuite.scala
index 1b2fbb5d4aa8..26d8f750f6e7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CacheTableInKryoSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CacheTableInKryoSuite.scala
@@ -17,7 +17,13 @@
 
 package org.apache.spark.sql
 
-import org.apache.spark.SparkConf
+import com.esotericsoftware.kryo.Kryo
+import com.esotericsoftware.kryo.io.Input
+
+import org.apache.spark.{SparkConf, SparkIllegalArgumentException}
+import org.apache.spark.serializer.KryoSerializer
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.columnar.{DefaultCachedBatch, 
DefaultCachedBatchKryoSerializer}
 import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
 import org.apache.spark.storage.StorageLevel
 
@@ -52,4 +58,64 @@ class CacheTableInKryoSuite extends QueryTest
             Seq(Row("apache", "spark", "community"), Row("Apache", "Spark", 
"Community")))
     }
   }
+
+  test("SPARK-51813 DefaultCachedBatchKryoSerializer do not propagate nulls") {
+    val ks = new KryoSerializer(this.sparkConf)
+    val kryo = ks.newKryo()
+    val serializer = kryo.getDefaultSerializer(classOf[DefaultCachedBatch])
+    assert(serializer.isInstanceOf[DefaultCachedBatchKryoSerializer])
+    val ser = serializer.asInstanceOf[DefaultCachedBatchKryoSerializer]
+
+    checkError(
+      exception = intercept[SparkIllegalArgumentException] {
+        ser.write(kryo, ks.newKryoOutput(), DefaultCachedBatch(1, null, 
InternalRow.empty))
+      },
+      condition = "INVALID_KRYO_SERIALIZER_NO_DATA",
+      parameters = Map(
+        "obj" -> "DefaultCachedBatch.buffers",
+        "serdeOp" -> "serialize",
+        "serdeClass" -> ser.getClass.getName))
+
+    checkError(
+      exception = intercept[SparkIllegalArgumentException] {
+        ser.write(kryo, ks.newKryoOutput(),
+          DefaultCachedBatch(1, Seq(Array.empty[Byte], null).toArray, 
InternalRow.empty))
+      },
+      condition = "INVALID_KRYO_SERIALIZER_NO_DATA",
+      parameters = Map(
+        "obj" -> "DefaultCachedBatch.buffers(1)",
+        "serdeOp" -> "serialize",
+        "serdeClass" -> ser.getClass.getName))
+
+    val output1 = ks.newKryoOutput()
+    output1.writeInt(1) // numRows
+    output1.writeInt(Kryo.NULL) // malformed buffers.length
+
+    checkError(
+      exception = intercept[SparkIllegalArgumentException] {
+        ser.read(kryo, new Input(output1.toBytes), classOf[DefaultCachedBatch])
+      },
+      condition = "INVALID_KRYO_SERIALIZER_NO_DATA",
+      parameters = Map(
+        "obj" -> "DefaultCachedBatch.buffers",
+        "serdeOp" -> "deserialize",
+        "serdeClass" -> ser.getClass.getName))
+    output1.close()
+
+    val output2 = ks.newKryoOutput()
+    output2.writeInt(1) // numRows
+    output2.writeInt(3) // buffers.length + 1
+    output2.writeInt(Kryo.NULL) // malformed buffers[0].length
+    output2.writeBytes(Array[Byte](1, 2, 3)) // buffers[0]
+
+    checkError(
+      exception = intercept[SparkIllegalArgumentException] {
+        ser.read(kryo, new Input(output2.toBytes, 0, 14), 
classOf[DefaultCachedBatch])
+      },
+      condition = "INVALID_KRYO_SERIALIZER_NO_DATA",
+      parameters = Map(
+        "obj" -> "DefaultCachedBatch.buffers(0)",
+        "serdeOp" -> "deserialize",
+        "serdeClass" -> ser.getClass.getName))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to