This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 87a71fabb097 [SPARK-53438][CONNECT][SQL] Use CatalystConverter in 
LiteralExpressionProtoConverter
87a71fabb097 is described below

commit 87a71fabb097e1543a935fae8167bc47a29a127e
Author: Yihong He <heyihong...@gmail.com>
AuthorDate: Thu Sep 18 01:35:00 2025 +0800

    [SPARK-53438][CONNECT][SQL] Use CatalystConverter in 
LiteralExpressionProtoConverter
    
    ### What changes were proposed in this pull request?
    
    This PR refactors the `LiteralExpressionProtoConverter` to use 
`CatalystTypeConverters` for consistent type conversion, eliminating code 
duplication and improving maintainability.
    
    **Key changes:**
    1. **Simplified `LiteralExpressionProtoConverter.toCatalystExpression()`**: 
Replaced the large switch statement (86 lines) with a clean 3-line 
implementation that leverages existing conversion utilities
    2. **Added TIME type support**: Added missing TIME literal type conversion 
in `LiteralValueProtoConverter.toScalaValue()`
    
    ### Why are the changes needed?
    
    1. **Type conversion issues**: Some complex nested data structures (e.g., 
arrays of case classes) failed to convert to Catalyst's internal representation 
when using `expressions.Literal.create(...)`.
    2. **Inconsistent behaviors**: Differences in behavior between Spark 
Connect and classic Spark for the same data types (e.g., Decimal).
    
    ### Does this PR introduce _any_ user-facing change?
    
    **Yes** - Users can now successfully use `typedLit` with an array of case 
classes. Previously, attempting to use `typedlit(Array(CaseClass(1, "a")))` 
would fail (please see the code piece below for details), but now it works 
correctly and converts case classes to proper struct representations.
    
    ```scala
    import org.apache.spark.sql.functions.typedlit
    case class CaseClass(a: Int, b: String)
    spark.sql("select 1").select(typedlit(Array(CaseClass(1, "a")))).collect()
    
    // Below is the error message:
    """
    org.apache.spark.SparkIllegalArgumentException: requirement failed: Literal 
must have a corresponding value to array<struct<a:int,b:string>>, but class 
GenericArrayData found.
      scala.Predef$.require(Predef.scala:337)
      
org.apache.spark.sql.catalyst.expressions.Literal$.validateLiteralValue(literals.scala:306)
      
org.apache.spark.sql.catalyst.expressions.Literal.<init>(literals.scala:456)
      
org.apache.spark.sql.catalyst.expressions.Literal$.create(literals.scala:206)
      
org.apache.spark.sql.connect.planner.LiteralExpressionProtoConverter$.toCatalystExpression(LiteralExpressionProtoConverter.scala:103)
    """
    ```
    
    Besides, some catalyst values (e.g., Decimal 89.97620 -> 
89.976200000000000000) have changed. Please see the changes in 
`explain-results/` for details.
    ```scala
    import org.apache.spark.sql.functions.typedlit
    
    spark.sql("select 1").select(typedlit(BigDecimal(8997620, 
5)),typedlit(Array(BigDecimal(8997620, 5), BigDecimal(8997621, 5)))).explain()
    // Current explain() output:
    """
    Project [89.97620 AS 89.97620#819, [89.97620,89.97621] AS ARRAY(89.97620BD, 
89.97621BD)#820]
    """
    // Expected explain() output (i.e., same as the classic mode):
    """
    Project [89.976200000000000000 AS 89.976200000000000000#132, 
[89.976200000000000000,89.976210000000000000] AS ARRAY(89.976200000000000000BD, 
89.976210000000000000BD)#133]
    """
    ```
    ### How was this patch tested?
    
    `build/sbt "connect-client-jvm/testOnly 
org.apache.spark.sql.PlanGenerationTestSuite"`
    `build/sbt "connect/testOnly 
org.apache.spark.sql.connect.ProtoToParsedPlanTestSuite"`
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Generated-by: Cursor 1.4.5
    
    Closes #52188 from heyihong/SPARK-53438.
    
    Authored-by: Yihong He <heyihong...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../apache/spark/sql/PlanGenerationTestSuite.scala |   4 +
 .../common/LiteralValueProtoConverter.scala        |   3 +
 .../explain-results/function_lit_array.explain     |   2 +-
 .../explain-results/function_typedLit.explain      |   2 +-
 .../query-tests/queries/function_typedLit.json     | 190 +++++++++++++++++++++
 .../queries/function_typedLit.proto.bin            | Bin 10943 -> 11642 bytes
 .../planner/LiteralExpressionProtoConverter.scala  |  91 +---------
 7 files changed, 205 insertions(+), 87 deletions(-)

diff --git 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
index 28498f18cb08..b5eabb82b88d 100644
--- 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++ 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -307,6 +307,8 @@ class PlanGenerationTestSuite
   private def temporals = createLocalRelation(temporalsSchemaString)
   private def boolean = createLocalRelation(booleanSchemaString)
 
+  private case class CaseClass(a: Int, b: String)
+
   /* Spark Session API */
   test("range") {
     session.range(1, 10, 1, 2)
@@ -3433,6 +3435,8 @@ class PlanGenerationTestSuite
       fn.typedlit[collection.immutable.Map[Int, Option[Int]]](
         collection.immutable.Map(1 -> None)),
       fn.typedLit(Seq(Seq(1, 2, 3), Seq(4, 5, 6), Seq(7, 8, 9))),
+      fn.typedLit(Seq((1, "2", Seq("3", "4")), (5, "6", Seq.empty[String]))),
+      fn.typedLit(Seq(CaseClass(1, "2"), CaseClass(3, "4"), CaseClass(5, 
"6"))),
       fn.typedLit(
         Seq(
           mutable.LinkedHashMap("a" -> 1, "b" -> 2),
diff --git 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala
 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala
index 16bbeb99557b..3c07bd5851fb 100644
--- 
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala
+++ 
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala
@@ -476,6 +476,9 @@ object LiteralValueProtoConverter {
       case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL =>
         SparkIntervalUtils.microsToDuration(literal.getDayTimeInterval)
 
+      case proto.Expression.Literal.LiteralTypeCase.TIME =>
+        SparkDateTimeUtils.nanosToLocalTime(literal.getTime.getNano)
+
       case proto.Expression.Literal.LiteralTypeCase.ARRAY =>
         toScalaArray(literal.getArray)
 
diff --git 
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_lit_array.explain
 
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_lit_array.explain
index 74d512b6910c..0f4ae8813e89 100644
--- 
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_lit_array.explain
+++ 
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_lit_array.explain
@@ -1,2 +1,2 @@
-Project [[] AS ARRAY()#0, [[1],[2],[3]] AS ARRAY(ARRAY(1), ARRAY(2), 
ARRAY(3))#0, [[[1]],[[2]],[[3]]] AS ARRAY(ARRAY(ARRAY(1)), ARRAY(ARRAY(2)), 
ARRAY(ARRAY(3)))#0, [true,false] AS ARRAY(true, false)#0, 0x434445 AS 
X'434445'#0, [9872,9873,9874] AS ARRAY(9872S, 9873S, 9874S)#0, 
[-8726532,8726532,-8726533] AS ARRAY(-8726532, 8726532, -8726533)#0, 
[7834609328726531,7834609328726532,7834609328726533] AS 
ARRAY(7834609328726531L, 7834609328726532L, 7834609328726533L)#0, 
[2.718281828459045,1.0, [...]
+Project [[] AS ARRAY()#0, [[1],[2],[3]] AS ARRAY(ARRAY(1), ARRAY(2), 
ARRAY(3))#0, [[[1]],[[2]],[[3]]] AS ARRAY(ARRAY(ARRAY(1)), ARRAY(ARRAY(2)), 
ARRAY(ARRAY(3)))#0, [true,false] AS ARRAY(true, false)#0, 0x434445 AS 
X'434445'#0, [9872,9873,9874] AS ARRAY(9872S, 9873S, 9874S)#0, 
[-8726532,8726532,-8726533] AS ARRAY(-8726532, 8726532, -8726533)#0, 
[7834609328726531,7834609328726532,7834609328726533] AS 
ARRAY(7834609328726531L, 7834609328726532L, 7834609328726533L)#0, 
[2.718281828459045,1.0, [...]
 +- LocalRelation <empty>, [id#0L, a#0, b#0]
diff --git 
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain
 
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain
index 5daa50bfe38a..3c878be34143 100644
--- 
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain
+++ 
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain
@@ -1,2 +1,2 @@
-Project [id#0L, id#0L, 1 AS 1#0, null AS NULL#0, true AS true#0, 68 AS 68#0, 
9872 AS 9872#0, -8726532 AS -8726532#0, 7834609328726532 AS 
7834609328726532#0L, 2.718281828459045 AS 2.718281828459045#0, -0.8 AS -0.8#0, 
89.97620 AS 89.97620#0, 89889.7667231 AS 89889.7667231#0, connect! AS 
connect!#0, T AS T#0, ABCDEFGHIJ AS ABCDEFGHIJ#0, 
0x78797A7B7C7D7E7F808182838485868788898A8B8C8D8E AS 
X'78797A7B7C7D7E7F808182838485868788898A8B8C8D8E'#0, 0x0806 AS X'0806'#0, [8,6] 
AS ARRAY(8, 6)#0, null A [...]
+Project [id#0L, id#0L, 1 AS 1#0, null AS NULL#0, true AS true#0, 68 AS 68#0, 
9872 AS 9872#0, -8726532 AS -8726532#0, 7834609328726532 AS 
7834609328726532#0L, 2.718281828459045 AS 2.718281828459045#0, -0.8 AS -0.8#0, 
89.97620 AS 89.97620#0, 89889.7667231 AS 89889.7667231#0, connect! AS 
connect!#0, T AS T#0, ABCDEFGHIJ AS ABCDEFGHIJ#0, 
0x78797A7B7C7D7E7F808182838485868788898A8B8C8D8E AS 
X'78797A7B7C7D7E7F808182838485868788898A8B8C8D8E'#0, 0x0806 AS X'0806'#0, [8,6] 
AS ARRAY(8, 6)#0, null A [...]
 +- LocalRelation <empty>, [id#0L, a#0, b#0]
diff --git 
a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json
 
b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json
index db7b2a992e94..1b989d402ee4 100644
--- 
a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json
+++ 
b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json
@@ -1394,6 +1394,196 @@
           }
         }
       }
+    }, {
+      "literal": {
+        "array": {
+          "elements": [{
+            "struct": {
+              "elements": [{
+                "integer": 1
+              }, {
+                "string": "2"
+              }, {
+                "array": {
+                  "elements": [{
+                    "string": "3"
+                  }, {
+                    "string": "4"
+                  }],
+                  "dataType": {
+                    "elementType": {
+                      "string": {
+                        "collation": "UTF8_BINARY"
+                      }
+                    },
+                    "containsNull": true
+                  }
+                }
+              }],
+              "dataTypeStruct": {
+                "fields": [{
+                  "name": "_1"
+                }, {
+                  "name": "_2",
+                  "dataType": {
+                    "string": {
+                      "collation": "UTF8_BINARY"
+                    }
+                  },
+                  "nullable": true
+                }, {
+                  "name": "_3",
+                  "nullable": true
+                }]
+              }
+            }
+          }, {
+            "struct": {
+              "elements": [{
+                "integer": 5
+              }, {
+                "string": "6"
+              }, {
+                "array": {
+                  "dataType": {
+                    "elementType": {
+                      "string": {
+                        "collation": "UTF8_BINARY"
+                      }
+                    },
+                    "containsNull": true
+                  }
+                }
+              }],
+              "dataTypeStruct": {
+                "fields": [{
+                  "name": "_1"
+                }, {
+                  "name": "_2",
+                  "dataType": {
+                    "string": {
+                      "collation": "UTF8_BINARY"
+                    }
+                  },
+                  "nullable": true
+                }, {
+                  "name": "_3",
+                  "nullable": true
+                }]
+              }
+            }
+          }],
+          "dataType": {
+            "containsNull": true
+          }
+        }
+      },
+      "common": {
+        "origin": {
+          "jvmOrigin": {
+            "stackTrace": [{
+              "classLoaderName": "app",
+              "declaringClass": "org.apache.spark.sql.functions$",
+              "methodName": "typedLit",
+              "fileName": "functions.scala"
+            }, {
+              "classLoaderName": "app",
+              "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite",
+              "methodName": "~~trimmed~anonfun~~",
+              "fileName": "PlanGenerationTestSuite.scala"
+            }]
+          }
+        }
+      }
+    }, {
+      "literal": {
+        "array": {
+          "elements": [{
+            "struct": {
+              "elements": [{
+                "integer": 1
+              }, {
+                "string": "2"
+              }],
+              "dataTypeStruct": {
+                "fields": [{
+                  "name": "a"
+                }, {
+                  "name": "b",
+                  "dataType": {
+                    "string": {
+                      "collation": "UTF8_BINARY"
+                    }
+                  },
+                  "nullable": true
+                }]
+              }
+            }
+          }, {
+            "struct": {
+              "elements": [{
+                "integer": 3
+              }, {
+                "string": "4"
+              }],
+              "dataTypeStruct": {
+                "fields": [{
+                  "name": "a"
+                }, {
+                  "name": "b",
+                  "dataType": {
+                    "string": {
+                      "collation": "UTF8_BINARY"
+                    }
+                  },
+                  "nullable": true
+                }]
+              }
+            }
+          }, {
+            "struct": {
+              "elements": [{
+                "integer": 5
+              }, {
+                "string": "6"
+              }],
+              "dataTypeStruct": {
+                "fields": [{
+                  "name": "a"
+                }, {
+                  "name": "b",
+                  "dataType": {
+                    "string": {
+                      "collation": "UTF8_BINARY"
+                    }
+                  },
+                  "nullable": true
+                }]
+              }
+            }
+          }],
+          "dataType": {
+            "containsNull": true
+          }
+        }
+      },
+      "common": {
+        "origin": {
+          "jvmOrigin": {
+            "stackTrace": [{
+              "classLoaderName": "app",
+              "declaringClass": "org.apache.spark.sql.functions$",
+              "methodName": "typedLit",
+              "fileName": "functions.scala"
+            }, {
+              "classLoaderName": "app",
+              "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite",
+              "methodName": "~~trimmed~anonfun~~",
+              "fileName": "PlanGenerationTestSuite.scala"
+            }]
+          }
+        }
+      }
     }, {
       "literal": {
         "array": {
diff --git 
a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin
 
b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin
index 6c5ea53d05a9..734f8576d24e 100644
Binary files 
a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin
 and 
b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin
 differ
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala
index be7d67279cc1..4c8911c88188 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala
@@ -18,10 +18,9 @@
 package org.apache.spark.sql.connect.planner
 
 import org.apache.spark.connect.proto
-import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters}
-import org.apache.spark.sql.connect.common.{InvalidPlanInput, 
LiteralValueProtoConverter}
-import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
+import org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.sql.connect.common.LiteralValueProtoConverter
 
 object LiteralExpressionProtoConverter {
 
@@ -33,86 +32,8 @@ object LiteralExpressionProtoConverter {
    */
   def toCatalystExpression(lit: proto.Expression.Literal): expressions.Literal 
= {
     val dataType = LiteralValueProtoConverter.getDataType(lit)
-    lit.getLiteralTypeCase match {
-      case proto.Expression.Literal.LiteralTypeCase.NULL =>
-        expressions.Literal(null, dataType)
-
-      case proto.Expression.Literal.LiteralTypeCase.BINARY =>
-        expressions.Literal(lit.getBinary.toByteArray, dataType)
-
-      case proto.Expression.Literal.LiteralTypeCase.BOOLEAN =>
-        expressions.Literal(lit.getBoolean, dataType)
-
-      case proto.Expression.Literal.LiteralTypeCase.BYTE =>
-        expressions.Literal(lit.getByte.toByte, dataType)
-
-      case proto.Expression.Literal.LiteralTypeCase.SHORT =>
-        expressions.Literal(lit.getShort.toShort, dataType)
-
-      case proto.Expression.Literal.LiteralTypeCase.INTEGER =>
-        expressions.Literal(lit.getInteger, dataType)
-
-      case proto.Expression.Literal.LiteralTypeCase.LONG =>
-        expressions.Literal(lit.getLong, dataType)
-
-      case proto.Expression.Literal.LiteralTypeCase.FLOAT =>
-        expressions.Literal(lit.getFloat, dataType)
-
-      case proto.Expression.Literal.LiteralTypeCase.DOUBLE =>
-        expressions.Literal(lit.getDouble, dataType)
-
-      case proto.Expression.Literal.LiteralTypeCase.DECIMAL =>
-        expressions.Literal(Decimal.apply(lit.getDecimal.getValue), dataType)
-
-      case proto.Expression.Literal.LiteralTypeCase.STRING =>
-        expressions.Literal(UTF8String.fromString(lit.getString), dataType)
-
-      case proto.Expression.Literal.LiteralTypeCase.DATE =>
-        expressions.Literal(lit.getDate, dataType)
-
-      case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP =>
-        expressions.Literal(lit.getTimestamp, dataType)
-
-      case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP_NTZ =>
-        expressions.Literal(lit.getTimestampNtz, dataType)
-
-      case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL =>
-        val interval = new CalendarInterval(
-          lit.getCalendarInterval.getMonths,
-          lit.getCalendarInterval.getDays,
-          lit.getCalendarInterval.getMicroseconds)
-        expressions.Literal(interval, dataType)
-
-      case proto.Expression.Literal.LiteralTypeCase.YEAR_MONTH_INTERVAL =>
-        expressions.Literal(lit.getYearMonthInterval, dataType)
-
-      case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL =>
-        expressions.Literal(lit.getDayTimeInterval, dataType)
-
-      case proto.Expression.Literal.LiteralTypeCase.TIME =>
-        var precision = TimeType.DEFAULT_PRECISION
-        if (lit.getTime.hasPrecision) {
-          precision = lit.getTime.getPrecision
-        }
-        expressions.Literal(lit.getTime.getNano, dataType)
-
-      case proto.Expression.Literal.LiteralTypeCase.ARRAY =>
-        val arrayData = LiteralValueProtoConverter.toScalaArray(lit.getArray)
-        expressions.Literal.create(arrayData, dataType)
-
-      case proto.Expression.Literal.LiteralTypeCase.MAP =>
-        val mapData = LiteralValueProtoConverter.toScalaMap(lit.getMap)
-        expressions.Literal.create(mapData, dataType)
-
-      case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
-        val structData = 
LiteralValueProtoConverter.toScalaStruct(lit.getStruct)
-        val convert = 
CatalystTypeConverters.createToCatalystConverter(dataType)
-        expressions.Literal(convert(structData), dataType)
-
-      case _ =>
-        throw InvalidPlanInput(
-          s"Unsupported Literal Type: ${lit.getLiteralTypeCase.name}" +
-            s"(${lit.getLiteralTypeCase.getNumber})")
-    }
+    val scalaValue = LiteralValueProtoConverter.toScalaValue(lit)
+    val convert = CatalystTypeConverters.createToCatalystConverter(dataType)
+    expressions.Literal(convert(scalaValue), dataType)
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to