This is an automated email from the ASF dual-hosted git repository. yao pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new beb71bb5a443 [SPARK-51813][SQL][CORE] Add a nonnullable DefaultCachedBatchKryoSerializer to avoid null propagating in DefaultCachedBatch serde beb71bb5a443 is described below commit beb71bb5a4438fd0204601425c907eb7d5ec8874 Author: Kent Yao <y...@apache.org> AuthorDate: Wed Apr 16 13:55:55 2025 +0800 [SPARK-51813][SQL][CORE] Add a nonnullable DefaultCachedBatchKryoSerializer to avoid null propagating in DefaultCachedBatch serde ### What changes were proposed in this pull request? Add a nonnullable DefaultCachedBatchKryoSerializer to avoid null propagating in DefaultCachedBatch serde ### Why are the changes needed? The cached data might become malformed sometimes. If Kryo fails to serialize/deserialize it and throw exceptions, it's kind of fine. But if it reads it as null, it will lead Spark to NPEs.  We need to detect these errors early ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? new tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #50599 from yaooqinn/SPARK-51813. Authored-by: Kent Yao <y...@apache.org> Signed-off-by: Kent Yao <y...@apache.org> --- .../src/main/resources/error/error-conditions.json | 6 ++ .../sql/execution/columnar/InMemoryRelation.scala | 55 ++++++++++++++++- .../apache/spark/sql/CacheTableInKryoSuite.scala | 68 +++++++++++++++++++++- 3 files changed, 125 insertions(+), 4 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 70550f0b4e13..3d7977673a3f 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -2937,6 +2937,12 @@ ], "sqlState" : "F0000" }, + "INVALID_KRYO_SERIALIZER_NO_DATA" : { + "message" : [ + "The object '<obj>' is invalid or malformed to <serdeOp> using <serdeClass>." + ], + "sqlState" : "22002" + }, "INVALID_LABEL_USAGE" : { "message" : [ "The usage of the label <labelName> is invalid." diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 0f280d236203..bdbaee16d4e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.columnar +import com.esotericsoftware.kryo.{DefaultSerializer, Kryo, Serializer => KryoSerializer} +import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} import org.apache.commons.lang3.StringUtils import org.apache.spark.{SparkException, TaskContext} @@ -30,11 +32,11 @@ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Sta import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.columnar.{CachedBatch, CachedBatchSerializer, SimpleMetricsCachedBatch, SimpleMetricsCachedBatchSerializer} -import org.apache.spark.sql.execution.{ColumnarToRowTransition, InputAdapter, QueryExecution, SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector, WritableColumnVector} import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} -import org.apache.spark.sql.types.{BooleanType, ByteType, DoubleType, FloatType, IntegerType, LongType, ShortType, StructType, UserDefinedType} +import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{LongAccumulator, Utils} @@ -47,9 +49,56 @@ import org.apache.spark.util.ArrayImplicits._ * @param buffers The buffers for serialized columns * @param stats The stat of columns */ -case class DefaultCachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow) +@DefaultSerializer(classOf[DefaultCachedBatchKryoSerializer]) +case class DefaultCachedBatch( + numRows: Int, + buffers: Array[Array[Byte]], + stats: InternalRow) extends SimpleMetricsCachedBatch +class DefaultCachedBatchKryoSerializer extends KryoSerializer[DefaultCachedBatch] { + override def write(kryo: Kryo, output: KryoOutput, batch: DefaultCachedBatch): Unit = { + output.writeInt(batch.numRows) + SparkException.require(batch.buffers != null, "INVALID_KRYO_SERIALIZER_NO_DATA", + Map("obj" -> "DefaultCachedBatch.buffers", + "serdeOp" -> "serialize", + "serdeClass" -> this.getClass.getName)) + output.writeInt(batch.buffers.length + 1) // +1 to distinguish Kryo.NULL + for (i <- batch.buffers.indices) { + val buffer = batch.buffers(i) + SparkException.require(buffer != null, "INVALID_KRYO_SERIALIZER_NO_DATA", + Map("obj" -> s"DefaultCachedBatch.buffers($i)", + "serdeOp" -> "serialize", + "serdeClass" -> this.getClass.getName)) + output.writeInt(buffer.length + 1) // +1 to distinguish Kryo.NULL + output.writeBytes(buffer) + } + kryo.writeClassAndObject(output, batch.stats) + } + + override def read( + kryo: Kryo, input: KryoInput, cls: Class[DefaultCachedBatch]): DefaultCachedBatch = { + val numRows = input.readInt() + val length = input.readInt() + SparkException.require(length != Kryo.NULL, "INVALID_KRYO_SERIALIZER_NO_DATA", + Map("obj" -> "DefaultCachedBatch.buffers", + "serdeOp" -> "deserialize", + "serdeClass" -> this.getClass.getName)) + val buffers = 0.until(length - 1).map { i => // -1 to restore + val subLength = input.readInt() + SparkException.require(subLength != Kryo.NULL, "INVALID_KRYO_SERIALIZER_NO_DATA", + Map("obj" -> s"DefaultCachedBatch.buffers($i)", + "serdeOp" -> "deserialize", + "serdeClass" -> this.getClass.getName)) + val innerArray = new Array[Byte](subLength - 1) // -1 to restore + input.readBytes(innerArray) + innerArray + }.toArray + val stats = kryo.readClassAndObject(input).asInstanceOf[InternalRow] + DefaultCachedBatch(numRows, buffers, stats) + } +} + /** * The default implementation of CachedBatchSerializer. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CacheTableInKryoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CacheTableInKryoSuite.scala index 1b2fbb5d4aa8..26d8f750f6e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CacheTableInKryoSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CacheTableInKryoSuite.scala @@ -17,7 +17,13 @@ package org.apache.spark.sql -import org.apache.spark.SparkConf +import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.io.Input + +import org.apache.spark.{SparkConf, SparkIllegalArgumentException} +import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.columnar.{DefaultCachedBatch, DefaultCachedBatchKryoSerializer} import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} import org.apache.spark.storage.StorageLevel @@ -52,4 +58,64 @@ class CacheTableInKryoSuite extends QueryTest Seq(Row("apache", "spark", "community"), Row("Apache", "Spark", "Community"))) } } + + test("SPARK-51813 DefaultCachedBatchKryoSerializer do not propagate nulls") { + val ks = new KryoSerializer(this.sparkConf) + val kryo = ks.newKryo() + val serializer = kryo.getDefaultSerializer(classOf[DefaultCachedBatch]) + assert(serializer.isInstanceOf[DefaultCachedBatchKryoSerializer]) + val ser = serializer.asInstanceOf[DefaultCachedBatchKryoSerializer] + + checkError( + exception = intercept[SparkIllegalArgumentException] { + ser.write(kryo, ks.newKryoOutput(), DefaultCachedBatch(1, null, InternalRow.empty)) + }, + condition = "INVALID_KRYO_SERIALIZER_NO_DATA", + parameters = Map( + "obj" -> "DefaultCachedBatch.buffers", + "serdeOp" -> "serialize", + "serdeClass" -> ser.getClass.getName)) + + checkError( + exception = intercept[SparkIllegalArgumentException] { + ser.write(kryo, ks.newKryoOutput(), + DefaultCachedBatch(1, Seq(Array.empty[Byte], null).toArray, InternalRow.empty)) + }, + condition = "INVALID_KRYO_SERIALIZER_NO_DATA", + parameters = Map( + "obj" -> "DefaultCachedBatch.buffers(1)", + "serdeOp" -> "serialize", + "serdeClass" -> ser.getClass.getName)) + + val output1 = ks.newKryoOutput() + output1.writeInt(1) // numRows + output1.writeInt(Kryo.NULL) // malformed buffers.length + + checkError( + exception = intercept[SparkIllegalArgumentException] { + ser.read(kryo, new Input(output1.toBytes), classOf[DefaultCachedBatch]) + }, + condition = "INVALID_KRYO_SERIALIZER_NO_DATA", + parameters = Map( + "obj" -> "DefaultCachedBatch.buffers", + "serdeOp" -> "deserialize", + "serdeClass" -> ser.getClass.getName)) + output1.close() + + val output2 = ks.newKryoOutput() + output2.writeInt(1) // numRows + output2.writeInt(3) // buffers.length + 1 + output2.writeInt(Kryo.NULL) // malformed buffers[0].length + output2.writeBytes(Array[Byte](1, 2, 3)) // buffers[0] + + checkError( + exception = intercept[SparkIllegalArgumentException] { + ser.read(kryo, new Input(output2.toBytes, 0, 14), classOf[DefaultCachedBatch]) + }, + condition = "INVALID_KRYO_SERIALIZER_NO_DATA", + parameters = Map( + "obj" -> "DefaultCachedBatch.buffers(0)", + "serdeOp" -> "deserialize", + "serdeClass" -> ser.getClass.getName)) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org