This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 19ca63f82fed [SPARK-53553][CONNECT] Fix handling of null values in LiteralValueProtoConverter 19ca63f82fed is described below commit 19ca63f82fedfa78a25205b12fe3aacf8c2fc815 Author: Yihong He <heyihong...@gmail.com> AuthorDate: Mon Sep 15 09:46:40 2025 +0800 [SPARK-53553][CONNECT] Fix handling of null values in LiteralValueProtoConverter ### What changes were proposed in this pull request? This PR fixes the handling of null literal values in `LiteralValueProtoConverter` for Spark Connect. The main changes include: 1. **Added proper null value handling**: Created a new `setNullValue` method that correctly sets null values in proto literals with appropriate data type information. 2. **Reordered pattern matching**: Moved null and Option handling to the top of the pattern matching in `toLiteralProtoBuilderInternal` to ensure null values are processed before other type-specific logic. 3. **Fixed converter logic**: Updated the `getScalaConverter` method to properly handle null values by checking `hasNull` before applying type-specific conversion logic. ### Why are the changes needed? The previous implementation had several issues with null value handling: 1. **Incorrect null processing order**: Null values were being processed after type-specific logic, which could lead to exceptions. 2. **Missing null checks in converters**: The converter functions didn't properly check for null values before applying type-specific conversion logic. ### Does this PR introduce _any_ user-facing change? **Yes**. This PR fixes a bug where null values in literals (especially in arrays and maps) were not being properly handled in Spark Connect. Users who were experiencing issues with null value serialization in complex types should now see correct behavior. ### How was this patch tested? `build/sbt "connect-client-jvm/testOnly *ClientE2ETestSuite -- -z SPARK-53553"` `build/sbt "connect/testOnly *LiteralExpressionProtoConverterSuite"` `build/sbt "connect-client-jvm/testOnly org.apache.spark.sql.PlanGenerationTestSuite"` `build/sbt "connect/testOnly org.apache.spark.sql.connect.ProtoToParsedPlanTestSuite"` ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Cursor 1.5.11 Closes #52310 from heyihong/SPARK-53553. Authored-by: Yihong He <heyihong...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../apache/spark/sql/PlanGenerationTestSuite.scala | 4 + .../spark/sql/connect/ClientE2ETestSuite.scala | 7 + .../common/LiteralValueProtoConverter.scala | 30 ++- .../explain-results/function_typedLit.explain | 2 +- .../query-tests/queries/function_typedLit.json | 249 ++++++++++++++++++++- .../queries/function_typedLit.proto.bin | Bin 9867 -> 10943 bytes .../LiteralExpressionProtoConverterSuite.scala | 4 + 7 files changed, 287 insertions(+), 9 deletions(-) diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index c6561510c035..28498f18cb08 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -3419,8 +3419,12 @@ class PlanGenerationTestSuite // Handle parameterized scala types e.g.: List, Seq and Map. fn.typedLit(Some(1)), fn.typedLit(Array(1, 2, 3)), + fn.typedLit[Array[Integer]](Array(null, null)), + fn.typedLit[Array[(Int, String)]](Array(null, null, (1, "a"), (2, null))), + fn.typedLit[Array[Option[(Int, String)]]](Array(None, None, Some((1, "a")))), fn.typedLit(Seq(1, 2, 3)), fn.typedLit(mutable.LinkedHashMap("a" -> 1, "b" -> 2)), + fn.typedLit(mutable.LinkedHashMap[String, Integer]("a" -> null, "b" -> null)), fn.typedLit(("a", 2, 1.0)), fn.typedLit[Option[Int]](None), fn.typedLit[Array[Option[Int]]](Array(Some(1))), diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala index e2213003656e..db165c03ad35 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala @@ -1785,6 +1785,13 @@ class ClientE2ETestSuite assert(observation.get.contains("map")) assert(observation.get("map") === Map("count" -> 10)) } + + test("SPARK-53553: null value handling in literals") { + val df = spark.sql("select 1").select(typedlit(Array[Integer](1, null)).as("arr_col")) + val result = df.collect() + assert(result.length === 1) + assert(result(0).getAs[Array[Integer]]("arr_col") === Array(1, null)) + } } private[sql] case class ClassData(a: String, b: Int) diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala index 286b83d4eae9..16bbeb99557b 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala @@ -40,6 +40,19 @@ import org.apache.spark.unsafe.types.CalendarInterval object LiteralValueProtoConverter { + private def setNullValue( + builder: proto.Expression.Literal.Builder, + dataType: DataType, + needDataType: Boolean): proto.Expression.Literal.Builder = { + if (needDataType) { + builder.setNull(toConnectProtoType(dataType)) + } else { + // No need data type but still set the null type to indicate that + // the value is null. + builder.setNull(ProtoDataTypes.NullType) + } + } + private def setArrayTypeAfterAddingElements( ab: proto.Expression.Literal.Array.Builder, elementType: DataType, @@ -275,6 +288,14 @@ object LiteralValueProtoConverter { } (literal, dataType) match { + case (v: Option[_], _) => + if (v.isDefined) { + toLiteralProtoBuilderInternal(v.get, dataType, options, needDataType) + } else { + setNullValue(builder, dataType, needDataType) + } + case (null, _) => + setNullValue(builder, dataType, needDataType) case (v: mutable.ArraySeq[_], ArrayType(_, _)) => toLiteralProtoBuilderInternal(v.array, dataType, options, needDataType) case (v: immutable.ArraySeq[_], ArrayType(_, _)) => @@ -287,12 +308,6 @@ object LiteralValueProtoConverter { builder.setMap(mapBuilder(v, keyType, valueType, valueContainsNull)) case (v, structType: StructType) => builder.setStruct(structBuilder(v, structType)) - case (v: Option[_], _: DataType) => - if (v.isDefined) { - toLiteralProtoBuilderInternal(v.get, options, needDataType) - } else { - builder.setNull(toConnectProtoType(dataType)) - } case (v: LocalTime, timeType: TimeType) => builder.setTime( builder.getTimeBuilder @@ -477,7 +492,7 @@ object LiteralValueProtoConverter { } private def getScalaConverter(dataType: proto.DataType): proto.Expression.Literal => Any = { - dataType.getKindCase match { + val converter: proto.Expression.Literal => Any = dataType.getKindCase match { case proto.DataType.KindCase.SHORT => v => v.getShort.toShort case proto.DataType.KindCase.INTEGER => v => v.getInteger case proto.DataType.KindCase.LONG => v => v.getLong @@ -513,6 +528,7 @@ object LiteralValueProtoConverter { case _ => throw InvalidPlanInput(s"Unsupported Literal Type: ${dataType.getKindCase}") } + v => if (v.hasNull) null else converter(v) } private def getInferredDataType( diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain index 817b923202c5..5daa50bfe38a 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain @@ -1,2 +1,2 @@ -Project [id#0L, id#0L, 1 AS 1#0, null AS NULL#0, true AS true#0, 68 AS 68#0, 9872 AS 9872#0, -8726532 AS -8726532#0, 7834609328726532 AS 7834609328726532#0L, 2.718281828459045 AS 2.718281828459045#0, -0.8 AS -0.8#0, 89.97620 AS 89.97620#0, 89889.7667231 AS 89889.7667231#0, connect! AS connect!#0, T AS T#0, ABCDEFGHIJ AS ABCDEFGHIJ#0, 0x78797A7B7C7D7E7F808182838485868788898A8B8C8D8E AS X'78797A7B7C7D7E7F808182838485868788898A8B8C8D8E'#0, 0x0806 AS X'0806'#0, [8,6] AS ARRAY(8, 6)#0, null A [...] +Project [id#0L, id#0L, 1 AS 1#0, null AS NULL#0, true AS true#0, 68 AS 68#0, 9872 AS 9872#0, -8726532 AS -8726532#0, 7834609328726532 AS 7834609328726532#0L, 2.718281828459045 AS 2.718281828459045#0, -0.8 AS -0.8#0, 89.97620 AS 89.97620#0, 89889.7667231 AS 89889.7667231#0, connect! AS connect!#0, T AS T#0, ABCDEFGHIJ AS ABCDEFGHIJ#0, 0x78797A7B7C7D7E7F808182838485868788898A8B8C8D8E AS X'78797A7B7C7D7E7F808182838485868788898A8B8C8D8E'#0, 0x0806 AS X'0806'#0, [8,6] AS ARRAY(8, 6)#0, null A [...] +- LocalRelation <empty>, [id#0L, a#0, b#0] diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json index 5869ec44789d..db7b2a992e94 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json @@ -77,7 +77,8 @@ }, { "literal": { "null": { - "null": { + "string": { + "collation": "UTF8_BINARY" } } }, @@ -821,6 +822,206 @@ } } } + }, { + "literal": { + "array": { + "elements": [{ + "null": { + "integer": { + } + } + }, { + "null": { + "null": { + } + } + }], + "dataType": { + "containsNull": true + } + } + }, + "common": { + "origin": { + "jvmOrigin": { + "stackTrace": [{ + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.functions$", + "methodName": "typedLit", + "fileName": "functions.scala" + }, { + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite", + "methodName": "~~trimmed~anonfun~~", + "fileName": "PlanGenerationTestSuite.scala" + }] + } + } + } + }, { + "literal": { + "array": { + "elements": [{ + "null": { + "struct": { + "fields": [{ + "name": "_1", + "dataType": { + "integer": { + } + } + }, { + "name": "_2", + "dataType": { + "string": { + "collation": "UTF8_BINARY" + } + }, + "nullable": true + }] + } + } + }, { + "null": { + "null": { + } + } + }, { + "struct": { + "elements": [{ + "integer": 1 + }, { + "string": "a" + }], + "dataTypeStruct": { + "fields": [{ + "name": "_1" + }, { + "name": "_2", + "dataType": { + "string": { + "collation": "UTF8_BINARY" + } + }, + "nullable": true + }] + } + } + }, { + "struct": { + "elements": [{ + "integer": 2 + }, { + "null": { + "string": { + "collation": "UTF8_BINARY" + } + } + }], + "dataTypeStruct": { + "fields": [{ + "name": "_1" + }, { + "name": "_2", + "nullable": true + }] + } + } + }], + "dataType": { + "containsNull": true + } + } + }, + "common": { + "origin": { + "jvmOrigin": { + "stackTrace": [{ + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.functions$", + "methodName": "typedLit", + "fileName": "functions.scala" + }, { + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite", + "methodName": "~~trimmed~anonfun~~", + "fileName": "PlanGenerationTestSuite.scala" + }] + } + } + } + }, { + "literal": { + "array": { + "elements": [{ + "null": { + "struct": { + "fields": [{ + "name": "_1", + "dataType": { + "integer": { + } + } + }, { + "name": "_2", + "dataType": { + "string": { + "collation": "UTF8_BINARY" + } + }, + "nullable": true + }] + } + } + }, { + "null": { + "null": { + } + } + }, { + "struct": { + "elements": [{ + "integer": 1 + }, { + "string": "a" + }], + "dataTypeStruct": { + "fields": [{ + "name": "_1" + }, { + "name": "_2", + "dataType": { + "string": { + "collation": "UTF8_BINARY" + } + }, + "nullable": true + }] + } + } + }], + "dataType": { + "containsNull": true + } + } + }, + "common": { + "origin": { + "jvmOrigin": { + "stackTrace": [{ + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.functions$", + "methodName": "typedLit", + "fileName": "functions.scala" + }, { + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite", + "methodName": "~~trimmed~anonfun~~", + "fileName": "PlanGenerationTestSuite.scala" + }] + } + } + } }, { "literal": { "array": { @@ -891,6 +1092,52 @@ } } } + }, { + "literal": { + "map": { + "keys": [{ + "string": "a" + }, { + "string": "b" + }], + "values": [{ + "null": { + "integer": { + } + } + }, { + "null": { + "null": { + } + } + }], + "dataType": { + "keyType": { + "string": { + "collation": "UTF8_BINARY" + } + }, + "valueContainsNull": true + } + } + }, + "common": { + "origin": { + "jvmOrigin": { + "stackTrace": [{ + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.functions$", + "methodName": "typedLit", + "fileName": "functions.scala" + }, { + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite", + "methodName": "~~trimmed~anonfun~~", + "fileName": "PlanGenerationTestSuite.scala" + }] + } + } + } }, { "literal": { "struct": { diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin index 00f80df0e229..6c5ea53d05a9 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin differ diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala index 80c185ee8b3c..9a2827cf8b55 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala @@ -53,7 +53,11 @@ class LiteralExpressionProtoConverterSuite extends AnyFunSuite { // scalastyle:i } } + // The goal of this test is to check that converting a Scala value -> Proto -> Catalyst value + // is equivalent to converting a Scala value directly to a Catalyst value. Seq[(Any, DataType)]( + (Array[String](null, "a", null), ArrayType(StringType)), + (Map[String, String]("a" -> null, "b" -> null), MapType(StringType, StringType)), ( (1, "string", true), StructType( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org