This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new cdc73ad36e5 [SPARK-41506][CONNECT][PYTHON] Refactor LiteralExpression to support DataType cdc73ad36e5 is described below commit cdc73ad36e53544add2cfb7ea66941202014303e Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Wed Dec 14 09:50:05 2022 +0900 [SPARK-41506][CONNECT][PYTHON] Refactor LiteralExpression to support DataType ### What changes were proposed in this pull request? 1, existing `LiteralExpression` is a mixture of `Literal`, `CreateArray`, `CreateStruct` and `CreateMap`, since we have added collection functions `array`, `struct` and `create_map`, the `CreateXXX` expressions can be replaced with `UnresolvedFunction`; 2, add field `dataType` in `LiteralExpression`, so we can specify the DataType if needed, a special case is the typed null; 3, it is up to the `lit` function to infer the DataType, not `LiteralExpression` itself; ### Why are the changes needed? Refactor LiteralExpression to support DataType ### Does this PR introduce _any_ user-facing change? No, `LiteralExpression` is a internal class, should not expose to end users ### How was this patch tested? added UT Closes #39047 from zhengruifeng/connect_lit_datatype. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../main/protobuf/spark/connect/expressions.proto | 23 +-- .../planner/LiteralValueProtoConverter.scala | 31 +-- python/pyspark/sql/connect/column.py | 213 +++++++++++++++------ python/pyspark/sql/connect/functions.py | 13 +- .../pyspark/sql/connect/proto/expressions_pb2.py | 86 ++------- .../pyspark/sql/connect/proto/expressions_pb2.pyi | 115 +---------- .../sql/tests/connect/test_connect_column.py | 122 +++++++++++- .../connect/test_connect_column_expressions.py | 74 +++++-- 8 files changed, 365 insertions(+), 312 deletions(-) diff --git a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto index 6c0facbfeee..c906f15e0a6 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto @@ -77,9 +77,7 @@ message Expression { int32 year_month_interval = 20; int64 day_time_interval = 21; - Array array = 22; - Struct struct = 23; - Map map = 24; + DataType typed_null = 22; } // whether the literal type should be treated as a nullable type. Applies to @@ -107,25 +105,6 @@ message Expression { int32 days = 2; int64 microseconds = 3; } - - message Struct { - // A possibly heterogeneously typed list of literals - repeated Literal fields = 1; - } - - message Array { - // A homogeneously typed list of literals - repeated Literal values = 1; - } - - message Map { - repeated Pair pairs = 1; - - message Pair { - Literal key = 1; - Literal value = 2; - } - } } // An unresolved attribute that is not explicitly bound to a specific column, but the column diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala index 5a54ad9ac64..46f6db64b8c 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala @@ -17,11 +17,8 @@ package org.apache.spark.sql.connect.planner -import scala.collection.JavaConverters._ - import org.apache.spark.connect.proto -import org.apache.spark.sql.catalyst.{expressions, InternalRow} -import org.apache.spark.sql.catalyst.expressions.{CreateArray, CreateMap, CreateStruct} +import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -99,20 +96,6 @@ object LiteralValueProtoConverter { case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL => expressions.Literal(lit.getDayTimeInterval, DayTimeIntervalType()) - case proto.Expression.Literal.LiteralTypeCase.ARRAY => - val literals = lit.getArray.getValuesList.asScala.toArray.map(toCatalystExpression) - CreateArray(literals) - - case proto.Expression.Literal.LiteralTypeCase.STRUCT => - val literals = lit.getStruct.getFieldsList.asScala.toArray.map(toCatalystExpression) - CreateStruct(literals) - - case proto.Expression.Literal.LiteralTypeCase.MAP => - val literals = lit.getMap.getPairsList.asScala.toArray.flatMap { pair => - toCatalystExpression(pair.getKey) :: toCatalystExpression(pair.getValue) :: Nil - } - CreateMap(literals) - case _ => throw InvalidPlanInput( s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" + @@ -122,18 +105,6 @@ object LiteralValueProtoConverter { def toCatalystValue(lit: proto.Expression.Literal): Any = { lit.getLiteralTypeCase match { - case proto.Expression.Literal.LiteralTypeCase.ARRAY => - lit.getArray.getValuesList.asScala.toArray.map(toCatalystValue) - - case proto.Expression.Literal.LiteralTypeCase.STRUCT => - val literals = lit.getStruct.getFieldsList.asScala.map(toCatalystValue).toSeq - InternalRow(literals: _*) - - case proto.Expression.Literal.LiteralTypeCase.MAP => - lit.getMap.getPairsList.asScala.toArray.map { pair => - toCatalystValue(pair.getKey) -> toCatalystValue(pair.getValue) - }.toMap - case proto.Expression.Literal.LiteralTypeCase.STRING => lit.getString case _ => toCatalystExpression(lit).asInstanceOf[expressions.Literal].value diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index 58d4e3dc41e..673c486ee1f 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -21,7 +21,24 @@ import json import decimal import datetime -from pyspark.sql.types import TimestampType, DayTimeIntervalType, DataType, DateType +from pyspark.sql.types import ( + DateType, + NullType, + BooleanType, + BinaryType, + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType, + DecimalType, + StringType, + DataType, + TimestampType, + TimestampNTZType, + DayTimeIntervalType, +) import pyspark.sql.connect.proto as proto from pyspark.sql.connect.types import pyspark_types_to_proto_types @@ -31,7 +48,10 @@ if TYPE_CHECKING: from pyspark.sql.connect.client import SparkConnectClient import pyspark.sql.connect.proto as proto - +JVM_BYTE_MIN = -(1 << 7) +JVM_BYTE_MAX = (1 << 7) - 1 +JVM_SHORT_MIN = -(1 << 15) +JVM_SHORT_MAX = (1 << 15) - 1 JVM_INT_MIN = -(1 << 31) JVM_INT_MAX = (1 << 31) - 1 JVM_LONG_MIN = -(1 << 63) @@ -166,78 +186,147 @@ class LiteralExpression(Expression): The Python types are converted best effort into the relevant proto types. On the Spark Connect server side, the proto types are converted to the Catalyst equivalents.""" - def __init__(self, value: Any) -> None: + def __init__(self, value: Any, dataType: DataType) -> None: super().__init__() - self._value = value - def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": - """Converts the literal expression to the literal in proto. + assert isinstance( + dataType, + ( + NullType, + BinaryType, + BooleanType, + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType, + DecimalType, + StringType, + DateType, + TimestampType, + TimestampNTZType, + DayTimeIntervalType, + ), + ) - TODO(SPARK-40533) This method always assumes the largest type and can thus - create weird interpretations of the literal.""" + if isinstance(dataType, NullType): + assert value is None + + if value is not None: + if isinstance(dataType, BinaryType): + assert isinstance(value, (bytes, bytearray)) + elif isinstance(dataType, BooleanType): + assert isinstance(value, bool) + elif isinstance(dataType, ByteType): + assert isinstance(value, int) and JVM_BYTE_MIN <= int(value) <= JVM_BYTE_MAX + elif isinstance(dataType, ShortType): + assert isinstance(value, int) and JVM_SHORT_MIN <= int(value) <= JVM_SHORT_MAX + elif isinstance(dataType, IntegerType): + assert isinstance(value, int) and JVM_INT_MIN <= int(value) <= JVM_INT_MAX + elif isinstance(dataType, LongType): + assert isinstance(value, int) and JVM_LONG_MIN <= int(value) <= JVM_LONG_MAX + elif isinstance(dataType, FloatType): + assert isinstance(value, float) + elif isinstance(dataType, DoubleType): + assert isinstance(value, float) + elif isinstance(dataType, DecimalType): + assert isinstance(value, decimal.Decimal) + elif isinstance(dataType, StringType): + assert isinstance(value, str) + elif isinstance(dataType, DateType): + assert isinstance(value, (datetime.date, datetime.datetime)) + if isinstance(value, datetime.date): + value = DateType().toInternal(value) + else: + value = DateType().toInternal(value.date()) + elif isinstance(dataType, TimestampType): + assert isinstance(value, datetime.datetime) + value = TimestampType().toInternal(value) + elif isinstance(dataType, TimestampNTZType): + assert isinstance(value, datetime.datetime) + value = TimestampNTZType().toInternal(value) + elif isinstance(dataType, DayTimeIntervalType): + assert isinstance(value, datetime.timedelta) + value = DayTimeIntervalType().toInternal(value) + assert value is not None + else: + raise ValueError(f"Unsupported Data Type {dataType}") - from pyspark.sql.connect.functions import lit + self._value = value + self._dataType = dataType + + @classmethod + def _infer_type(cls, value: Any) -> DataType: + if value is None: + return NullType() + elif isinstance(value, (bytes, bytearray)): + return BinaryType() + elif isinstance(value, bool): + return BooleanType() + elif isinstance(value, int): + if JVM_INT_MIN <= value <= JVM_INT_MAX: + return IntegerType() + elif JVM_LONG_MIN <= value <= JVM_LONG_MAX: + return LongType() + else: + raise ValueError(f"integer {value} out of bounds") + elif isinstance(value, float): + return DoubleType() + elif isinstance(value, str): + return StringType() + elif isinstance(value, decimal.Decimal): + return DecimalType() + elif isinstance(value, datetime.datetime): + return TimestampType() + elif isinstance(value, datetime.date): + return DateType() + elif isinstance(value, datetime.timedelta): + return DayTimeIntervalType() + else: + raise ValueError(f"Unsupported Data Type {type(value).__name__}") + + def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": + """Converts the literal expression to the literal in proto.""" expr = proto.Expression() - if self._value is None: + + if isinstance(self._dataType, NullType): expr.literal.null = True - elif isinstance(self._value, (bytes, bytearray)): + elif self._value is None: + expr.typed_null.CopyFrom(pyspark_types_to_proto_types(self._dataType)) + elif isinstance(self._dataType, BinaryType): expr.literal.binary = bytes(self._value) - elif isinstance(self._value, bool): + elif isinstance(self._dataType, BooleanType): expr.literal.boolean = bool(self._value) - elif isinstance(self._value, int): - if JVM_INT_MIN <= self._value <= JVM_INT_MAX: - expr.literal.integer = int(self._value) - elif JVM_LONG_MIN <= self._value <= JVM_LONG_MAX: - expr.literal.long = int(self._value) - else: - raise ValueError(f"integer {self._value} out of bounds") - elif isinstance(self._value, float): + elif isinstance(self._dataType, ByteType): + expr.literal.byte = int(self._value) + elif isinstance(self._dataType, ShortType): + expr.literal.short = int(self._value) + elif isinstance(self._dataType, IntegerType): + expr.literal.integer = int(self._value) + elif isinstance(self._dataType, LongType): + expr.literal.long = int(self._value) + elif isinstance(self._dataType, FloatType): + expr.literal.float = float(self._value) + elif isinstance(self._dataType, DoubleType): expr.literal.double = float(self._value) - elif isinstance(self._value, str): - expr.literal.string = str(self._value) - elif isinstance(self._value, decimal.Decimal): + elif isinstance(self._dataType, DecimalType): expr.literal.decimal.value = str(self._value) - expr.literal.decimal.precision = int(decimal.getcontext().prec) - elif isinstance(self._value, datetime.datetime): - expr.literal.timestamp = TimestampType().toInternal(self._value) - elif isinstance(self._value, datetime.date): - expr.literal.date = DateType().toInternal(self._value) - elif isinstance(self._value, datetime.timedelta): - interval = DayTimeIntervalType().toInternal(self._value) - assert interval is not None - expr.literal.day_time_interval = int(interval) - elif isinstance(self._value, list): - expr.literal.array.SetInParent() - for item in list(self._value): - if isinstance(item, Column): - expr.literal.array.values.append(item.to_plan(session).literal) - else: - expr.literal.array.values.append(lit(item).to_plan(session).literal) - elif isinstance(self._value, tuple): - expr.literal.struct.SetInParent() - for item in list(self._value): - if isinstance(item, Column): - expr.literal.struct.fields.append(item.to_plan(session).literal) - else: - expr.literal.struct.fields.append(lit(item).to_plan(session).literal) - elif isinstance(self._value, dict): - expr.literal.map.SetInParent() - for key, value in dict(self._value).items(): - pair = proto.Expression.Literal.Map.Pair() - if isinstance(key, Column): - pair.key.CopyFrom(key.to_plan(session).literal) - else: - pair.key.CopyFrom(lit(key).to_plan(session).literal) - if isinstance(value, Column): - pair.value.CopyFrom(value.to_plan(session).literal) - else: - pair.value.CopyFrom(lit(value).to_plan(session).literal) - expr.literal.map.pairs.append(pair) - elif isinstance(self._value, Column): - expr.CopyFrom(self._value.to_plan(session)) + expr.literal.decimal.precision = self._dataType.precision + expr.literal.decimal.scale = self._dataType.scale + elif isinstance(self._dataType, StringType): + expr.literal.string = str(self._value) + elif isinstance(self._dataType, DateType): + expr.literal.date = int(self._value) + elif isinstance(self._dataType, TimestampType): + expr.literal.timestamp = int(self._value) + elif isinstance(self._dataType, TimestampNTZType): + expr.literal.timestamp_ntz = int(self._value) + elif isinstance(self._dataType, DayTimeIntervalType): + expr.literal.day_time_interval = int(self._value) else: - raise ValueError(f"Could not convert literal for type {type(self._value)}") + raise ValueError(f"Unsupported Data Type {self._dataType}") return expr diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py index dccb6d6e0c7..a81de4f9ba6 100644 --- a/python/pyspark/sql/connect/functions.py +++ b/python/pyspark/sql/connect/functions.py @@ -92,8 +92,19 @@ column = col def lit(col: Any) -> Column: if isinstance(col, Column): return col + elif isinstance(col, list): + return array(*[lit(c) for c in col]) + elif isinstance(col, tuple): + return struct(*[lit(c) for c in col]) + elif isinstance(col, dict): + cols = [] + for k, v in col.items(): + cols.append(lit(k)) + cols.append(lit(v)) + return create_map(*cols) else: - return Column(LiteralExpression(col)) + dataType = LiteralExpression._infer_type(col) + return Column(LiteralExpression(col, dataType)) # def bitwiseNOT(col: "ColumnOrName") -> Column: diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py index 91c57a9ef22..a311efc063a 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.py +++ b/python/pyspark/sql/connect/proto/expressions_pb2.py @@ -33,7 +33,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\xf4\x14\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_st [...] + b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\xa5\x11\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_st [...] ) @@ -42,10 +42,6 @@ _EXPRESSION_CAST = _EXPRESSION.nested_types_by_name["Cast"] _EXPRESSION_LITERAL = _EXPRESSION.nested_types_by_name["Literal"] _EXPRESSION_LITERAL_DECIMAL = _EXPRESSION_LITERAL.nested_types_by_name["Decimal"] _EXPRESSION_LITERAL_CALENDARINTERVAL = _EXPRESSION_LITERAL.nested_types_by_name["CalendarInterval"] -_EXPRESSION_LITERAL_STRUCT = _EXPRESSION_LITERAL.nested_types_by_name["Struct"] -_EXPRESSION_LITERAL_ARRAY = _EXPRESSION_LITERAL.nested_types_by_name["Array"] -_EXPRESSION_LITERAL_MAP = _EXPRESSION_LITERAL.nested_types_by_name["Map"] -_EXPRESSION_LITERAL_MAP_PAIR = _EXPRESSION_LITERAL_MAP.nested_types_by_name["Pair"] _EXPRESSION_UNRESOLVEDATTRIBUTE = _EXPRESSION.nested_types_by_name["UnresolvedAttribute"] _EXPRESSION_UNRESOLVEDFUNCTION = _EXPRESSION.nested_types_by_name["UnresolvedFunction"] _EXPRESSION_EXPRESSIONSTRING = _EXPRESSION.nested_types_by_name["ExpressionString"] @@ -86,42 +82,6 @@ Expression = _reflection.GeneratedProtocolMessageType( # @@protoc_insertion_point(class_scope:spark.connect.Expression.Literal.CalendarInterval) }, ), - "Struct": _reflection.GeneratedProtocolMessageType( - "Struct", - (_message.Message,), - { - "DESCRIPTOR": _EXPRESSION_LITERAL_STRUCT, - "__module__": "spark.connect.expressions_pb2" - # @@protoc_insertion_point(class_scope:spark.connect.Expression.Literal.Struct) - }, - ), - "Array": _reflection.GeneratedProtocolMessageType( - "Array", - (_message.Message,), - { - "DESCRIPTOR": _EXPRESSION_LITERAL_ARRAY, - "__module__": "spark.connect.expressions_pb2" - # @@protoc_insertion_point(class_scope:spark.connect.Expression.Literal.Array) - }, - ), - "Map": _reflection.GeneratedProtocolMessageType( - "Map", - (_message.Message,), - { - "Pair": _reflection.GeneratedProtocolMessageType( - "Pair", - (_message.Message,), - { - "DESCRIPTOR": _EXPRESSION_LITERAL_MAP_PAIR, - "__module__": "spark.connect.expressions_pb2" - # @@protoc_insertion_point(class_scope:spark.connect.Expression.Literal.Map.Pair) - }, - ), - "DESCRIPTOR": _EXPRESSION_LITERAL_MAP, - "__module__": "spark.connect.expressions_pb2" - # @@protoc_insertion_point(class_scope:spark.connect.Expression.Literal.Map) - }, - ), "DESCRIPTOR": _EXPRESSION_LITERAL, "__module__": "spark.connect.expressions_pb2" # @@protoc_insertion_point(class_scope:spark.connect.Expression.Literal) @@ -182,10 +142,6 @@ _sym_db.RegisterMessage(Expression.Cast) _sym_db.RegisterMessage(Expression.Literal) _sym_db.RegisterMessage(Expression.Literal.Decimal) _sym_db.RegisterMessage(Expression.Literal.CalendarInterval) -_sym_db.RegisterMessage(Expression.Literal.Struct) -_sym_db.RegisterMessage(Expression.Literal.Array) -_sym_db.RegisterMessage(Expression.Literal.Map) -_sym_db.RegisterMessage(Expression.Literal.Map.Pair) _sym_db.RegisterMessage(Expression.UnresolvedAttribute) _sym_db.RegisterMessage(Expression.UnresolvedFunction) _sym_db.RegisterMessage(Expression.ExpressionString) @@ -197,31 +153,23 @@ if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001" _EXPRESSION._serialized_start = 78 - _EXPRESSION._serialized_end = 2754 + _EXPRESSION._serialized_end = 2291 _EXPRESSION_CAST._serialized_start = 640 _EXPRESSION_CAST._serialized_end = 785 _EXPRESSION_LITERAL._serialized_start = 788 - _EXPRESSION_LITERAL._serialized_end = 2246 - _EXPRESSION_LITERAL_DECIMAL._serialized_start = 1684 - _EXPRESSION_LITERAL_DECIMAL._serialized_end = 1801 - _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 1803 - _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 1901 - _EXPRESSION_LITERAL_STRUCT._serialized_start = 1903 - _EXPRESSION_LITERAL_STRUCT._serialized_end = 1970 - _EXPRESSION_LITERAL_ARRAY._serialized_start = 1972 - _EXPRESSION_LITERAL_ARRAY._serialized_end = 2038 - _EXPRESSION_LITERAL_MAP._serialized_start = 2041 - _EXPRESSION_LITERAL_MAP._serialized_end = 2230 - _EXPRESSION_LITERAL_MAP_PAIR._serialized_start = 2114 - _EXPRESSION_LITERAL_MAP_PAIR._serialized_end = 2230 - _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 2248 - _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 2318 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 2321 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 2525 - _EXPRESSION_EXPRESSIONSTRING._serialized_start = 2527 - _EXPRESSION_EXPRESSIONSTRING._serialized_end = 2577 - _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 2579 - _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 2619 - _EXPRESSION_ALIAS._serialized_start = 2621 - _EXPRESSION_ALIAS._serialized_end = 2741 + _EXPRESSION_LITERAL._serialized_end = 1783 + _EXPRESSION_LITERAL_DECIMAL._serialized_start = 1550 + _EXPRESSION_LITERAL_DECIMAL._serialized_end = 1667 + _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 1669 + _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 1767 + _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 1785 + _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 1855 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 1858 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 2062 + _EXPRESSION_EXPRESSIONSTRING._serialized_start = 2064 + _EXPRESSION_EXPRESSIONSTRING._serialized_end = 2114 + _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 2116 + _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 2156 + _EXPRESSION_ALIAS._serialized_start = 2158 + _EXPRESSION_ALIAS._serialized_end = 2278 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi index 2c486f62a9d..d710fcab52d 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi +++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi @@ -190,87 +190,6 @@ class Expression(google.protobuf.message.Message): ], ) -> None: ... - class Struct(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - FIELDS_FIELD_NUMBER: builtins.int - @property - def fields( - self, - ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ - global___Expression.Literal - ]: - """A possibly heterogeneously typed list of literals""" - def __init__( - self, - *, - fields: collections.abc.Iterable[global___Expression.Literal] | None = ..., - ) -> None: ... - def ClearField( - self, field_name: typing_extensions.Literal["fields", b"fields"] - ) -> None: ... - - class Array(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - VALUES_FIELD_NUMBER: builtins.int - @property - def values( - self, - ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ - global___Expression.Literal - ]: - """A homogeneously typed list of literals""" - def __init__( - self, - *, - values: collections.abc.Iterable[global___Expression.Literal] | None = ..., - ) -> None: ... - def ClearField( - self, field_name: typing_extensions.Literal["values", b"values"] - ) -> None: ... - - class Map(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - class Pair(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - KEY_FIELD_NUMBER: builtins.int - VALUE_FIELD_NUMBER: builtins.int - @property - def key(self) -> global___Expression.Literal: ... - @property - def value(self) -> global___Expression.Literal: ... - def __init__( - self, - *, - key: global___Expression.Literal | None = ..., - value: global___Expression.Literal | None = ..., - ) -> None: ... - def HasField( - self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"] - ) -> builtins.bool: ... - def ClearField( - self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"] - ) -> None: ... - - PAIRS_FIELD_NUMBER: builtins.int - @property - def pairs( - self, - ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ - global___Expression.Literal.Map.Pair - ]: ... - def __init__( - self, - *, - pairs: collections.abc.Iterable[global___Expression.Literal.Map.Pair] | None = ..., - ) -> None: ... - def ClearField( - self, field_name: typing_extensions.Literal["pairs", b"pairs"] - ) -> None: ... - NULL_FIELD_NUMBER: builtins.int BINARY_FIELD_NUMBER: builtins.int BOOLEAN_FIELD_NUMBER: builtins.int @@ -288,9 +207,7 @@ class Expression(google.protobuf.message.Message): CALENDAR_INTERVAL_FIELD_NUMBER: builtins.int YEAR_MONTH_INTERVAL_FIELD_NUMBER: builtins.int DAY_TIME_INTERVAL_FIELD_NUMBER: builtins.int - ARRAY_FIELD_NUMBER: builtins.int - STRUCT_FIELD_NUMBER: builtins.int - MAP_FIELD_NUMBER: builtins.int + TYPED_NULL_FIELD_NUMBER: builtins.int NULLABLE_FIELD_NUMBER: builtins.int TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int null: builtins.bool @@ -316,11 +233,7 @@ class Expression(google.protobuf.message.Message): year_month_interval: builtins.int day_time_interval: builtins.int @property - def array(self) -> global___Expression.Literal.Array: ... - @property - def struct(self) -> global___Expression.Literal.Struct: ... - @property - def map(self) -> global___Expression.Literal.Map: ... + def typed_null(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ... nullable: builtins.bool """whether the literal type should be treated as a nullable type. Applies to all members of union other than the Typed null (which should directly @@ -351,17 +264,13 @@ class Expression(google.protobuf.message.Message): calendar_interval: global___Expression.Literal.CalendarInterval | None = ..., year_month_interval: builtins.int = ..., day_time_interval: builtins.int = ..., - array: global___Expression.Literal.Array | None = ..., - struct: global___Expression.Literal.Struct | None = ..., - map: global___Expression.Literal.Map | None = ..., + typed_null: pyspark.sql.connect.proto.types_pb2.DataType | None = ..., nullable: builtins.bool = ..., type_variation_reference: builtins.int = ..., ) -> None: ... def HasField( self, field_name: typing_extensions.Literal[ - "array", - b"array", "binary", b"binary", "boolean", @@ -386,20 +295,18 @@ class Expression(google.protobuf.message.Message): b"literal_type", "long", b"long", - "map", - b"map", "null", b"null", "short", b"short", "string", b"string", - "struct", - b"struct", "timestamp", b"timestamp", "timestamp_ntz", b"timestamp_ntz", + "typed_null", + b"typed_null", "year_month_interval", b"year_month_interval", ], @@ -407,8 +314,6 @@ class Expression(google.protobuf.message.Message): def ClearField( self, field_name: typing_extensions.Literal[ - "array", - b"array", "binary", b"binary", "boolean", @@ -433,8 +338,6 @@ class Expression(google.protobuf.message.Message): b"literal_type", "long", b"long", - "map", - b"map", "null", b"null", "nullable", @@ -443,14 +346,14 @@ class Expression(google.protobuf.message.Message): b"short", "string", b"string", - "struct", - b"struct", "timestamp", b"timestamp", "timestamp_ntz", b"timestamp_ntz", "type_variation_reference", b"type_variation_reference", + "typed_null", + b"typed_null", "year_month_interval", b"year_month_interval", ], @@ -475,9 +378,7 @@ class Expression(google.protobuf.message.Message): "calendar_interval", "year_month_interval", "day_time_interval", - "array", - "struct", - "map", + "typed_null", ] | None: ... class UnresolvedAttribute(google.protobuf.message.Message): diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py b/python/pyspark/sql/tests/connect/test_connect_column.py index 8b70b4d9a44..b7645bc4b71 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column.py +++ b/python/pyspark/sql/tests/connect/test_connect_column.py @@ -14,9 +14,32 @@ # See the License for the specific language governing permissions and # limitations under the License. # + +import decimal +import datetime + from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase -from pyspark.sql.types import StringType +from pyspark.sql.connect.column import ( + LiteralExpression, + JVM_BYTE_MIN, + JVM_BYTE_MAX, + JVM_SHORT_MIN, + JVM_SHORT_MAX, + JVM_INT_MIN, + JVM_INT_MAX, + JVM_LONG_MIN, + JVM_LONG_MAX, +) + from pyspark.sql.types import ( + StructField, + StructType, + ArrayType, + MapType, + NullType, + DateType, + TimestampType, + TimestampNTZType, ByteType, ShortType, IntegerType, @@ -92,13 +115,108 @@ class SparkConnectTests(SparkConnectSQLTestCase): res = pd.DataFrame(data={"id": [0, 30, 60, 90]}) self.assert_(pdf.equals(res), f"{pdf.to_string()} != {res.to_string()}") + def test_literal_with_acceptable_type(self): + for value, dataType in [ + (b"binary\0\0asas", BinaryType()), + (True, BooleanType()), + (False, BooleanType()), + (0, ByteType()), + (JVM_BYTE_MIN, ByteType()), + (JVM_BYTE_MAX, ByteType()), + (0, ShortType()), + (JVM_SHORT_MIN, ShortType()), + (JVM_SHORT_MAX, ShortType()), + (0, IntegerType()), + (JVM_INT_MIN, IntegerType()), + (JVM_INT_MAX, IntegerType()), + (0, LongType()), + (JVM_LONG_MIN, LongType()), + (JVM_LONG_MAX, LongType()), + (0.0, FloatType()), + (1.234567, FloatType()), + (float("nan"), FloatType()), + (float("inf"), FloatType()), + (float("-inf"), FloatType()), + (0.0, DoubleType()), + (1.234567, DoubleType()), + (float("nan"), DoubleType()), + (float("inf"), DoubleType()), + (float("-inf"), DoubleType()), + (decimal.Decimal(0.0), DecimalType()), + (decimal.Decimal(1.234567), DecimalType()), + ("sss", StringType()), + (datetime.date(2022, 12, 13), DateType()), + (datetime.datetime.now(), DateType()), + (datetime.datetime.now(), TimestampType()), + (datetime.datetime.now(), TimestampNTZType()), + (datetime.timedelta(1, 2, 3), DayTimeIntervalType()), + ]: + lit = LiteralExpression(value=value, dataType=dataType) + self.assertEqual(dataType, lit._dataType) + + def test_literal_with_unsupported_type(self): + for value, dataType in [ + (b"binary\0\0asas", BooleanType()), + (True, StringType()), + (False, DoubleType()), + (JVM_BYTE_MIN - 1, ByteType()), + (JVM_BYTE_MAX + 1, ByteType()), + (JVM_SHORT_MIN - 1, ShortType()), + (JVM_SHORT_MAX + 1, ShortType()), + (JVM_INT_MIN - 1, IntegerType()), + (JVM_INT_MAX + 1, IntegerType()), + (JVM_LONG_MIN - 1, LongType()), + (JVM_LONG_MAX + 1, LongType()), + (0.1, DecimalType()), + (datetime.date(2022, 12, 13), TimestampType()), + (datetime.timedelta(1, 2, 3), DateType()), + ([1, 2, 3], ArrayType(IntegerType())), + ({1: 2}, MapType(IntegerType(), IntegerType())), + ( + {"a": "xyz", "b": 1}, + StructType([StructField("a", StringType()), StructField("b", IntegerType())]), + ), + ]: + with self.assertRaises(AssertionError): + LiteralExpression(value=value, dataType=dataType) + + def test_literal_null(self): + for dataType in [ + NullType(), + BinaryType(), + BooleanType(), + ByteType(), + ShortType(), + IntegerType(), + LongType(), + FloatType(), + DoubleType(), + DecimalType(), + DateType(), + TimestampType(), + TimestampNTZType(), + DayTimeIntervalType(), + ]: + lit_null = LiteralExpression(value=None, dataType=dataType) + self.assertTrue(lit_null._value is None) + self.assertEqual(dataType, lit_null._dataType) + + for value, dataType in [ + ("123", NullType()), + (123, NullType()), + (None, ArrayType(IntegerType())), + (None, MapType(IntegerType(), IntegerType())), + (None, StructType([StructField("a", StringType())])), + ]: + with self.assertRaises(AssertionError): + LiteralExpression(value=value, dataType=dataType) + def test_literal_integers(self): cdf = self.connect.range(0, 1) sdf = self.spark.range(0, 1) from pyspark.sql import functions as SF from pyspark.sql.connect import functions as CF - from pyspark.sql.connect.column import JVM_INT_MIN, JVM_INT_MAX, JVM_LONG_MIN, JVM_LONG_MAX cdf1 = cdf.select( CF.lit(0), diff --git a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py index d74473e725f..f401e66de01 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py +++ b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py @@ -68,21 +68,50 @@ class SparkConnectColumnExpressionSuite(PlanOnlyTestFixture): val = {"this": "is", 12: [12, 32, 43]} map_lit = fun.lit(val) map_lit_p = map_lit.to_plan(None) - self.assertEqual(2, len(map_lit_p.literal.map.pairs)) - self.assertEqual("this", map_lit_p.literal.map.pairs[0].key.string) - self.assertEqual(12, map_lit_p.literal.map.pairs[1].key.integer) + + self.assertEqual(map_lit_p.unresolved_function.function_name, "map") + self.assertEqual(map_lit_p.unresolved_function.arguments[0].literal.string, "this") + self.assertEqual(map_lit_p.unresolved_function.arguments[1].literal.string, "is") + self.assertEqual(map_lit_p.unresolved_function.arguments[2].literal.integer, 12) + + self.assertEqual( + map_lit_p.unresolved_function.arguments[3].unresolved_function.function_name, "array" + ) + self.assertEqual( + map_lit_p.unresolved_function.arguments[3] + .unresolved_function.arguments[0] + .literal.integer, + 12, + ) + self.assertEqual( + map_lit_p.unresolved_function.arguments[3] + .unresolved_function.arguments[1] + .literal.integer, + 32, + ) + self.assertEqual( + map_lit_p.unresolved_function.arguments[3] + .unresolved_function.arguments[2] + .literal.integer, + 43, + ) val = {"this": fun.lit("is"), 12: [12, 32, 43]} map_lit = fun.lit(val) map_lit_p = map_lit.to_plan(None) - self.assertEqual(2, len(map_lit_p.literal.map.pairs)) - self.assertEqual("is", map_lit_p.literal.map.pairs[0].value.string) + self.assertEqual(map_lit_p.unresolved_function.function_name, "map") + self.assertEqual(len(map_lit_p.unresolved_function.arguments), 4) + self.assertEqual(map_lit_p.unresolved_function.arguments[0].literal.string, "this") + self.assertEqual(map_lit_p.unresolved_function.arguments[1].literal.string, "is") + self.assertEqual(map_lit_p.unresolved_function.arguments[2].literal.integer, 12) + self.assertEqual( + map_lit_p.unresolved_function.arguments[3].unresolved_function.function_name, "array" + ) def test_uuid_literal(self): val = uuid.uuid4() - lit = fun.lit(val) with self.assertRaises(ValueError): - lit.to_plan(None) + fun.lit(val) def test_column_literals(self): df = self.connect.with_plan(p.Read("table")) @@ -162,27 +191,34 @@ class SparkConnectColumnExpressionSuite(PlanOnlyTestFixture): p0 = fun.lit(t0).to_plan(None) self.assertIsNotNone(p0) - self.assertTrue(p0.literal.HasField("struct")) + self.assertEqual(p0.unresolved_function.function_name, "struct") p1 = fun.lit(t1).to_plan(None) self.assertIsNotNone(p1) - self.assertTrue(p1.literal.HasField("struct")) - self.assertEqual(p1.literal.struct.fields[0].double, 1.0) + self.assertEqual(p1.unresolved_function.function_name, "struct") + self.assertEqual(p1.unresolved_function.arguments[0].literal.double, 1.0) p2 = fun.lit(t2).to_plan(None) self.assertIsNotNone(p2) - self.assertTrue(p2.literal.HasField("struct")) - self.assertEqual(p2.literal.struct.fields[0].integer, 1) - self.assertEqual(p2.literal.struct.fields[1].string, "xyz") + self.assertEqual(p2.unresolved_function.function_name, "struct") + self.assertEqual(p2.unresolved_function.arguments[0].literal.integer, 1) + self.assertEqual(p2.unresolved_function.arguments[1].literal.string, "xyz") p3 = fun.lit(t3).to_plan(None) self.assertIsNotNone(p3) - self.assertTrue(p3.literal.HasField("struct")) - self.assertEqual(p3.literal.struct.fields[0].integer, 1) - self.assertEqual(p3.literal.struct.fields[1].string, "abc") - self.assertEqual(p3.literal.struct.fields[2].struct.fields[0].double, 3.5) - self.assertEqual(p3.literal.struct.fields[2].struct.fields[1].boolean, True) - self.assertEqual(p3.literal.struct.fields[2].struct.fields[2].null, True) + self.assertEqual(p3.unresolved_function.function_name, "struct") + self.assertEqual(p3.unresolved_function.arguments[0].literal.integer, 1) + self.assertEqual(p3.unresolved_function.arguments[1].literal.string, "abc") + self.assertEqual( + p3.unresolved_function.arguments[2].unresolved_function.arguments[0].literal.double, 3.5 + ) + self.assertEqual( + p3.unresolved_function.arguments[2].unresolved_function.arguments[1].literal.boolean, + True, + ) + self.assertEqual( + p3.unresolved_function.arguments[2].unresolved_function.arguments[2].literal.null, True + ) def test_column_alias(self) -> None: # SPARK-40809: Support for Column Aliases --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org