This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new cdc73ad36e5 [SPARK-41506][CONNECT][PYTHON] Refactor LiteralExpression 
to support DataType
cdc73ad36e5 is described below

commit cdc73ad36e53544add2cfb7ea66941202014303e
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Wed Dec 14 09:50:05 2022 +0900

    [SPARK-41506][CONNECT][PYTHON] Refactor LiteralExpression to support 
DataType
    
    ### What changes were proposed in this pull request?
    1, existing `LiteralExpression` is a mixture of `Literal`, `CreateArray`, 
`CreateStruct` and `CreateMap`, since we have added collection functions 
`array`, `struct` and `create_map`, the `CreateXXX` expressions can be replaced 
with `UnresolvedFunction`;
    2, add field `dataType` in `LiteralExpression`, so we can specify the 
DataType if needed, a special case is the typed null;
    3, it is up to the `lit` function to infer the DataType, not 
`LiteralExpression` itself;
    
    ### Why are the changes needed?
    Refactor LiteralExpression to support DataType
    
    ### Does this PR introduce _any_ user-facing change?
    No, `LiteralExpression` is a internal class, should not expose to end users
    
    ### How was this patch tested?
    added UT
    
    Closes #39047 from zhengruifeng/connect_lit_datatype.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../main/protobuf/spark/connect/expressions.proto  |  23 +--
 .../planner/LiteralValueProtoConverter.scala       |  31 +--
 python/pyspark/sql/connect/column.py               | 213 +++++++++++++++------
 python/pyspark/sql/connect/functions.py            |  13 +-
 .../pyspark/sql/connect/proto/expressions_pb2.py   |  86 ++-------
 .../pyspark/sql/connect/proto/expressions_pb2.pyi  | 115 +----------
 .../sql/tests/connect/test_connect_column.py       | 122 +++++++++++-
 .../connect/test_connect_column_expressions.py     |  74 +++++--
 8 files changed, 365 insertions(+), 312 deletions(-)

diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index 6c0facbfeee..c906f15e0a6 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -77,9 +77,7 @@ message Expression {
       int32 year_month_interval = 20;
       int64 day_time_interval = 21;
 
-      Array array = 22;
-      Struct struct = 23;
-      Map map = 24;
+      DataType typed_null = 22;
     }
 
     // whether the literal type should be treated as a nullable type. Applies 
to
@@ -107,25 +105,6 @@ message Expression {
       int32 days = 2;
       int64 microseconds = 3;
     }
-
-    message Struct {
-      // A possibly heterogeneously typed list of literals
-      repeated Literal fields = 1;
-    }
-
-    message Array {
-      // A homogeneously typed list of literals
-      repeated Literal values = 1;
-    }
-
-    message Map {
-      repeated Pair pairs = 1;
-
-      message Pair {
-        Literal key = 1;
-        Literal value = 2;
-      }
-    }
   }
 
   // An unresolved attribute that is not explicitly bound to a specific 
column, but the column
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala
index 5a54ad9ac64..46f6db64b8c 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala
@@ -17,11 +17,8 @@
 
 package org.apache.spark.sql.connect.planner
 
-import scala.collection.JavaConverters._
-
 import org.apache.spark.connect.proto
-import org.apache.spark.sql.catalyst.{expressions, InternalRow}
-import org.apache.spark.sql.catalyst.expressions.{CreateArray, CreateMap, 
CreateStruct}
+import org.apache.spark.sql.catalyst.expressions
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
 
@@ -99,20 +96,6 @@ object LiteralValueProtoConverter {
       case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL =>
         expressions.Literal(lit.getDayTimeInterval, DayTimeIntervalType())
 
-      case proto.Expression.Literal.LiteralTypeCase.ARRAY =>
-        val literals = 
lit.getArray.getValuesList.asScala.toArray.map(toCatalystExpression)
-        CreateArray(literals)
-
-      case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
-        val literals = 
lit.getStruct.getFieldsList.asScala.toArray.map(toCatalystExpression)
-        CreateStruct(literals)
-
-      case proto.Expression.Literal.LiteralTypeCase.MAP =>
-        val literals = lit.getMap.getPairsList.asScala.toArray.flatMap { pair 
=>
-          toCatalystExpression(pair.getKey) :: 
toCatalystExpression(pair.getValue) :: Nil
-        }
-        CreateMap(literals)
-
       case _ =>
         throw InvalidPlanInput(
           s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" +
@@ -122,18 +105,6 @@ object LiteralValueProtoConverter {
 
   def toCatalystValue(lit: proto.Expression.Literal): Any = {
     lit.getLiteralTypeCase match {
-      case proto.Expression.Literal.LiteralTypeCase.ARRAY =>
-        lit.getArray.getValuesList.asScala.toArray.map(toCatalystValue)
-
-      case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
-        val literals = 
lit.getStruct.getFieldsList.asScala.map(toCatalystValue).toSeq
-        InternalRow(literals: _*)
-
-      case proto.Expression.Literal.LiteralTypeCase.MAP =>
-        lit.getMap.getPairsList.asScala.toArray.map { pair =>
-          toCatalystValue(pair.getKey) -> toCatalystValue(pair.getValue)
-        }.toMap
-
       case proto.Expression.Literal.LiteralTypeCase.STRING => lit.getString
 
       case _ => 
toCatalystExpression(lit).asInstanceOf[expressions.Literal].value
diff --git a/python/pyspark/sql/connect/column.py 
b/python/pyspark/sql/connect/column.py
index 58d4e3dc41e..673c486ee1f 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -21,7 +21,24 @@ import json
 import decimal
 import datetime
 
-from pyspark.sql.types import TimestampType, DayTimeIntervalType, DataType, 
DateType
+from pyspark.sql.types import (
+    DateType,
+    NullType,
+    BooleanType,
+    BinaryType,
+    ByteType,
+    ShortType,
+    IntegerType,
+    LongType,
+    FloatType,
+    DoubleType,
+    DecimalType,
+    StringType,
+    DataType,
+    TimestampType,
+    TimestampNTZType,
+    DayTimeIntervalType,
+)
 
 import pyspark.sql.connect.proto as proto
 from pyspark.sql.connect.types import pyspark_types_to_proto_types
@@ -31,7 +48,10 @@ if TYPE_CHECKING:
     from pyspark.sql.connect.client import SparkConnectClient
     import pyspark.sql.connect.proto as proto
 
-
+JVM_BYTE_MIN = -(1 << 7)
+JVM_BYTE_MAX = (1 << 7) - 1
+JVM_SHORT_MIN = -(1 << 15)
+JVM_SHORT_MAX = (1 << 15) - 1
 JVM_INT_MIN = -(1 << 31)
 JVM_INT_MAX = (1 << 31) - 1
 JVM_LONG_MIN = -(1 << 63)
@@ -166,78 +186,147 @@ class LiteralExpression(Expression):
     The Python types are converted best effort into the relevant proto types. 
On the Spark Connect
     server side, the proto types are converted to the Catalyst equivalents."""
 
-    def __init__(self, value: Any) -> None:
+    def __init__(self, value: Any, dataType: DataType) -> None:
         super().__init__()
-        self._value = value
 
-    def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
-        """Converts the literal expression to the literal in proto.
+        assert isinstance(
+            dataType,
+            (
+                NullType,
+                BinaryType,
+                BooleanType,
+                ByteType,
+                ShortType,
+                IntegerType,
+                LongType,
+                FloatType,
+                DoubleType,
+                DecimalType,
+                StringType,
+                DateType,
+                TimestampType,
+                TimestampNTZType,
+                DayTimeIntervalType,
+            ),
+        )
 
-        TODO(SPARK-40533) This method always assumes the largest type and can 
thus
-             create weird interpretations of the literal."""
+        if isinstance(dataType, NullType):
+            assert value is None
+
+        if value is not None:
+            if isinstance(dataType, BinaryType):
+                assert isinstance(value, (bytes, bytearray))
+            elif isinstance(dataType, BooleanType):
+                assert isinstance(value, bool)
+            elif isinstance(dataType, ByteType):
+                assert isinstance(value, int) and JVM_BYTE_MIN <= int(value) 
<= JVM_BYTE_MAX
+            elif isinstance(dataType, ShortType):
+                assert isinstance(value, int) and JVM_SHORT_MIN <= int(value) 
<= JVM_SHORT_MAX
+            elif isinstance(dataType, IntegerType):
+                assert isinstance(value, int) and JVM_INT_MIN <= int(value) <= 
JVM_INT_MAX
+            elif isinstance(dataType, LongType):
+                assert isinstance(value, int) and JVM_LONG_MIN <= int(value) 
<= JVM_LONG_MAX
+            elif isinstance(dataType, FloatType):
+                assert isinstance(value, float)
+            elif isinstance(dataType, DoubleType):
+                assert isinstance(value, float)
+            elif isinstance(dataType, DecimalType):
+                assert isinstance(value, decimal.Decimal)
+            elif isinstance(dataType, StringType):
+                assert isinstance(value, str)
+            elif isinstance(dataType, DateType):
+                assert isinstance(value, (datetime.date, datetime.datetime))
+                if isinstance(value, datetime.date):
+                    value = DateType().toInternal(value)
+                else:
+                    value = DateType().toInternal(value.date())
+            elif isinstance(dataType, TimestampType):
+                assert isinstance(value, datetime.datetime)
+                value = TimestampType().toInternal(value)
+            elif isinstance(dataType, TimestampNTZType):
+                assert isinstance(value, datetime.datetime)
+                value = TimestampNTZType().toInternal(value)
+            elif isinstance(dataType, DayTimeIntervalType):
+                assert isinstance(value, datetime.timedelta)
+                value = DayTimeIntervalType().toInternal(value)
+                assert value is not None
+            else:
+                raise ValueError(f"Unsupported Data Type {dataType}")
 
-        from pyspark.sql.connect.functions import lit
+        self._value = value
+        self._dataType = dataType
+
+    @classmethod
+    def _infer_type(cls, value: Any) -> DataType:
+        if value is None:
+            return NullType()
+        elif isinstance(value, (bytes, bytearray)):
+            return BinaryType()
+        elif isinstance(value, bool):
+            return BooleanType()
+        elif isinstance(value, int):
+            if JVM_INT_MIN <= value <= JVM_INT_MAX:
+                return IntegerType()
+            elif JVM_LONG_MIN <= value <= JVM_LONG_MAX:
+                return LongType()
+            else:
+                raise ValueError(f"integer {value} out of bounds")
+        elif isinstance(value, float):
+            return DoubleType()
+        elif isinstance(value, str):
+            return StringType()
+        elif isinstance(value, decimal.Decimal):
+            return DecimalType()
+        elif isinstance(value, datetime.datetime):
+            return TimestampType()
+        elif isinstance(value, datetime.date):
+            return DateType()
+        elif isinstance(value, datetime.timedelta):
+            return DayTimeIntervalType()
+        else:
+            raise ValueError(f"Unsupported Data Type {type(value).__name__}")
+
+    def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
+        """Converts the literal expression to the literal in proto."""
 
         expr = proto.Expression()
-        if self._value is None:
+
+        if isinstance(self._dataType, NullType):
             expr.literal.null = True
-        elif isinstance(self._value, (bytes, bytearray)):
+        elif self._value is None:
+            
expr.typed_null.CopyFrom(pyspark_types_to_proto_types(self._dataType))
+        elif isinstance(self._dataType, BinaryType):
             expr.literal.binary = bytes(self._value)
-        elif isinstance(self._value, bool):
+        elif isinstance(self._dataType, BooleanType):
             expr.literal.boolean = bool(self._value)
-        elif isinstance(self._value, int):
-            if JVM_INT_MIN <= self._value <= JVM_INT_MAX:
-                expr.literal.integer = int(self._value)
-            elif JVM_LONG_MIN <= self._value <= JVM_LONG_MAX:
-                expr.literal.long = int(self._value)
-            else:
-                raise ValueError(f"integer {self._value} out of bounds")
-        elif isinstance(self._value, float):
+        elif isinstance(self._dataType, ByteType):
+            expr.literal.byte = int(self._value)
+        elif isinstance(self._dataType, ShortType):
+            expr.literal.short = int(self._value)
+        elif isinstance(self._dataType, IntegerType):
+            expr.literal.integer = int(self._value)
+        elif isinstance(self._dataType, LongType):
+            expr.literal.long = int(self._value)
+        elif isinstance(self._dataType, FloatType):
+            expr.literal.float = float(self._value)
+        elif isinstance(self._dataType, DoubleType):
             expr.literal.double = float(self._value)
-        elif isinstance(self._value, str):
-            expr.literal.string = str(self._value)
-        elif isinstance(self._value, decimal.Decimal):
+        elif isinstance(self._dataType, DecimalType):
             expr.literal.decimal.value = str(self._value)
-            expr.literal.decimal.precision = int(decimal.getcontext().prec)
-        elif isinstance(self._value, datetime.datetime):
-            expr.literal.timestamp = TimestampType().toInternal(self._value)
-        elif isinstance(self._value, datetime.date):
-            expr.literal.date = DateType().toInternal(self._value)
-        elif isinstance(self._value, datetime.timedelta):
-            interval = DayTimeIntervalType().toInternal(self._value)
-            assert interval is not None
-            expr.literal.day_time_interval = int(interval)
-        elif isinstance(self._value, list):
-            expr.literal.array.SetInParent()
-            for item in list(self._value):
-                if isinstance(item, Column):
-                    
expr.literal.array.values.append(item.to_plan(session).literal)
-                else:
-                    
expr.literal.array.values.append(lit(item).to_plan(session).literal)
-        elif isinstance(self._value, tuple):
-            expr.literal.struct.SetInParent()
-            for item in list(self._value):
-                if isinstance(item, Column):
-                    
expr.literal.struct.fields.append(item.to_plan(session).literal)
-                else:
-                    
expr.literal.struct.fields.append(lit(item).to_plan(session).literal)
-        elif isinstance(self._value, dict):
-            expr.literal.map.SetInParent()
-            for key, value in dict(self._value).items():
-                pair = proto.Expression.Literal.Map.Pair()
-                if isinstance(key, Column):
-                    pair.key.CopyFrom(key.to_plan(session).literal)
-                else:
-                    pair.key.CopyFrom(lit(key).to_plan(session).literal)
-                if isinstance(value, Column):
-                    pair.value.CopyFrom(value.to_plan(session).literal)
-                else:
-                    pair.value.CopyFrom(lit(value).to_plan(session).literal)
-                expr.literal.map.pairs.append(pair)
-        elif isinstance(self._value, Column):
-            expr.CopyFrom(self._value.to_plan(session))
+            expr.literal.decimal.precision = self._dataType.precision
+            expr.literal.decimal.scale = self._dataType.scale
+        elif isinstance(self._dataType, StringType):
+            expr.literal.string = str(self._value)
+        elif isinstance(self._dataType, DateType):
+            expr.literal.date = int(self._value)
+        elif isinstance(self._dataType, TimestampType):
+            expr.literal.timestamp = int(self._value)
+        elif isinstance(self._dataType, TimestampNTZType):
+            expr.literal.timestamp_ntz = int(self._value)
+        elif isinstance(self._dataType, DayTimeIntervalType):
+            expr.literal.day_time_interval = int(self._value)
         else:
-            raise ValueError(f"Could not convert literal for type 
{type(self._value)}")
+            raise ValueError(f"Unsupported Data Type {self._dataType}")
 
         return expr
 
diff --git a/python/pyspark/sql/connect/functions.py 
b/python/pyspark/sql/connect/functions.py
index dccb6d6e0c7..a81de4f9ba6 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -92,8 +92,19 @@ column = col
 def lit(col: Any) -> Column:
     if isinstance(col, Column):
         return col
+    elif isinstance(col, list):
+        return array(*[lit(c) for c in col])
+    elif isinstance(col, tuple):
+        return struct(*[lit(c) for c in col])
+    elif isinstance(col, dict):
+        cols = []
+        for k, v in col.items():
+            cols.append(lit(k))
+            cols.append(lit(v))
+        return create_map(*cols)
     else:
-        return Column(LiteralExpression(col))
+        dataType = LiteralExpression._infer_type(col)
+        return Column(LiteralExpression(col, dataType))
 
 
 # def bitwiseNOT(col: "ColumnOrName") -> Column:
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py 
b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 91c57a9ef22..a311efc063a 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -33,7 +33,7 @@ from pyspark.sql.connect.proto import types_pb2 as 
spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\xf4\x14\n\nExpression\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
 
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_st
 [...]
+    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\xa5\x11\n\nExpression\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
 
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_st
 [...]
 )
 
 
@@ -42,10 +42,6 @@ _EXPRESSION_CAST = _EXPRESSION.nested_types_by_name["Cast"]
 _EXPRESSION_LITERAL = _EXPRESSION.nested_types_by_name["Literal"]
 _EXPRESSION_LITERAL_DECIMAL = 
_EXPRESSION_LITERAL.nested_types_by_name["Decimal"]
 _EXPRESSION_LITERAL_CALENDARINTERVAL = 
_EXPRESSION_LITERAL.nested_types_by_name["CalendarInterval"]
-_EXPRESSION_LITERAL_STRUCT = _EXPRESSION_LITERAL.nested_types_by_name["Struct"]
-_EXPRESSION_LITERAL_ARRAY = _EXPRESSION_LITERAL.nested_types_by_name["Array"]
-_EXPRESSION_LITERAL_MAP = _EXPRESSION_LITERAL.nested_types_by_name["Map"]
-_EXPRESSION_LITERAL_MAP_PAIR = 
_EXPRESSION_LITERAL_MAP.nested_types_by_name["Pair"]
 _EXPRESSION_UNRESOLVEDATTRIBUTE = 
_EXPRESSION.nested_types_by_name["UnresolvedAttribute"]
 _EXPRESSION_UNRESOLVEDFUNCTION = 
_EXPRESSION.nested_types_by_name["UnresolvedFunction"]
 _EXPRESSION_EXPRESSIONSTRING = 
_EXPRESSION.nested_types_by_name["ExpressionString"]
@@ -86,42 +82,6 @@ Expression = _reflection.GeneratedProtocolMessageType(
                         # 
@@protoc_insertion_point(class_scope:spark.connect.Expression.Literal.CalendarInterval)
                     },
                 ),
-                "Struct": _reflection.GeneratedProtocolMessageType(
-                    "Struct",
-                    (_message.Message,),
-                    {
-                        "DESCRIPTOR": _EXPRESSION_LITERAL_STRUCT,
-                        "__module__": "spark.connect.expressions_pb2"
-                        # 
@@protoc_insertion_point(class_scope:spark.connect.Expression.Literal.Struct)
-                    },
-                ),
-                "Array": _reflection.GeneratedProtocolMessageType(
-                    "Array",
-                    (_message.Message,),
-                    {
-                        "DESCRIPTOR": _EXPRESSION_LITERAL_ARRAY,
-                        "__module__": "spark.connect.expressions_pb2"
-                        # 
@@protoc_insertion_point(class_scope:spark.connect.Expression.Literal.Array)
-                    },
-                ),
-                "Map": _reflection.GeneratedProtocolMessageType(
-                    "Map",
-                    (_message.Message,),
-                    {
-                        "Pair": _reflection.GeneratedProtocolMessageType(
-                            "Pair",
-                            (_message.Message,),
-                            {
-                                "DESCRIPTOR": _EXPRESSION_LITERAL_MAP_PAIR,
-                                "__module__": "spark.connect.expressions_pb2"
-                                # 
@@protoc_insertion_point(class_scope:spark.connect.Expression.Literal.Map.Pair)
-                            },
-                        ),
-                        "DESCRIPTOR": _EXPRESSION_LITERAL_MAP,
-                        "__module__": "spark.connect.expressions_pb2"
-                        # 
@@protoc_insertion_point(class_scope:spark.connect.Expression.Literal.Map)
-                    },
-                ),
                 "DESCRIPTOR": _EXPRESSION_LITERAL,
                 "__module__": "spark.connect.expressions_pb2"
                 # 
@@protoc_insertion_point(class_scope:spark.connect.Expression.Literal)
@@ -182,10 +142,6 @@ _sym_db.RegisterMessage(Expression.Cast)
 _sym_db.RegisterMessage(Expression.Literal)
 _sym_db.RegisterMessage(Expression.Literal.Decimal)
 _sym_db.RegisterMessage(Expression.Literal.CalendarInterval)
-_sym_db.RegisterMessage(Expression.Literal.Struct)
-_sym_db.RegisterMessage(Expression.Literal.Array)
-_sym_db.RegisterMessage(Expression.Literal.Map)
-_sym_db.RegisterMessage(Expression.Literal.Map.Pair)
 _sym_db.RegisterMessage(Expression.UnresolvedAttribute)
 _sym_db.RegisterMessage(Expression.UnresolvedFunction)
 _sym_db.RegisterMessage(Expression.ExpressionString)
@@ -197,31 +153,23 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     DESCRIPTOR._options = None
     DESCRIPTOR._serialized_options = 
b"\n\036org.apache.spark.connect.protoP\001"
     _EXPRESSION._serialized_start = 78
-    _EXPRESSION._serialized_end = 2754
+    _EXPRESSION._serialized_end = 2291
     _EXPRESSION_CAST._serialized_start = 640
     _EXPRESSION_CAST._serialized_end = 785
     _EXPRESSION_LITERAL._serialized_start = 788
-    _EXPRESSION_LITERAL._serialized_end = 2246
-    _EXPRESSION_LITERAL_DECIMAL._serialized_start = 1684
-    _EXPRESSION_LITERAL_DECIMAL._serialized_end = 1801
-    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 1803
-    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 1901
-    _EXPRESSION_LITERAL_STRUCT._serialized_start = 1903
-    _EXPRESSION_LITERAL_STRUCT._serialized_end = 1970
-    _EXPRESSION_LITERAL_ARRAY._serialized_start = 1972
-    _EXPRESSION_LITERAL_ARRAY._serialized_end = 2038
-    _EXPRESSION_LITERAL_MAP._serialized_start = 2041
-    _EXPRESSION_LITERAL_MAP._serialized_end = 2230
-    _EXPRESSION_LITERAL_MAP_PAIR._serialized_start = 2114
-    _EXPRESSION_LITERAL_MAP_PAIR._serialized_end = 2230
-    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 2248
-    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 2318
-    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 2321
-    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 2525
-    _EXPRESSION_EXPRESSIONSTRING._serialized_start = 2527
-    _EXPRESSION_EXPRESSIONSTRING._serialized_end = 2577
-    _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 2579
-    _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 2619
-    _EXPRESSION_ALIAS._serialized_start = 2621
-    _EXPRESSION_ALIAS._serialized_end = 2741
+    _EXPRESSION_LITERAL._serialized_end = 1783
+    _EXPRESSION_LITERAL_DECIMAL._serialized_start = 1550
+    _EXPRESSION_LITERAL_DECIMAL._serialized_end = 1667
+    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 1669
+    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 1767
+    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 1785
+    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 1855
+    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 1858
+    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 2062
+    _EXPRESSION_EXPRESSIONSTRING._serialized_start = 2064
+    _EXPRESSION_EXPRESSIONSTRING._serialized_end = 2114
+    _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 2116
+    _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 2156
+    _EXPRESSION_ALIAS._serialized_start = 2158
+    _EXPRESSION_ALIAS._serialized_end = 2278
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi 
b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index 2c486f62a9d..d710fcab52d 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -190,87 +190,6 @@ class Expression(google.protobuf.message.Message):
                 ],
             ) -> None: ...
 
-        class Struct(google.protobuf.message.Message):
-            DESCRIPTOR: google.protobuf.descriptor.Descriptor
-
-            FIELDS_FIELD_NUMBER: builtins.int
-            @property
-            def fields(
-                self,
-            ) -> 
google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
-                global___Expression.Literal
-            ]:
-                """A possibly heterogeneously typed list of literals"""
-            def __init__(
-                self,
-                *,
-                fields: collections.abc.Iterable[global___Expression.Literal] 
| None = ...,
-            ) -> None: ...
-            def ClearField(
-                self, field_name: typing_extensions.Literal["fields", 
b"fields"]
-            ) -> None: ...
-
-        class Array(google.protobuf.message.Message):
-            DESCRIPTOR: google.protobuf.descriptor.Descriptor
-
-            VALUES_FIELD_NUMBER: builtins.int
-            @property
-            def values(
-                self,
-            ) -> 
google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
-                global___Expression.Literal
-            ]:
-                """A homogeneously typed list of literals"""
-            def __init__(
-                self,
-                *,
-                values: collections.abc.Iterable[global___Expression.Literal] 
| None = ...,
-            ) -> None: ...
-            def ClearField(
-                self, field_name: typing_extensions.Literal["values", 
b"values"]
-            ) -> None: ...
-
-        class Map(google.protobuf.message.Message):
-            DESCRIPTOR: google.protobuf.descriptor.Descriptor
-
-            class Pair(google.protobuf.message.Message):
-                DESCRIPTOR: google.protobuf.descriptor.Descriptor
-
-                KEY_FIELD_NUMBER: builtins.int
-                VALUE_FIELD_NUMBER: builtins.int
-                @property
-                def key(self) -> global___Expression.Literal: ...
-                @property
-                def value(self) -> global___Expression.Literal: ...
-                def __init__(
-                    self,
-                    *,
-                    key: global___Expression.Literal | None = ...,
-                    value: global___Expression.Literal | None = ...,
-                ) -> None: ...
-                def HasField(
-                    self, field_name: typing_extensions.Literal["key", b"key", 
"value", b"value"]
-                ) -> builtins.bool: ...
-                def ClearField(
-                    self, field_name: typing_extensions.Literal["key", b"key", 
"value", b"value"]
-                ) -> None: ...
-
-            PAIRS_FIELD_NUMBER: builtins.int
-            @property
-            def pairs(
-                self,
-            ) -> 
google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
-                global___Expression.Literal.Map.Pair
-            ]: ...
-            def __init__(
-                self,
-                *,
-                pairs: 
collections.abc.Iterable[global___Expression.Literal.Map.Pair] | None = ...,
-            ) -> None: ...
-            def ClearField(
-                self, field_name: typing_extensions.Literal["pairs", b"pairs"]
-            ) -> None: ...
-
         NULL_FIELD_NUMBER: builtins.int
         BINARY_FIELD_NUMBER: builtins.int
         BOOLEAN_FIELD_NUMBER: builtins.int
@@ -288,9 +207,7 @@ class Expression(google.protobuf.message.Message):
         CALENDAR_INTERVAL_FIELD_NUMBER: builtins.int
         YEAR_MONTH_INTERVAL_FIELD_NUMBER: builtins.int
         DAY_TIME_INTERVAL_FIELD_NUMBER: builtins.int
-        ARRAY_FIELD_NUMBER: builtins.int
-        STRUCT_FIELD_NUMBER: builtins.int
-        MAP_FIELD_NUMBER: builtins.int
+        TYPED_NULL_FIELD_NUMBER: builtins.int
         NULLABLE_FIELD_NUMBER: builtins.int
         TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
         null: builtins.bool
@@ -316,11 +233,7 @@ class Expression(google.protobuf.message.Message):
         year_month_interval: builtins.int
         day_time_interval: builtins.int
         @property
-        def array(self) -> global___Expression.Literal.Array: ...
-        @property
-        def struct(self) -> global___Expression.Literal.Struct: ...
-        @property
-        def map(self) -> global___Expression.Literal.Map: ...
+        def typed_null(self) -> pyspark.sql.connect.proto.types_pb2.DataType: 
...
         nullable: builtins.bool
         """whether the literal type should be treated as a nullable type. 
Applies to
         all members of union other than the Typed null (which should directly
@@ -351,17 +264,13 @@ class Expression(google.protobuf.message.Message):
             calendar_interval: global___Expression.Literal.CalendarInterval | 
None = ...,
             year_month_interval: builtins.int = ...,
             day_time_interval: builtins.int = ...,
-            array: global___Expression.Literal.Array | None = ...,
-            struct: global___Expression.Literal.Struct | None = ...,
-            map: global___Expression.Literal.Map | None = ...,
+            typed_null: pyspark.sql.connect.proto.types_pb2.DataType | None = 
...,
             nullable: builtins.bool = ...,
             type_variation_reference: builtins.int = ...,
         ) -> None: ...
         def HasField(
             self,
             field_name: typing_extensions.Literal[
-                "array",
-                b"array",
                 "binary",
                 b"binary",
                 "boolean",
@@ -386,20 +295,18 @@ class Expression(google.protobuf.message.Message):
                 b"literal_type",
                 "long",
                 b"long",
-                "map",
-                b"map",
                 "null",
                 b"null",
                 "short",
                 b"short",
                 "string",
                 b"string",
-                "struct",
-                b"struct",
                 "timestamp",
                 b"timestamp",
                 "timestamp_ntz",
                 b"timestamp_ntz",
+                "typed_null",
+                b"typed_null",
                 "year_month_interval",
                 b"year_month_interval",
             ],
@@ -407,8 +314,6 @@ class Expression(google.protobuf.message.Message):
         def ClearField(
             self,
             field_name: typing_extensions.Literal[
-                "array",
-                b"array",
                 "binary",
                 b"binary",
                 "boolean",
@@ -433,8 +338,6 @@ class Expression(google.protobuf.message.Message):
                 b"literal_type",
                 "long",
                 b"long",
-                "map",
-                b"map",
                 "null",
                 b"null",
                 "nullable",
@@ -443,14 +346,14 @@ class Expression(google.protobuf.message.Message):
                 b"short",
                 "string",
                 b"string",
-                "struct",
-                b"struct",
                 "timestamp",
                 b"timestamp",
                 "timestamp_ntz",
                 b"timestamp_ntz",
                 "type_variation_reference",
                 b"type_variation_reference",
+                "typed_null",
+                b"typed_null",
                 "year_month_interval",
                 b"year_month_interval",
             ],
@@ -475,9 +378,7 @@ class Expression(google.protobuf.message.Message):
             "calendar_interval",
             "year_month_interval",
             "day_time_interval",
-            "array",
-            "struct",
-            "map",
+            "typed_null",
         ] | None: ...
 
     class UnresolvedAttribute(google.protobuf.message.Message):
diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py 
b/python/pyspark/sql/tests/connect/test_connect_column.py
index 8b70b4d9a44..b7645bc4b71 100644
--- a/python/pyspark/sql/tests/connect/test_connect_column.py
+++ b/python/pyspark/sql/tests/connect/test_connect_column.py
@@ -14,9 +14,32 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
+
+import decimal
+import datetime
+
 from pyspark.sql.tests.connect.test_connect_basic import 
SparkConnectSQLTestCase
-from pyspark.sql.types import StringType
+from pyspark.sql.connect.column import (
+    LiteralExpression,
+    JVM_BYTE_MIN,
+    JVM_BYTE_MAX,
+    JVM_SHORT_MIN,
+    JVM_SHORT_MAX,
+    JVM_INT_MIN,
+    JVM_INT_MAX,
+    JVM_LONG_MIN,
+    JVM_LONG_MAX,
+)
+
 from pyspark.sql.types import (
+    StructField,
+    StructType,
+    ArrayType,
+    MapType,
+    NullType,
+    DateType,
+    TimestampType,
+    TimestampNTZType,
     ByteType,
     ShortType,
     IntegerType,
@@ -92,13 +115,108 @@ class SparkConnectTests(SparkConnectSQLTestCase):
         res = pd.DataFrame(data={"id": [0, 30, 60, 90]})
         self.assert_(pdf.equals(res), f"{pdf.to_string()} != 
{res.to_string()}")
 
+    def test_literal_with_acceptable_type(self):
+        for value, dataType in [
+            (b"binary\0\0asas", BinaryType()),
+            (True, BooleanType()),
+            (False, BooleanType()),
+            (0, ByteType()),
+            (JVM_BYTE_MIN, ByteType()),
+            (JVM_BYTE_MAX, ByteType()),
+            (0, ShortType()),
+            (JVM_SHORT_MIN, ShortType()),
+            (JVM_SHORT_MAX, ShortType()),
+            (0, IntegerType()),
+            (JVM_INT_MIN, IntegerType()),
+            (JVM_INT_MAX, IntegerType()),
+            (0, LongType()),
+            (JVM_LONG_MIN, LongType()),
+            (JVM_LONG_MAX, LongType()),
+            (0.0, FloatType()),
+            (1.234567, FloatType()),
+            (float("nan"), FloatType()),
+            (float("inf"), FloatType()),
+            (float("-inf"), FloatType()),
+            (0.0, DoubleType()),
+            (1.234567, DoubleType()),
+            (float("nan"), DoubleType()),
+            (float("inf"), DoubleType()),
+            (float("-inf"), DoubleType()),
+            (decimal.Decimal(0.0), DecimalType()),
+            (decimal.Decimal(1.234567), DecimalType()),
+            ("sss", StringType()),
+            (datetime.date(2022, 12, 13), DateType()),
+            (datetime.datetime.now(), DateType()),
+            (datetime.datetime.now(), TimestampType()),
+            (datetime.datetime.now(), TimestampNTZType()),
+            (datetime.timedelta(1, 2, 3), DayTimeIntervalType()),
+        ]:
+            lit = LiteralExpression(value=value, dataType=dataType)
+            self.assertEqual(dataType, lit._dataType)
+
+    def test_literal_with_unsupported_type(self):
+        for value, dataType in [
+            (b"binary\0\0asas", BooleanType()),
+            (True, StringType()),
+            (False, DoubleType()),
+            (JVM_BYTE_MIN - 1, ByteType()),
+            (JVM_BYTE_MAX + 1, ByteType()),
+            (JVM_SHORT_MIN - 1, ShortType()),
+            (JVM_SHORT_MAX + 1, ShortType()),
+            (JVM_INT_MIN - 1, IntegerType()),
+            (JVM_INT_MAX + 1, IntegerType()),
+            (JVM_LONG_MIN - 1, LongType()),
+            (JVM_LONG_MAX + 1, LongType()),
+            (0.1, DecimalType()),
+            (datetime.date(2022, 12, 13), TimestampType()),
+            (datetime.timedelta(1, 2, 3), DateType()),
+            ([1, 2, 3], ArrayType(IntegerType())),
+            ({1: 2}, MapType(IntegerType(), IntegerType())),
+            (
+                {"a": "xyz", "b": 1},
+                StructType([StructField("a", StringType()), StructField("b", 
IntegerType())]),
+            ),
+        ]:
+            with self.assertRaises(AssertionError):
+                LiteralExpression(value=value, dataType=dataType)
+
+    def test_literal_null(self):
+        for dataType in [
+            NullType(),
+            BinaryType(),
+            BooleanType(),
+            ByteType(),
+            ShortType(),
+            IntegerType(),
+            LongType(),
+            FloatType(),
+            DoubleType(),
+            DecimalType(),
+            DateType(),
+            TimestampType(),
+            TimestampNTZType(),
+            DayTimeIntervalType(),
+        ]:
+            lit_null = LiteralExpression(value=None, dataType=dataType)
+            self.assertTrue(lit_null._value is None)
+            self.assertEqual(dataType, lit_null._dataType)
+
+        for value, dataType in [
+            ("123", NullType()),
+            (123, NullType()),
+            (None, ArrayType(IntegerType())),
+            (None, MapType(IntegerType(), IntegerType())),
+            (None, StructType([StructField("a", StringType())])),
+        ]:
+            with self.assertRaises(AssertionError):
+                LiteralExpression(value=value, dataType=dataType)
+
     def test_literal_integers(self):
         cdf = self.connect.range(0, 1)
         sdf = self.spark.range(0, 1)
 
         from pyspark.sql import functions as SF
         from pyspark.sql.connect import functions as CF
-        from pyspark.sql.connect.column import JVM_INT_MIN, JVM_INT_MAX, 
JVM_LONG_MIN, JVM_LONG_MAX
 
         cdf1 = cdf.select(
             CF.lit(0),
diff --git 
a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py 
b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
index d74473e725f..f401e66de01 100644
--- a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
+++ b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
@@ -68,21 +68,50 @@ class 
SparkConnectColumnExpressionSuite(PlanOnlyTestFixture):
         val = {"this": "is", 12: [12, 32, 43]}
         map_lit = fun.lit(val)
         map_lit_p = map_lit.to_plan(None)
-        self.assertEqual(2, len(map_lit_p.literal.map.pairs))
-        self.assertEqual("this", map_lit_p.literal.map.pairs[0].key.string)
-        self.assertEqual(12, map_lit_p.literal.map.pairs[1].key.integer)
+
+        self.assertEqual(map_lit_p.unresolved_function.function_name, "map")
+        
self.assertEqual(map_lit_p.unresolved_function.arguments[0].literal.string, 
"this")
+        
self.assertEqual(map_lit_p.unresolved_function.arguments[1].literal.string, 
"is")
+        
self.assertEqual(map_lit_p.unresolved_function.arguments[2].literal.integer, 12)
+
+        self.assertEqual(
+            
map_lit_p.unresolved_function.arguments[3].unresolved_function.function_name, 
"array"
+        )
+        self.assertEqual(
+            map_lit_p.unresolved_function.arguments[3]
+            .unresolved_function.arguments[0]
+            .literal.integer,
+            12,
+        )
+        self.assertEqual(
+            map_lit_p.unresolved_function.arguments[3]
+            .unresolved_function.arguments[1]
+            .literal.integer,
+            32,
+        )
+        self.assertEqual(
+            map_lit_p.unresolved_function.arguments[3]
+            .unresolved_function.arguments[2]
+            .literal.integer,
+            43,
+        )
 
         val = {"this": fun.lit("is"), 12: [12, 32, 43]}
         map_lit = fun.lit(val)
         map_lit_p = map_lit.to_plan(None)
-        self.assertEqual(2, len(map_lit_p.literal.map.pairs))
-        self.assertEqual("is", map_lit_p.literal.map.pairs[0].value.string)
+        self.assertEqual(map_lit_p.unresolved_function.function_name, "map")
+        self.assertEqual(len(map_lit_p.unresolved_function.arguments), 4)
+        
self.assertEqual(map_lit_p.unresolved_function.arguments[0].literal.string, 
"this")
+        
self.assertEqual(map_lit_p.unresolved_function.arguments[1].literal.string, 
"is")
+        
self.assertEqual(map_lit_p.unresolved_function.arguments[2].literal.integer, 12)
+        self.assertEqual(
+            
map_lit_p.unresolved_function.arguments[3].unresolved_function.function_name, 
"array"
+        )
 
     def test_uuid_literal(self):
         val = uuid.uuid4()
-        lit = fun.lit(val)
         with self.assertRaises(ValueError):
-            lit.to_plan(None)
+            fun.lit(val)
 
     def test_column_literals(self):
         df = self.connect.with_plan(p.Read("table"))
@@ -162,27 +191,34 @@ class 
SparkConnectColumnExpressionSuite(PlanOnlyTestFixture):
 
         p0 = fun.lit(t0).to_plan(None)
         self.assertIsNotNone(p0)
-        self.assertTrue(p0.literal.HasField("struct"))
+        self.assertEqual(p0.unresolved_function.function_name, "struct")
 
         p1 = fun.lit(t1).to_plan(None)
         self.assertIsNotNone(p1)
-        self.assertTrue(p1.literal.HasField("struct"))
-        self.assertEqual(p1.literal.struct.fields[0].double, 1.0)
+        self.assertEqual(p1.unresolved_function.function_name, "struct")
+        self.assertEqual(p1.unresolved_function.arguments[0].literal.double, 
1.0)
 
         p2 = fun.lit(t2).to_plan(None)
         self.assertIsNotNone(p2)
-        self.assertTrue(p2.literal.HasField("struct"))
-        self.assertEqual(p2.literal.struct.fields[0].integer, 1)
-        self.assertEqual(p2.literal.struct.fields[1].string, "xyz")
+        self.assertEqual(p2.unresolved_function.function_name, "struct")
+        self.assertEqual(p2.unresolved_function.arguments[0].literal.integer, 
1)
+        self.assertEqual(p2.unresolved_function.arguments[1].literal.string, 
"xyz")
 
         p3 = fun.lit(t3).to_plan(None)
         self.assertIsNotNone(p3)
-        self.assertTrue(p3.literal.HasField("struct"))
-        self.assertEqual(p3.literal.struct.fields[0].integer, 1)
-        self.assertEqual(p3.literal.struct.fields[1].string, "abc")
-        self.assertEqual(p3.literal.struct.fields[2].struct.fields[0].double, 
3.5)
-        self.assertEqual(p3.literal.struct.fields[2].struct.fields[1].boolean, 
True)
-        self.assertEqual(p3.literal.struct.fields[2].struct.fields[2].null, 
True)
+        self.assertEqual(p3.unresolved_function.function_name, "struct")
+        self.assertEqual(p3.unresolved_function.arguments[0].literal.integer, 
1)
+        self.assertEqual(p3.unresolved_function.arguments[1].literal.string, 
"abc")
+        self.assertEqual(
+            
p3.unresolved_function.arguments[2].unresolved_function.arguments[0].literal.double,
 3.5
+        )
+        self.assertEqual(
+            
p3.unresolved_function.arguments[2].unresolved_function.arguments[1].literal.boolean,
+            True,
+        )
+        self.assertEqual(
+            
p3.unresolved_function.arguments[2].unresolved_function.arguments[2].literal.null,
 True
+        )
 
     def test_column_alias(self) -> None:
         # SPARK-40809: Support for Column Aliases


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to