This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 87a71fabb097 [SPARK-53438][CONNECT][SQL] Use CatalystConverter in LiteralExpressionProtoConverter 87a71fabb097 is described below commit 87a71fabb097e1543a935fae8167bc47a29a127e Author: Yihong He <heyihong...@gmail.com> AuthorDate: Thu Sep 18 01:35:00 2025 +0800 [SPARK-53438][CONNECT][SQL] Use CatalystConverter in LiteralExpressionProtoConverter ### What changes were proposed in this pull request? This PR refactors the `LiteralExpressionProtoConverter` to use `CatalystTypeConverters` for consistent type conversion, eliminating code duplication and improving maintainability. **Key changes:** 1. **Simplified `LiteralExpressionProtoConverter.toCatalystExpression()`**: Replaced the large switch statement (86 lines) with a clean 3-line implementation that leverages existing conversion utilities 2. **Added TIME type support**: Added missing TIME literal type conversion in `LiteralValueProtoConverter.toScalaValue()` ### Why are the changes needed? 1. **Type conversion issues**: Some complex nested data structures (e.g., arrays of case classes) failed to convert to Catalyst's internal representation when using `expressions.Literal.create(...)`. 2. **Inconsistent behaviors**: Differences in behavior between Spark Connect and classic Spark for the same data types (e.g., Decimal). ### Does this PR introduce _any_ user-facing change? **Yes** - Users can now successfully use `typedLit` with an array of case classes. Previously, attempting to use `typedlit(Array(CaseClass(1, "a")))` would fail (please see the code piece below for details), but now it works correctly and converts case classes to proper struct representations. ```scala import org.apache.spark.sql.functions.typedlit case class CaseClass(a: Int, b: String) spark.sql("select 1").select(typedlit(Array(CaseClass(1, "a")))).collect() // Below is the error message: """ org.apache.spark.SparkIllegalArgumentException: requirement failed: Literal must have a corresponding value to array<struct<a:int,b:string>>, but class GenericArrayData found. scala.Predef$.require(Predef.scala:337) org.apache.spark.sql.catalyst.expressions.Literal$.validateLiteralValue(literals.scala:306) org.apache.spark.sql.catalyst.expressions.Literal.<init>(literals.scala:456) org.apache.spark.sql.catalyst.expressions.Literal$.create(literals.scala:206) org.apache.spark.sql.connect.planner.LiteralExpressionProtoConverter$.toCatalystExpression(LiteralExpressionProtoConverter.scala:103) """ ``` Besides, some catalyst values (e.g., Decimal 89.97620 -> 89.976200000000000000) have changed. Please see the changes in `explain-results/` for details. ```scala import org.apache.spark.sql.functions.typedlit spark.sql("select 1").select(typedlit(BigDecimal(8997620, 5)),typedlit(Array(BigDecimal(8997620, 5), BigDecimal(8997621, 5)))).explain() // Current explain() output: """ Project [89.97620 AS 89.97620#819, [89.97620,89.97621] AS ARRAY(89.97620BD, 89.97621BD)#820] """ // Expected explain() output (i.e., same as the classic mode): """ Project [89.976200000000000000 AS 89.976200000000000000#132, [89.976200000000000000,89.976210000000000000] AS ARRAY(89.976200000000000000BD, 89.976210000000000000BD)#133] """ ``` ### How was this patch tested? `build/sbt "connect-client-jvm/testOnly org.apache.spark.sql.PlanGenerationTestSuite"` `build/sbt "connect/testOnly org.apache.spark.sql.connect.ProtoToParsedPlanTestSuite"` ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Cursor 1.4.5 Closes #52188 from heyihong/SPARK-53438. Authored-by: Yihong He <heyihong...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../apache/spark/sql/PlanGenerationTestSuite.scala | 4 + .../common/LiteralValueProtoConverter.scala | 3 + .../explain-results/function_lit_array.explain | 2 +- .../explain-results/function_typedLit.explain | 2 +- .../query-tests/queries/function_typedLit.json | 190 +++++++++++++++++++++ .../queries/function_typedLit.proto.bin | Bin 10943 -> 11642 bytes .../planner/LiteralExpressionProtoConverter.scala | 91 +--------- 7 files changed, 205 insertions(+), 87 deletions(-) diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index 28498f18cb08..b5eabb82b88d 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -307,6 +307,8 @@ class PlanGenerationTestSuite private def temporals = createLocalRelation(temporalsSchemaString) private def boolean = createLocalRelation(booleanSchemaString) + private case class CaseClass(a: Int, b: String) + /* Spark Session API */ test("range") { session.range(1, 10, 1, 2) @@ -3433,6 +3435,8 @@ class PlanGenerationTestSuite fn.typedlit[collection.immutable.Map[Int, Option[Int]]]( collection.immutable.Map(1 -> None)), fn.typedLit(Seq(Seq(1, 2, 3), Seq(4, 5, 6), Seq(7, 8, 9))), + fn.typedLit(Seq((1, "2", Seq("3", "4")), (5, "6", Seq.empty[String]))), + fn.typedLit(Seq(CaseClass(1, "2"), CaseClass(3, "4"), CaseClass(5, "6"))), fn.typedLit( Seq( mutable.LinkedHashMap("a" -> 1, "b" -> 2), diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala index 16bbeb99557b..3c07bd5851fb 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala @@ -476,6 +476,9 @@ object LiteralValueProtoConverter { case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL => SparkIntervalUtils.microsToDuration(literal.getDayTimeInterval) + case proto.Expression.Literal.LiteralTypeCase.TIME => + SparkDateTimeUtils.nanosToLocalTime(literal.getTime.getNano) + case proto.Expression.Literal.LiteralTypeCase.ARRAY => toScalaArray(literal.getArray) diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_lit_array.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_lit_array.explain index 74d512b6910c..0f4ae8813e89 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_lit_array.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_lit_array.explain @@ -1,2 +1,2 @@ -Project [[] AS ARRAY()#0, [[1],[2],[3]] AS ARRAY(ARRAY(1), ARRAY(2), ARRAY(3))#0, [[[1]],[[2]],[[3]]] AS ARRAY(ARRAY(ARRAY(1)), ARRAY(ARRAY(2)), ARRAY(ARRAY(3)))#0, [true,false] AS ARRAY(true, false)#0, 0x434445 AS X'434445'#0, [9872,9873,9874] AS ARRAY(9872S, 9873S, 9874S)#0, [-8726532,8726532,-8726533] AS ARRAY(-8726532, 8726532, -8726533)#0, [7834609328726531,7834609328726532,7834609328726533] AS ARRAY(7834609328726531L, 7834609328726532L, 7834609328726533L)#0, [2.718281828459045,1.0, [...] +Project [[] AS ARRAY()#0, [[1],[2],[3]] AS ARRAY(ARRAY(1), ARRAY(2), ARRAY(3))#0, [[[1]],[[2]],[[3]]] AS ARRAY(ARRAY(ARRAY(1)), ARRAY(ARRAY(2)), ARRAY(ARRAY(3)))#0, [true,false] AS ARRAY(true, false)#0, 0x434445 AS X'434445'#0, [9872,9873,9874] AS ARRAY(9872S, 9873S, 9874S)#0, [-8726532,8726532,-8726533] AS ARRAY(-8726532, 8726532, -8726533)#0, [7834609328726531,7834609328726532,7834609328726533] AS ARRAY(7834609328726531L, 7834609328726532L, 7834609328726533L)#0, [2.718281828459045,1.0, [...] +- LocalRelation <empty>, [id#0L, a#0, b#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain index 5daa50bfe38a..3c878be34143 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_typedLit.explain @@ -1,2 +1,2 @@ -Project [id#0L, id#0L, 1 AS 1#0, null AS NULL#0, true AS true#0, 68 AS 68#0, 9872 AS 9872#0, -8726532 AS -8726532#0, 7834609328726532 AS 7834609328726532#0L, 2.718281828459045 AS 2.718281828459045#0, -0.8 AS -0.8#0, 89.97620 AS 89.97620#0, 89889.7667231 AS 89889.7667231#0, connect! AS connect!#0, T AS T#0, ABCDEFGHIJ AS ABCDEFGHIJ#0, 0x78797A7B7C7D7E7F808182838485868788898A8B8C8D8E AS X'78797A7B7C7D7E7F808182838485868788898A8B8C8D8E'#0, 0x0806 AS X'0806'#0, [8,6] AS ARRAY(8, 6)#0, null A [...] +Project [id#0L, id#0L, 1 AS 1#0, null AS NULL#0, true AS true#0, 68 AS 68#0, 9872 AS 9872#0, -8726532 AS -8726532#0, 7834609328726532 AS 7834609328726532#0L, 2.718281828459045 AS 2.718281828459045#0, -0.8 AS -0.8#0, 89.97620 AS 89.97620#0, 89889.7667231 AS 89889.7667231#0, connect! AS connect!#0, T AS T#0, ABCDEFGHIJ AS ABCDEFGHIJ#0, 0x78797A7B7C7D7E7F808182838485868788898A8B8C8D8E AS X'78797A7B7C7D7E7F808182838485868788898A8B8C8D8E'#0, 0x0806 AS X'0806'#0, [8,6] AS ARRAY(8, 6)#0, null A [...] +- LocalRelation <empty>, [id#0L, a#0, b#0] diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json index db7b2a992e94..1b989d402ee4 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json @@ -1394,6 +1394,196 @@ } } } + }, { + "literal": { + "array": { + "elements": [{ + "struct": { + "elements": [{ + "integer": 1 + }, { + "string": "2" + }, { + "array": { + "elements": [{ + "string": "3" + }, { + "string": "4" + }], + "dataType": { + "elementType": { + "string": { + "collation": "UTF8_BINARY" + } + }, + "containsNull": true + } + } + }], + "dataTypeStruct": { + "fields": [{ + "name": "_1" + }, { + "name": "_2", + "dataType": { + "string": { + "collation": "UTF8_BINARY" + } + }, + "nullable": true + }, { + "name": "_3", + "nullable": true + }] + } + } + }, { + "struct": { + "elements": [{ + "integer": 5 + }, { + "string": "6" + }, { + "array": { + "dataType": { + "elementType": { + "string": { + "collation": "UTF8_BINARY" + } + }, + "containsNull": true + } + } + }], + "dataTypeStruct": { + "fields": [{ + "name": "_1" + }, { + "name": "_2", + "dataType": { + "string": { + "collation": "UTF8_BINARY" + } + }, + "nullable": true + }, { + "name": "_3", + "nullable": true + }] + } + } + }], + "dataType": { + "containsNull": true + } + } + }, + "common": { + "origin": { + "jvmOrigin": { + "stackTrace": [{ + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.functions$", + "methodName": "typedLit", + "fileName": "functions.scala" + }, { + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite", + "methodName": "~~trimmed~anonfun~~", + "fileName": "PlanGenerationTestSuite.scala" + }] + } + } + } + }, { + "literal": { + "array": { + "elements": [{ + "struct": { + "elements": [{ + "integer": 1 + }, { + "string": "2" + }], + "dataTypeStruct": { + "fields": [{ + "name": "a" + }, { + "name": "b", + "dataType": { + "string": { + "collation": "UTF8_BINARY" + } + }, + "nullable": true + }] + } + } + }, { + "struct": { + "elements": [{ + "integer": 3 + }, { + "string": "4" + }], + "dataTypeStruct": { + "fields": [{ + "name": "a" + }, { + "name": "b", + "dataType": { + "string": { + "collation": "UTF8_BINARY" + } + }, + "nullable": true + }] + } + } + }, { + "struct": { + "elements": [{ + "integer": 5 + }, { + "string": "6" + }], + "dataTypeStruct": { + "fields": [{ + "name": "a" + }, { + "name": "b", + "dataType": { + "string": { + "collation": "UTF8_BINARY" + } + }, + "nullable": true + }] + } + } + }], + "dataType": { + "containsNull": true + } + } + }, + "common": { + "origin": { + "jvmOrigin": { + "stackTrace": [{ + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.functions$", + "methodName": "typedLit", + "fileName": "functions.scala" + }, { + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite", + "methodName": "~~trimmed~anonfun~~", + "fileName": "PlanGenerationTestSuite.scala" + }] + } + } + } }, { "literal": { "array": { diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin index 6c5ea53d05a9..734f8576d24e 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin differ diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala index be7d67279cc1..4c8911c88188 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala @@ -18,10 +18,9 @@ package org.apache.spark.sql.connect.planner import org.apache.spark.connect.proto -import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters} -import org.apache.spark.sql.connect.common.{InvalidPlanInput, LiteralValueProtoConverter} -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.connect.common.LiteralValueProtoConverter object LiteralExpressionProtoConverter { @@ -33,86 +32,8 @@ object LiteralExpressionProtoConverter { */ def toCatalystExpression(lit: proto.Expression.Literal): expressions.Literal = { val dataType = LiteralValueProtoConverter.getDataType(lit) - lit.getLiteralTypeCase match { - case proto.Expression.Literal.LiteralTypeCase.NULL => - expressions.Literal(null, dataType) - - case proto.Expression.Literal.LiteralTypeCase.BINARY => - expressions.Literal(lit.getBinary.toByteArray, dataType) - - case proto.Expression.Literal.LiteralTypeCase.BOOLEAN => - expressions.Literal(lit.getBoolean, dataType) - - case proto.Expression.Literal.LiteralTypeCase.BYTE => - expressions.Literal(lit.getByte.toByte, dataType) - - case proto.Expression.Literal.LiteralTypeCase.SHORT => - expressions.Literal(lit.getShort.toShort, dataType) - - case proto.Expression.Literal.LiteralTypeCase.INTEGER => - expressions.Literal(lit.getInteger, dataType) - - case proto.Expression.Literal.LiteralTypeCase.LONG => - expressions.Literal(lit.getLong, dataType) - - case proto.Expression.Literal.LiteralTypeCase.FLOAT => - expressions.Literal(lit.getFloat, dataType) - - case proto.Expression.Literal.LiteralTypeCase.DOUBLE => - expressions.Literal(lit.getDouble, dataType) - - case proto.Expression.Literal.LiteralTypeCase.DECIMAL => - expressions.Literal(Decimal.apply(lit.getDecimal.getValue), dataType) - - case proto.Expression.Literal.LiteralTypeCase.STRING => - expressions.Literal(UTF8String.fromString(lit.getString), dataType) - - case proto.Expression.Literal.LiteralTypeCase.DATE => - expressions.Literal(lit.getDate, dataType) - - case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP => - expressions.Literal(lit.getTimestamp, dataType) - - case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP_NTZ => - expressions.Literal(lit.getTimestampNtz, dataType) - - case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL => - val interval = new CalendarInterval( - lit.getCalendarInterval.getMonths, - lit.getCalendarInterval.getDays, - lit.getCalendarInterval.getMicroseconds) - expressions.Literal(interval, dataType) - - case proto.Expression.Literal.LiteralTypeCase.YEAR_MONTH_INTERVAL => - expressions.Literal(lit.getYearMonthInterval, dataType) - - case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL => - expressions.Literal(lit.getDayTimeInterval, dataType) - - case proto.Expression.Literal.LiteralTypeCase.TIME => - var precision = TimeType.DEFAULT_PRECISION - if (lit.getTime.hasPrecision) { - precision = lit.getTime.getPrecision - } - expressions.Literal(lit.getTime.getNano, dataType) - - case proto.Expression.Literal.LiteralTypeCase.ARRAY => - val arrayData = LiteralValueProtoConverter.toScalaArray(lit.getArray) - expressions.Literal.create(arrayData, dataType) - - case proto.Expression.Literal.LiteralTypeCase.MAP => - val mapData = LiteralValueProtoConverter.toScalaMap(lit.getMap) - expressions.Literal.create(mapData, dataType) - - case proto.Expression.Literal.LiteralTypeCase.STRUCT => - val structData = LiteralValueProtoConverter.toScalaStruct(lit.getStruct) - val convert = CatalystTypeConverters.createToCatalystConverter(dataType) - expressions.Literal(convert(structData), dataType) - - case _ => - throw InvalidPlanInput( - s"Unsupported Literal Type: ${lit.getLiteralTypeCase.name}" + - s"(${lit.getLiteralTypeCase.getNumber})") - } + val scalaValue = LiteralValueProtoConverter.toScalaValue(lit) + val convert = CatalystTypeConverters.createToCatalystConverter(dataType) + expressions.Literal(convert(scalaValue), dataType) } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org