This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 19ca63f82fed [SPARK-53553][CONNECT] Fix handling of null values in 
LiteralValueProtoConverter
19ca63f82fed is described below

commit 19ca63f82fedfa78a25205b12fe3aacf8c2fc815
Author: Yihong He <heyihong...@gmail.com>
AuthorDate: Mon Sep 15 09:46:40 2025 +0800

    [SPARK-53553][CONNECT] Fix handling of null values in 
LiteralValueProtoConverter
    
    ### What changes were proposed in this pull request?
    
    This PR fixes the handling of null literal values in 
`LiteralValueProtoConverter` for Spark Connect. The main changes include:
    
    1. **Added proper null value handling**: Created a new `setNullValue` 
method that correctly sets null values in proto literals with appropriate data 
type information.
    
    2. **Reordered pattern matching**: Moved null and Option handling to the 
top of the pattern matching in `toLiteralProtoBuilderInternal` to ensure null 
values are processed before other type-specific logic.
    
    3. **Fixed converter logic**: Updated the `getScalaConverter` method to 
properly handle null values by checking `hasNull` before applying type-specific 
conversion logic.
    
    ### Why are the changes needed?
    
    The previous implementation had several issues with null value handling:
    
    1. **Incorrect null processing order**: Null values were being processed 
after type-specific logic, which could lead to exceptions.
    
    2. **Missing null checks in converters**: The converter functions didn't 
properly check for null values before applying type-specific conversion logic.
    
    ### Does this PR introduce _any_ user-facing change?
    
    **Yes**. This PR fixes a bug where null values in literals (especially in 
arrays and maps) were not being properly handled in Spark Connect. Users who 
were experiencing issues with null value serialization in complex types should 
now see correct behavior.
    
    ### How was this patch tested?
    
    `build/sbt "connect-client-jvm/testOnly *ClientE2ETestSuite -- -z 
SPARK-53553"`
    `build/sbt "connect/testOnly *LiteralExpressionProtoConverterSuite"`
    `build/sbt "connect-client-jvm/testOnly 
org.apache.spark.sql.PlanGenerationTestSuite"`
    `build/sbt "connect/testOnly 
org.apache.spark.sql.connect.ProtoToParsedPlanTestSuite"`
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: Cursor 1.5.11
    
    Closes #52310 from heyihong/SPARK-53553.
    
    Authored-by: Yihong He <heyihong...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../apache/spark/sql/PlanGenerationTestSuite.scala |   4 +
 .../spark/sql/connect/ClientE2ETestSuite.scala     |   7 +
 .../common/LiteralValueProtoConverter.scala        |  30 ++-
 .../explain-results/function_typedLit.explain      |   2 +-
 .../query-tests/queries/function_typedLit.json     | 249 ++++++++++++++++++++-
 .../queries/function_typedLit.proto.bin            | Bin 9867 -> 10943 bytes
 .../LiteralExpressionProtoConverterSuite.scala     |   4 +
 7 files changed, 287 insertions(+), 9 deletions(-)

diff --git 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
index c6561510c035..28498f18cb08 100644
--- 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++ 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -3419,8 +3419,12 @@ class PlanGenerationTestSuite
       // Handle parameterized scala types e.g.: List, Seq and Map.
       fn.typedLit(Some(1)),
       fn.typedLit(Array(1, 2, 3)),
+      fn.typedLit[Array[Integer]](Array(null, null)),
+      fn.typedLit[Array[(Int, String)]](Array(null, null, (1, "a"), (2, 
null))),
+      fn.typedLit[Array[Option[(Int, String)]]](Array(None, None, Some((1, 
"a")))),
       fn.typedLit(Seq(1, 2, 3)),
       fn.typedLit(mutable.LinkedHashMap("a" -> 1, "b" -> 2)),
+      fn.typedLit(mutable.LinkedHashMap[String, Integer]("a" -> null, "b" -> 
null)),
       fn.typedLit(("a", 2, 1.0)),
       fn.typedLit[Option[Int]](None),
       fn.typedLit[Array[Option[Int]]](Array(Some(1))),
diff --git 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
index e2213003656e..db165c03ad35 100644
--- 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
+++ 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala
@@ -1785,6 +1785,13 @@ class ClientE2ETestSuite
     assert(observation.get.contains("map"))
     assert(observation.get("map") === Map("count" -> 10))
   }
+
+  test("SPARK-53553: null value handling in literals") {
+    val df = spark.sql("select 1").select(typedlit(Array[Integer](1, 
null)).as("arr_col"))
+    val result = df.collect()
+    assert(result.length === 1)
+    assert(result(0).getAs[Array[Integer]]("arr_col") === Array(1, null))
+  }
 }
 
 private[sql] case class ClassData(a: String, b: Int)
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala
index 286b83d4eae9..16bbeb99557b 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala
@@ -40,6 +40,19 @@ import org.apache.spark.unsafe.types.CalendarInterval
 
 object LiteralValueProtoConverter {
 
+  private def setNullValue(
+      builder: proto.Expression.Literal.Builder,
+      dataType: DataType,
+      needDataType: Boolean): proto.Expression.Literal.Builder = {
+    if (needDataType) {
+      builder.setNull(toConnectProtoType(dataType))
+    } else {
+      // No need data type but still set the null type to indicate that
+      // the value is null.
+      builder.setNull(ProtoDataTypes.NullType)
+    }
+  }
+
   private def setArrayTypeAfterAddingElements(
       ab: proto.Expression.Literal.Array.Builder,
       elementType: DataType,
@@ -275,6 +288,14 @@ object LiteralValueProtoConverter {
     }
 
     (literal, dataType) match {
+      case (v: Option[_], _) =>
+        if (v.isDefined) {
+          toLiteralProtoBuilderInternal(v.get, dataType, options, needDataType)
+        } else {
+          setNullValue(builder, dataType, needDataType)
+        }
+      case (null, _) =>
+        setNullValue(builder, dataType, needDataType)
       case (v: mutable.ArraySeq[_], ArrayType(_, _)) =>
         toLiteralProtoBuilderInternal(v.array, dataType, options, needDataType)
       case (v: immutable.ArraySeq[_], ArrayType(_, _)) =>
@@ -287,12 +308,6 @@ object LiteralValueProtoConverter {
         builder.setMap(mapBuilder(v, keyType, valueType, valueContainsNull))
       case (v, structType: StructType) =>
         builder.setStruct(structBuilder(v, structType))
-      case (v: Option[_], _: DataType) =>
-        if (v.isDefined) {
-          toLiteralProtoBuilderInternal(v.get, options, needDataType)
-        } else {
-          builder.setNull(toConnectProtoType(dataType))
-        }
       case (v: LocalTime, timeType: TimeType) =>
         builder.setTime(
           builder.getTimeBuilder
@@ -477,7 +492,7 @@ object LiteralValueProtoConverter {
   }
 
   private def getScalaConverter(dataType: proto.DataType): 
proto.Expression.Literal => Any = {
-    dataType.getKindCase match {
+    val converter: proto.Expression.Literal => Any = dataType.getKindCase 
match {
       case proto.DataType.KindCase.SHORT => v => v.getShort.toShort
       case proto.DataType.KindCase.INTEGER => v => v.getInteger
       case proto.DataType.KindCase.LONG => v => v.getLong
@@ -513,6 +528,7 @@ object LiteralValueProtoConverter {
       case _ =>
         throw InvalidPlanInput(s"Unsupported Literal Type: 
${dataType.getKindCase}")
     }
+    v => if (v.hasNull) null else converter(v)
   }
 
   private def getInferredDataType(
diff --git 
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain
 
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain
index 817b923202c5..5daa50bfe38a 100644
--- 
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain
+++ 
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain
@@ -1,2 +1,2 @@
-Project [id#0L, id#0L, 1 AS 1#0, null AS NULL#0, true AS true#0, 68 AS 68#0, 
9872 AS 9872#0, -8726532 AS -8726532#0, 7834609328726532 AS 
7834609328726532#0L, 2.718281828459045 AS 2.718281828459045#0, -0.8 AS -0.8#0, 
89.97620 AS 89.97620#0, 89889.7667231 AS 89889.7667231#0, connect! AS 
connect!#0, T AS T#0, ABCDEFGHIJ AS ABCDEFGHIJ#0, 
0x78797A7B7C7D7E7F808182838485868788898A8B8C8D8E AS 
X'78797A7B7C7D7E7F808182838485868788898A8B8C8D8E'#0, 0x0806 AS X'0806'#0, [8,6] 
AS ARRAY(8, 6)#0, null A [...]
+Project [id#0L, id#0L, 1 AS 1#0, null AS NULL#0, true AS true#0, 68 AS 68#0, 
9872 AS 9872#0, -8726532 AS -8726532#0, 7834609328726532 AS 
7834609328726532#0L, 2.718281828459045 AS 2.718281828459045#0, -0.8 AS -0.8#0, 
89.97620 AS 89.97620#0, 89889.7667231 AS 89889.7667231#0, connect! AS 
connect!#0, T AS T#0, ABCDEFGHIJ AS ABCDEFGHIJ#0, 
0x78797A7B7C7D7E7F808182838485868788898A8B8C8D8E AS 
X'78797A7B7C7D7E7F808182838485868788898A8B8C8D8E'#0, 0x0806 AS X'0806'#0, [8,6] 
AS ARRAY(8, 6)#0, null A [...]
 +- LocalRelation <empty>, [id#0L, a#0, b#0]
diff --git 
a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json
 
b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json
index 5869ec44789d..db7b2a992e94 100644
--- 
a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json
+++ 
b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json
@@ -77,7 +77,8 @@
     }, {
       "literal": {
         "null": {
-          "null": {
+          "string": {
+            "collation": "UTF8_BINARY"
           }
         }
       },
@@ -821,6 +822,206 @@
           }
         }
       }
+    }, {
+      "literal": {
+        "array": {
+          "elements": [{
+            "null": {
+              "integer": {
+              }
+            }
+          }, {
+            "null": {
+              "null": {
+              }
+            }
+          }],
+          "dataType": {
+            "containsNull": true
+          }
+        }
+      },
+      "common": {
+        "origin": {
+          "jvmOrigin": {
+            "stackTrace": [{
+              "classLoaderName": "app",
+              "declaringClass": "org.apache.spark.sql.functions$",
+              "methodName": "typedLit",
+              "fileName": "functions.scala"
+            }, {
+              "classLoaderName": "app",
+              "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite",
+              "methodName": "~~trimmed~anonfun~~",
+              "fileName": "PlanGenerationTestSuite.scala"
+            }]
+          }
+        }
+      }
+    }, {
+      "literal": {
+        "array": {
+          "elements": [{
+            "null": {
+              "struct": {
+                "fields": [{
+                  "name": "_1",
+                  "dataType": {
+                    "integer": {
+                    }
+                  }
+                }, {
+                  "name": "_2",
+                  "dataType": {
+                    "string": {
+                      "collation": "UTF8_BINARY"
+                    }
+                  },
+                  "nullable": true
+                }]
+              }
+            }
+          }, {
+            "null": {
+              "null": {
+              }
+            }
+          }, {
+            "struct": {
+              "elements": [{
+                "integer": 1
+              }, {
+                "string": "a"
+              }],
+              "dataTypeStruct": {
+                "fields": [{
+                  "name": "_1"
+                }, {
+                  "name": "_2",
+                  "dataType": {
+                    "string": {
+                      "collation": "UTF8_BINARY"
+                    }
+                  },
+                  "nullable": true
+                }]
+              }
+            }
+          }, {
+            "struct": {
+              "elements": [{
+                "integer": 2
+              }, {
+                "null": {
+                  "string": {
+                    "collation": "UTF8_BINARY"
+                  }
+                }
+              }],
+              "dataTypeStruct": {
+                "fields": [{
+                  "name": "_1"
+                }, {
+                  "name": "_2",
+                  "nullable": true
+                }]
+              }
+            }
+          }],
+          "dataType": {
+            "containsNull": true
+          }
+        }
+      },
+      "common": {
+        "origin": {
+          "jvmOrigin": {
+            "stackTrace": [{
+              "classLoaderName": "app",
+              "declaringClass": "org.apache.spark.sql.functions$",
+              "methodName": "typedLit",
+              "fileName": "functions.scala"
+            }, {
+              "classLoaderName": "app",
+              "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite",
+              "methodName": "~~trimmed~anonfun~~",
+              "fileName": "PlanGenerationTestSuite.scala"
+            }]
+          }
+        }
+      }
+    }, {
+      "literal": {
+        "array": {
+          "elements": [{
+            "null": {
+              "struct": {
+                "fields": [{
+                  "name": "_1",
+                  "dataType": {
+                    "integer": {
+                    }
+                  }
+                }, {
+                  "name": "_2",
+                  "dataType": {
+                    "string": {
+                      "collation": "UTF8_BINARY"
+                    }
+                  },
+                  "nullable": true
+                }]
+              }
+            }
+          }, {
+            "null": {
+              "null": {
+              }
+            }
+          }, {
+            "struct": {
+              "elements": [{
+                "integer": 1
+              }, {
+                "string": "a"
+              }],
+              "dataTypeStruct": {
+                "fields": [{
+                  "name": "_1"
+                }, {
+                  "name": "_2",
+                  "dataType": {
+                    "string": {
+                      "collation": "UTF8_BINARY"
+                    }
+                  },
+                  "nullable": true
+                }]
+              }
+            }
+          }],
+          "dataType": {
+            "containsNull": true
+          }
+        }
+      },
+      "common": {
+        "origin": {
+          "jvmOrigin": {
+            "stackTrace": [{
+              "classLoaderName": "app",
+              "declaringClass": "org.apache.spark.sql.functions$",
+              "methodName": "typedLit",
+              "fileName": "functions.scala"
+            }, {
+              "classLoaderName": "app",
+              "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite",
+              "methodName": "~~trimmed~anonfun~~",
+              "fileName": "PlanGenerationTestSuite.scala"
+            }]
+          }
+        }
+      }
     }, {
       "literal": {
         "array": {
@@ -891,6 +1092,52 @@
           }
         }
       }
+    }, {
+      "literal": {
+        "map": {
+          "keys": [{
+            "string": "a"
+          }, {
+            "string": "b"
+          }],
+          "values": [{
+            "null": {
+              "integer": {
+              }
+            }
+          }, {
+            "null": {
+              "null": {
+              }
+            }
+          }],
+          "dataType": {
+            "keyType": {
+              "string": {
+                "collation": "UTF8_BINARY"
+              }
+            },
+            "valueContainsNull": true
+          }
+        }
+      },
+      "common": {
+        "origin": {
+          "jvmOrigin": {
+            "stackTrace": [{
+              "classLoaderName": "app",
+              "declaringClass": "org.apache.spark.sql.functions$",
+              "methodName": "typedLit",
+              "fileName": "functions.scala"
+            }, {
+              "classLoaderName": "app",
+              "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite",
+              "methodName": "~~trimmed~anonfun~~",
+              "fileName": "PlanGenerationTestSuite.scala"
+            }]
+          }
+        }
+      }
     }, {
       "literal": {
         "struct": {
diff --git 
a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin
 
b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin
index 00f80df0e229..6c5ea53d05a9 100644
Binary files 
a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin
 and 
b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin
 differ
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala
index 80c185ee8b3c..9a2827cf8b55 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala
@@ -53,7 +53,11 @@ class LiteralExpressionProtoConverterSuite extends 
AnyFunSuite { // scalastyle:i
     }
   }
 
+  // The goal of this test is to check that converting a Scala value -> Proto 
-> Catalyst value
+  // is equivalent to converting a Scala value directly to a Catalyst value.
   Seq[(Any, DataType)](
+    (Array[String](null, "a", null), ArrayType(StringType)),
+    (Map[String, String]("a" -> null, "b" -> null), MapType(StringType, 
StringType)),
     (
       (1, "string", true),
       StructType(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to