This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 39fbf594fa73 [SPARK-52448][CONNECT] Add simplified Struct Expression.Literal 39fbf594fa73 is described below commit 39fbf594fa73e7fab2c1c3744d667c34248bfe2e Author: Yihong He <heyihong...@gmail.com> AuthorDate: Tue Jul 22 08:02:14 2025 +0900 [SPARK-52448][CONNECT] Add simplified Struct Expression.Literal ### What changes were proposed in this pull request? This PR adds a new `data_type_struct` field to the protobuf definition for struct literals in Spark Connect, addressing the ambiguity issues with the existing `struct_type` field. The changes include: 1. **Protobuf Schema Update**: Added a new `data_type_struct` field of type `DataType.Struct` to the `Literal.Struct` message in `expressions.proto`, while marking the existing `struct_type` field as deprecated. 2. **Enhanced Struct Conversion Logic**: Updated `LiteralValueProtoConverter.scala` to: - Use the new `data_type_struct` field when available for more precise struct type definition - Maintain backward compatibility by still supporting the deprecated `struct_type` field - Add proper field metadata handling in struct conversions - Improve type inference for struct fields when data types can be inferred from literal values ### Why are the changes needed? The current Expression.Struct literal is somewhat overcomplicated since it duplicates most of the information its fields already have. This is bulky to send over the wire, and it can be ambiguous. ### Does this PR introduce _any_ user-facing change? No. This PR maintains backward compatibility with existing struct literal implementations. Existing code using the deprecated `struct_type` field will continue to work without modification. ### How was this patch tested? `build/sbt "connect/testOnly *LiteralExpressionProtoConverterSuite"` ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Cursor 1.2.4 Closes #51561 from heyihong/SPARK-52448. Authored-by: Yihong He <heyihong...@gmail.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../pyspark/sql/connect/proto/expressions_pb2.py | 116 +++++----- .../pyspark/sql/connect/proto/expressions_pb2.pyi | 32 ++- .../connect/ColumnNodeToProtoConverterSuite.scala | 27 ++- .../main/protobuf/spark/connect/expressions.proto | 15 +- .../common/LiteralValueProtoConverter.scala | 240 ++++++++++++++++----- .../query-tests/queries/function_typedLit.json | 202 +++++++---------- .../queries/function_typedLit.proto.bin | Bin 9442 -> 9358 bytes .../planner/LiteralExpressionProtoConverter.scala | 5 +- .../LiteralExpressionProtoConverterSuite.scala | 133 ++++++++++++ 9 files changed, 520 insertions(+), 250 deletions(-) diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py index 351bc65d30bc..fe83e78ce657 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.py +++ b/python/pyspark/sql/connect/proto/expressions_pb2.py @@ -40,7 +40,7 @@ from pyspark.sql.connect.proto import common_pb2 as spark_dot_connect_dot_common DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto\x1a\x1aspark/connect/common.proto"\xfe\x35\n\nExpression\x12\x37\n\x06\x63ommon\x18\x12 \x01(\x0b\x32\x1f.spark.connect.ExpressionCommonR\x06\x63ommon\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolved [...] + b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto\x1a\x1aspark/connect/common.proto"\xcc\x36\n\nExpression\x12\x37\n\x06\x63ommon\x18\x12 \x01(\x0b\x32\x1f.spark.connect.ExpressionCommonR\x06\x63ommon\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolved [...] ) _globals = globals() @@ -53,8 +53,12 @@ if not _descriptor._USE_C_DESCRIPTORS: _globals[ "DESCRIPTOR" ]._serialized_options = b"\n\036org.apache.spark.connect.protoP\001Z\022internal/generated" + _globals["_EXPRESSION_LITERAL_STRUCT"].fields_by_name["struct_type"]._loaded_options = None + _globals["_EXPRESSION_LITERAL_STRUCT"].fields_by_name[ + "struct_type" + ]._serialized_options = b"\030\001" _globals["_EXPRESSION"]._serialized_start = 133 - _globals["_EXPRESSION"]._serialized_end = 7043 + _globals["_EXPRESSION"]._serialized_end = 7121 _globals["_EXPRESSION_WINDOW"]._serialized_start = 1986 _globals["_EXPRESSION_WINDOW"]._serialized_end = 2769 _globals["_EXPRESSION_WINDOW_WINDOWFRAME"]._serialized_start = 2276 @@ -74,7 +78,7 @@ if not _descriptor._USE_C_DESCRIPTORS: _globals["_EXPRESSION_CAST_EVALMODE"]._serialized_start = 3401 _globals["_EXPRESSION_CAST_EVALMODE"]._serialized_end = 3499 _globals["_EXPRESSION_LITERAL"]._serialized_start = 3518 - _globals["_EXPRESSION_LITERAL"]._serialized_end = 5642 + _globals["_EXPRESSION_LITERAL"]._serialized_end = 5720 _globals["_EXPRESSION_LITERAL_DECIMAL"]._serialized_start = 4514 _globals["_EXPRESSION_LITERAL_DECIMAL"]._serialized_end = 4631 _globals["_EXPRESSION_LITERAL_CALENDARINTERVAL"]._serialized_start = 4633 @@ -84,57 +88,57 @@ if not _descriptor._USE_C_DESCRIPTORS: _globals["_EXPRESSION_LITERAL_MAP"]._serialized_start = 4867 _globals["_EXPRESSION_LITERAL_MAP"]._serialized_end = 5094 _globals["_EXPRESSION_LITERAL_STRUCT"]._serialized_start = 5097 - _globals["_EXPRESSION_LITERAL_STRUCT"]._serialized_end = 5226 - _globals["_EXPRESSION_LITERAL_SPECIALIZEDARRAY"]._serialized_start = 5229 - _globals["_EXPRESSION_LITERAL_SPECIALIZEDARRAY"]._serialized_end = 5549 - _globals["_EXPRESSION_LITERAL_TIME"]._serialized_start = 5551 - _globals["_EXPRESSION_LITERAL_TIME"]._serialized_end = 5626 - _globals["_EXPRESSION_UNRESOLVEDATTRIBUTE"]._serialized_start = 5645 - _globals["_EXPRESSION_UNRESOLVEDATTRIBUTE"]._serialized_end = 5831 - _globals["_EXPRESSION_UNRESOLVEDFUNCTION"]._serialized_start = 5834 - _globals["_EXPRESSION_UNRESOLVEDFUNCTION"]._serialized_end = 6092 - _globals["_EXPRESSION_EXPRESSIONSTRING"]._serialized_start = 6094 - _globals["_EXPRESSION_EXPRESSIONSTRING"]._serialized_end = 6144 - _globals["_EXPRESSION_UNRESOLVEDSTAR"]._serialized_start = 6146 - _globals["_EXPRESSION_UNRESOLVEDSTAR"]._serialized_end = 6270 - _globals["_EXPRESSION_UNRESOLVEDREGEX"]._serialized_start = 6272 - _globals["_EXPRESSION_UNRESOLVEDREGEX"]._serialized_end = 6358 - _globals["_EXPRESSION_UNRESOLVEDEXTRACTVALUE"]._serialized_start = 6361 - _globals["_EXPRESSION_UNRESOLVEDEXTRACTVALUE"]._serialized_end = 6493 - _globals["_EXPRESSION_UPDATEFIELDS"]._serialized_start = 6496 - _globals["_EXPRESSION_UPDATEFIELDS"]._serialized_end = 6683 - _globals["_EXPRESSION_ALIAS"]._serialized_start = 6685 - _globals["_EXPRESSION_ALIAS"]._serialized_end = 6805 - _globals["_EXPRESSION_LAMBDAFUNCTION"]._serialized_start = 6808 - _globals["_EXPRESSION_LAMBDAFUNCTION"]._serialized_end = 6966 - _globals["_EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE"]._serialized_start = 6968 - _globals["_EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE"]._serialized_end = 7030 - _globals["_EXPRESSIONCOMMON"]._serialized_start = 7045 - _globals["_EXPRESSIONCOMMON"]._serialized_end = 7110 - _globals["_COMMONINLINEUSERDEFINEDFUNCTION"]._serialized_start = 7113 - _globals["_COMMONINLINEUSERDEFINEDFUNCTION"]._serialized_end = 7510 - _globals["_PYTHONUDF"]._serialized_start = 7513 - _globals["_PYTHONUDF"]._serialized_end = 7717 - _globals["_SCALARSCALAUDF"]._serialized_start = 7720 - _globals["_SCALARSCALAUDF"]._serialized_end = 7934 - _globals["_JAVAUDF"]._serialized_start = 7937 - _globals["_JAVAUDF"]._serialized_end = 8086 - _globals["_TYPEDAGGREGATEEXPRESSION"]._serialized_start = 8088 - _globals["_TYPEDAGGREGATEEXPRESSION"]._serialized_end = 8187 - _globals["_CALLFUNCTION"]._serialized_start = 8189 - _globals["_CALLFUNCTION"]._serialized_end = 8297 - _globals["_NAMEDARGUMENTEXPRESSION"]._serialized_start = 8299 - _globals["_NAMEDARGUMENTEXPRESSION"]._serialized_end = 8391 - _globals["_MERGEACTION"]._serialized_start = 8394 - _globals["_MERGEACTION"]._serialized_end = 8906 - _globals["_MERGEACTION_ASSIGNMENT"]._serialized_start = 8616 - _globals["_MERGEACTION_ASSIGNMENT"]._serialized_end = 8722 - _globals["_MERGEACTION_ACTIONTYPE"]._serialized_start = 8725 - _globals["_MERGEACTION_ACTIONTYPE"]._serialized_end = 8892 - _globals["_SUBQUERYEXPRESSION"]._serialized_start = 8909 - _globals["_SUBQUERYEXPRESSION"]._serialized_end = 9618 - _globals["_SUBQUERYEXPRESSION_TABLEARGOPTIONS"]._serialized_start = 9215 - _globals["_SUBQUERYEXPRESSION_TABLEARGOPTIONS"]._serialized_end = 9449 - _globals["_SUBQUERYEXPRESSION_SUBQUERYTYPE"]._serialized_start = 9452 - _globals["_SUBQUERYEXPRESSION_SUBQUERYTYPE"]._serialized_end = 9596 + _globals["_EXPRESSION_LITERAL_STRUCT"]._serialized_end = 5304 + _globals["_EXPRESSION_LITERAL_SPECIALIZEDARRAY"]._serialized_start = 5307 + _globals["_EXPRESSION_LITERAL_SPECIALIZEDARRAY"]._serialized_end = 5627 + _globals["_EXPRESSION_LITERAL_TIME"]._serialized_start = 5629 + _globals["_EXPRESSION_LITERAL_TIME"]._serialized_end = 5704 + _globals["_EXPRESSION_UNRESOLVEDATTRIBUTE"]._serialized_start = 5723 + _globals["_EXPRESSION_UNRESOLVEDATTRIBUTE"]._serialized_end = 5909 + _globals["_EXPRESSION_UNRESOLVEDFUNCTION"]._serialized_start = 5912 + _globals["_EXPRESSION_UNRESOLVEDFUNCTION"]._serialized_end = 6170 + _globals["_EXPRESSION_EXPRESSIONSTRING"]._serialized_start = 6172 + _globals["_EXPRESSION_EXPRESSIONSTRING"]._serialized_end = 6222 + _globals["_EXPRESSION_UNRESOLVEDSTAR"]._serialized_start = 6224 + _globals["_EXPRESSION_UNRESOLVEDSTAR"]._serialized_end = 6348 + _globals["_EXPRESSION_UNRESOLVEDREGEX"]._serialized_start = 6350 + _globals["_EXPRESSION_UNRESOLVEDREGEX"]._serialized_end = 6436 + _globals["_EXPRESSION_UNRESOLVEDEXTRACTVALUE"]._serialized_start = 6439 + _globals["_EXPRESSION_UNRESOLVEDEXTRACTVALUE"]._serialized_end = 6571 + _globals["_EXPRESSION_UPDATEFIELDS"]._serialized_start = 6574 + _globals["_EXPRESSION_UPDATEFIELDS"]._serialized_end = 6761 + _globals["_EXPRESSION_ALIAS"]._serialized_start = 6763 + _globals["_EXPRESSION_ALIAS"]._serialized_end = 6883 + _globals["_EXPRESSION_LAMBDAFUNCTION"]._serialized_start = 6886 + _globals["_EXPRESSION_LAMBDAFUNCTION"]._serialized_end = 7044 + _globals["_EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE"]._serialized_start = 7046 + _globals["_EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE"]._serialized_end = 7108 + _globals["_EXPRESSIONCOMMON"]._serialized_start = 7123 + _globals["_EXPRESSIONCOMMON"]._serialized_end = 7188 + _globals["_COMMONINLINEUSERDEFINEDFUNCTION"]._serialized_start = 7191 + _globals["_COMMONINLINEUSERDEFINEDFUNCTION"]._serialized_end = 7588 + _globals["_PYTHONUDF"]._serialized_start = 7591 + _globals["_PYTHONUDF"]._serialized_end = 7795 + _globals["_SCALARSCALAUDF"]._serialized_start = 7798 + _globals["_SCALARSCALAUDF"]._serialized_end = 8012 + _globals["_JAVAUDF"]._serialized_start = 8015 + _globals["_JAVAUDF"]._serialized_end = 8164 + _globals["_TYPEDAGGREGATEEXPRESSION"]._serialized_start = 8166 + _globals["_TYPEDAGGREGATEEXPRESSION"]._serialized_end = 8265 + _globals["_CALLFUNCTION"]._serialized_start = 8267 + _globals["_CALLFUNCTION"]._serialized_end = 8375 + _globals["_NAMEDARGUMENTEXPRESSION"]._serialized_start = 8377 + _globals["_NAMEDARGUMENTEXPRESSION"]._serialized_end = 8469 + _globals["_MERGEACTION"]._serialized_start = 8472 + _globals["_MERGEACTION"]._serialized_end = 8984 + _globals["_MERGEACTION_ASSIGNMENT"]._serialized_start = 8694 + _globals["_MERGEACTION_ASSIGNMENT"]._serialized_end = 8800 + _globals["_MERGEACTION_ACTIONTYPE"]._serialized_start = 8803 + _globals["_MERGEACTION_ACTIONTYPE"]._serialized_end = 8970 + _globals["_SUBQUERYEXPRESSION"]._serialized_start = 8987 + _globals["_SUBQUERYEXPRESSION"]._serialized_end = 9696 + _globals["_SUBQUERYEXPRESSION_TABLEARGOPTIONS"]._serialized_start = 9293 + _globals["_SUBQUERYEXPRESSION_TABLEARGOPTIONS"]._serialized_end = 9527 + _globals["_SUBQUERYEXPRESSION_SUBQUERYTYPE"]._serialized_start = 9530 + _globals["_SUBQUERYEXPRESSION_SUBQUERYTYPE"]._serialized_end = 9674 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi index 9376f5aac5c5..ad347fd4bd15 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi +++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi @@ -554,27 +554,51 @@ class Expression(google.protobuf.message.Message): STRUCT_TYPE_FIELD_NUMBER: builtins.int ELEMENTS_FIELD_NUMBER: builtins.int + DATA_TYPE_STRUCT_FIELD_NUMBER: builtins.int @property - def struct_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ... + def struct_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: + """(Deprecated) The type of the struct. + + This field is deprecated since Spark 4.1+ because using DataType as the type of a struct + is ambiguous. This field should only be set if the data_type_struct field is not set. + Use data_type_struct field instead. + """ @property def elements( self, ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ global___Expression.Literal - ]: ... + ]: + """(Required) The literal values that make up the struct elements.""" + @property + def data_type_struct(self) -> pyspark.sql.connect.proto.types_pb2.DataType.Struct: + """The type of the struct. + + Whether data_type_struct.fields.data_type should be set depends on + whether each field's type can be inferred from the elements field. + """ def __init__( self, *, struct_type: pyspark.sql.connect.proto.types_pb2.DataType | None = ..., elements: collections.abc.Iterable[global___Expression.Literal] | None = ..., + data_type_struct: pyspark.sql.connect.proto.types_pb2.DataType.Struct | None = ..., ) -> None: ... def HasField( - self, field_name: typing_extensions.Literal["struct_type", b"struct_type"] + self, + field_name: typing_extensions.Literal[ + "data_type_struct", b"data_type_struct", "struct_type", b"struct_type" + ], ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ - "elements", b"elements", "struct_type", b"struct_type" + "data_type_struct", + b"data_type_struct", + "elements", + b"elements", + "struct_type", + b"struct_type", ], ) -> None: ... diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ColumnNodeToProtoConverterSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ColumnNodeToProtoConverterSuite.scala index 02f0c35c44a8..90da125b49ff 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ColumnNodeToProtoConverterSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ColumnNodeToProtoConverterSuite.scala @@ -79,15 +79,24 @@ class ColumnNodeToProtoConverterSuite extends ConnectFunSuite { Literal((12.0, "north", 60.0, "west"), Option(dataType)), expr { b => val builder = b.getLiteralBuilder.getStructBuilder - builder.getStructTypeBuilder.getStructBuilder - .addFields(structField("_1", ProtoDataTypes.DoubleType)) - .addFields(structField("_2", stringTypeWithCollation)) - .addFields(structField("_3", ProtoDataTypes.DoubleType)) - .addFields(structField("_4", stringTypeWithCollation)) - builder.addElements(proto.Expression.Literal.newBuilder().setDouble(12.0)) - builder.addElements(proto.Expression.Literal.newBuilder().setString("north")) - builder.addElements(proto.Expression.Literal.newBuilder().setDouble(60.0)) - builder.addElements(proto.Expression.Literal.newBuilder().setString("west")) + builder + .addElements(proto.Expression.Literal.newBuilder().setDouble(12.0).build()) + builder + .addElements(proto.Expression.Literal.newBuilder().setString("north").build()) + builder + .addElements(proto.Expression.Literal.newBuilder().setDouble(60.0).build()) + builder + .addElements(proto.Expression.Literal.newBuilder().setString("west").build()) + builder.setDataTypeStruct( + proto.DataType.Struct + .newBuilder() + .addFields( + proto.DataType.StructField.newBuilder().setName("_1").setNullable(true).build()) + .addFields(structField("_2", stringTypeWithCollation)) + .addFields( + proto.DataType.StructField.newBuilder().setName("_3").setNullable(true).build()) + .addFields(structField("_4", stringTypeWithCollation)) + .build()) }) } diff --git a/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto b/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto index 78fa0041e192..3ae6cb8dba9b 100644 --- a/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto +++ b/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto @@ -227,8 +227,21 @@ message Expression { } message Struct { - DataType struct_type = 1; + // (Deprecated) The type of the struct. + // + // This field is deprecated since Spark 4.1+ because using DataType as the type of a struct + // is ambiguous. This field should only be set if the data_type_struct field is not set. + // Use data_type_struct field instead. + DataType struct_type = 1 [deprecated = true]; + + // (Required) The literal values that make up the struct elements. repeated Literal elements = 2; + + // The type of the struct. + // + // Whether data_type_struct.fields.data_type should be set depends on + // whether each field's type can be inferred from the elements field. + DataType.Struct data_type_struct = 3; } message SpecializedArray { diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala index 31bbfe08c8f8..4567cc10c81c 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala @@ -149,17 +149,36 @@ object LiteralValueProtoConverter { } def structBuilder(scalaValue: Any, structType: StructType) = { - val sb = builder.getStructBuilder.setStructType(toConnectProtoType(structType)) - val dataTypes = structType.fields.map(_.dataType) + val sb = builder.getStructBuilder + val fields = structType.fields scalaValue match { case p: Product => val iter = p.productIterator + val dataTypeStruct = proto.DataType.Struct.newBuilder() var idx = 0 while (idx < structType.size) { - sb.addElements(toLiteralProto(iter.next(), dataTypes(idx))) + val field = fields(idx) + val literalProto = toLiteralProto(iter.next(), field.dataType) + sb.addElements(literalProto) + + val fieldBuilder = dataTypeStruct + .addFieldsBuilder() + .setName(field.name) + .setNullable(field.nullable) + + if (LiteralValueProtoConverter.getInferredDataType(literalProto).isEmpty) { + fieldBuilder.setDataType(toConnectProtoType(field.dataType)) + } + + // Set metadata if available + if (field.metadata != Metadata.empty) { + fieldBuilder.setMetadata(field.metadata.json) + } + idx += 1 } + sb.setDataTypeStruct(dataTypeStruct.build()) case other => throw new IllegalArgumentException(s"literal $other not supported (yet).") } @@ -300,54 +319,101 @@ object LiteralValueProtoConverter { case proto.Expression.Literal.LiteralTypeCase.ARRAY => toCatalystArray(literal.getArray) + case proto.Expression.Literal.LiteralTypeCase.STRUCT => + toCatalystStruct(literal.getStruct)._1 + case other => throw new UnsupportedOperationException( s"Unsupported Literal Type: ${other.getNumber} (${other.name})") } } - private def getConverter(dataType: proto.DataType): proto.Expression.Literal => Any = { - if (dataType.hasShort) { v => - v.getShort.toShort - } else if (dataType.hasInteger) { v => - v.getInteger - } else if (dataType.hasLong) { v => - v.getLong - } else if (dataType.hasDouble) { v => - v.getDouble - } else if (dataType.hasByte) { v => - v.getByte.toByte - } else if (dataType.hasFloat) { v => - v.getFloat - } else if (dataType.hasBoolean) { v => - v.getBoolean - } else if (dataType.hasString) { v => - v.getString - } else if (dataType.hasBinary) { v => - v.getBinary.toByteArray - } else if (dataType.hasDate) { v => - v.getDate - } else if (dataType.hasTimestamp) { v => - v.getTimestamp - } else if (dataType.hasTimestampNtz) { v => - v.getTimestampNtz - } else if (dataType.hasDayTimeInterval) { v => - v.getDayTimeInterval - } else if (dataType.hasYearMonthInterval) { v => - v.getYearMonthInterval - } else if (dataType.hasDecimal) { v => - Decimal(v.getDecimal.getValue) - } else if (dataType.hasCalendarInterval) { v => - val interval = v.getCalendarInterval - new CalendarInterval(interval.getMonths, interval.getDays, interval.getMicroseconds) - } else if (dataType.hasArray) { v => - toCatalystArray(v.getArray) - } else if (dataType.hasMap) { v => - toCatalystMap(v.getMap) - } else if (dataType.hasStruct) { v => - toCatalystStruct(v.getStruct) - } else { - throw InvalidPlanInput(s"Unsupported Literal Type: $dataType)") + private def getConverter( + dataType: proto.DataType, + inferDataType: Boolean = false): proto.Expression.Literal => Any = { + dataType.getKindCase match { + case proto.DataType.KindCase.SHORT => v => v.getShort.toShort + case proto.DataType.KindCase.INTEGER => v => v.getInteger + case proto.DataType.KindCase.LONG => v => v.getLong + case proto.DataType.KindCase.DOUBLE => v => v.getDouble + case proto.DataType.KindCase.BYTE => v => v.getByte.toByte + case proto.DataType.KindCase.FLOAT => v => v.getFloat + case proto.DataType.KindCase.BOOLEAN => v => v.getBoolean + case proto.DataType.KindCase.STRING => v => v.getString + case proto.DataType.KindCase.BINARY => v => v.getBinary.toByteArray + case proto.DataType.KindCase.DATE => v => v.getDate + case proto.DataType.KindCase.TIMESTAMP => v => v.getTimestamp + case proto.DataType.KindCase.TIMESTAMP_NTZ => v => v.getTimestampNtz + case proto.DataType.KindCase.DAY_TIME_INTERVAL => v => v.getDayTimeInterval + case proto.DataType.KindCase.YEAR_MONTH_INTERVAL => v => v.getYearMonthInterval + case proto.DataType.KindCase.DECIMAL => v => Decimal(v.getDecimal.getValue) + case proto.DataType.KindCase.CALENDAR_INTERVAL => + v => + val interval = v.getCalendarInterval + new CalendarInterval(interval.getMonths, interval.getDays, interval.getMicroseconds) + case proto.DataType.KindCase.ARRAY => v => toCatalystArray(v.getArray) + case proto.DataType.KindCase.MAP => v => toCatalystMap(v.getMap) + case proto.DataType.KindCase.STRUCT => + if (inferDataType) { v => + val (struct, structType) = toCatalystStruct(v.getStruct, None) + LiteralValueWithDataType( + struct, + proto.DataType.newBuilder.setStruct(structType).build()) + } else { v => + toCatalystStruct(v.getStruct, Some(dataType.getStruct))._1 + } + case _ => + throw InvalidPlanInput(s"Unsupported Literal Type: $dataType)") + } + } + + private def getInferredDataType(literal: proto.Expression.Literal): Option[proto.DataType] = { + if (literal.hasNull) { + return Some(literal.getNull) + } + + val builder = proto.DataType.newBuilder() + literal.getLiteralTypeCase match { + case proto.Expression.Literal.LiteralTypeCase.BINARY => + builder.setBinary(proto.DataType.Binary.newBuilder.build()) + case proto.Expression.Literal.LiteralTypeCase.BOOLEAN => + builder.setBoolean(proto.DataType.Boolean.newBuilder.build()) + case proto.Expression.Literal.LiteralTypeCase.BYTE => + builder.setByte(proto.DataType.Byte.newBuilder.build()) + case proto.Expression.Literal.LiteralTypeCase.SHORT => + builder.setShort(proto.DataType.Short.newBuilder.build()) + case proto.Expression.Literal.LiteralTypeCase.INTEGER => + builder.setInteger(proto.DataType.Integer.newBuilder.build()) + case proto.Expression.Literal.LiteralTypeCase.LONG => + builder.setLong(proto.DataType.Long.newBuilder.build()) + case proto.Expression.Literal.LiteralTypeCase.FLOAT => + builder.setFloat(proto.DataType.Float.newBuilder.build()) + case proto.Expression.Literal.LiteralTypeCase.DOUBLE => + builder.setDouble(proto.DataType.Double.newBuilder.build()) + case proto.Expression.Literal.LiteralTypeCase.DATE => + builder.setDate(proto.DataType.Date.newBuilder.build()) + case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP => + builder.setTimestamp(proto.DataType.Timestamp.newBuilder.build()) + case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP_NTZ => + builder.setTimestampNtz(proto.DataType.TimestampNTZ.newBuilder.build()) + case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL => + builder.setCalendarInterval(proto.DataType.CalendarInterval.newBuilder.build()) + case proto.Expression.Literal.LiteralTypeCase.STRUCT => + // The type of the fields will be inferred from the literals of the fields in the struct. + builder.setStruct(literal.getStruct.getStructType.getStruct) + case _ => + // Not all data types support inferring the data type from the literal at the moment. + // e.g. the type of DayTimeInterval contains extra information like start_field and + // end_field and cannot be inferred from the literal. + return None + } + Some(builder.build()) + } + + private def getInferredDataTypeOrThrow(literal: proto.Expression.Literal): proto.DataType = { + getInferredDataType(literal).getOrElse { + throw InvalidPlanInput( + s"Unsupported Literal type for data type inference: ${literal.getLiteralTypeCase}") } } @@ -386,7 +452,9 @@ object LiteralValueProtoConverter { makeMapData(getConverter(map.getKeyType), getConverter(map.getValueType)) } - def toCatalystStruct(struct: proto.Expression.Literal.Struct): Any = { + def toCatalystStruct( + struct: proto.Expression.Literal.Struct, + structTypeOpt: Option[proto.DataType.Struct] = None): (Any, proto.DataType.Struct) = { def toTuple[A <: Object](data: Seq[A]): Product = { try { val tupleClass = SparkClassUtils.classForName(s"scala.Tuple${data.length}") @@ -397,16 +465,78 @@ object LiteralValueProtoConverter { } } - val elements = struct.getElementsList.asScala - val dataTypes = struct.getStructType.getStruct.getFieldsList.asScala.map(_.getDataType) - val structData = elements - .zip(dataTypes) - .map { case (element, dataType) => - getConverter(dataType)(element) + if (struct.hasDataTypeStruct) { + // The new way to define and convert structs. + val (structData, structType) = if (structTypeOpt.isDefined) { + val structFields = structTypeOpt.get.getFieldsList.asScala + val structData = + struct.getElementsList.asScala.zip(structFields).map { case (element, structField) => + getConverter(structField.getDataType)(element) + } + (structData, structTypeOpt.get) + } else { + def protoStructField( + name: String, + dataType: proto.DataType, + nullable: Boolean, + metadata: Option[String]): proto.DataType.StructField = { + val builder = proto.DataType.StructField + .newBuilder() + .setName(name) + .setDataType(dataType) + .setNullable(nullable) + metadata.foreach(builder.setMetadata) + builder.build() + } + + val dataTypeFields = struct.getDataTypeStruct.getFieldsList.asScala + + val structDataAndFields = struct.getElementsList.asScala.zip(dataTypeFields).map { + case (element, dataTypeField) => + if (dataTypeField.hasDataType) { + (getConverter(dataTypeField.getDataType)(element), dataTypeField) + } else { + val outerDataType = getInferredDataTypeOrThrow(element) + val (value, dataType) = + getConverter(outerDataType, inferDataType = true)(element) match { + case LiteralValueWithDataType(value, dataType) => (value, dataType) + case value => (value, outerDataType) + } + ( + value, + protoStructField( + dataTypeField.getName, + dataType, + dataTypeField.getNullable, + if (dataTypeField.hasMetadata) Some(dataTypeField.getMetadata) else None)) + } + } + + val structType = proto.DataType.Struct + .newBuilder() + .addAllFields(structDataAndFields.map(_._2).asJava) + .build() + + (structDataAndFields.map(_._1), structType) } - .asInstanceOf[scala.collection.Seq[Object]] - .toSeq + (toTuple(structData.toSeq.asInstanceOf[Seq[Object]]), structType) + } else if (struct.hasStructType) { + // For backward compatibility, we still support the old way to define and convert structs. + val elements = struct.getElementsList.asScala + val dataTypes = struct.getStructType.getStruct.getFieldsList.asScala.map(_.getDataType) + val structData = elements + .zip(dataTypes) + .map { case (element, dataType) => + getConverter(dataType)(element) + } + .asInstanceOf[scala.collection.Seq[Object]] + .toSeq - toTuple(structData) + (toTuple(structData), struct.getStructType.getStruct) + } else { + throw InvalidPlanInput("Data type information is missing in the struct literal.") + } } + + private case class LiteralValueWithDataType(value: Any, dataType: proto.DataType) } diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json index fa1352557a58..bd9d6bb3c8bb 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json @@ -807,38 +807,28 @@ }, { "literal": { "struct": { - "structType": { - "struct": { - "fields": [{ - "name": "_1", - "dataType": { - "string": { - "collation": "UTF8_BINARY" - } - }, - "nullable": true - }, { - "name": "_2", - "dataType": { - "integer": { - } - } - }, { - "name": "_3", - "dataType": { - "double": { - } - } - }] - } - }, "elements": [{ "string": "a" }, { "integer": 2 }, { "double": 1.0 - }] + }], + "dataTypeStruct": { + "fields": [{ + "name": "_1", + "dataType": { + "string": { + "collation": "UTF8_BINARY" + } + }, + "nullable": true + }, { + "name": "_2" + }, { + "name": "_3" + }] + } } }, "common": { @@ -1295,71 +1285,6 @@ }, { "literal": { "struct": { - "structType": { - "struct": { - "fields": [{ - "name": "_1", - "dataType": { - "array": { - "elementType": { - "integer": { - } - } - } - }, - "nullable": true - }, { - "name": "_2", - "dataType": { - "map": { - "keyType": { - "string": { - "collation": "UTF8_BINARY" - } - }, - "valueType": { - "integer": { - } - } - } - }, - "nullable": true - }, { - "name": "_3", - "dataType": { - "struct": { - "fields": [{ - "name": "_1", - "dataType": { - "string": { - "collation": "UTF8_BINARY" - } - }, - "nullable": true - }, { - "name": "_2", - "dataType": { - "map": { - "keyType": { - "integer": { - } - }, - "valueType": { - "string": { - "collation": "UTF8_BINARY" - } - }, - "valueContainsNull": true - } - }, - "nullable": true - }] - } - }, - "nullable": true - }] - } - }, "elements": [{ "array": { "elementType": { @@ -1398,36 +1323,6 @@ } }, { "struct": { - "structType": { - "struct": { - "fields": [{ - "name": "_1", - "dataType": { - "string": { - "collation": "UTF8_BINARY" - } - }, - "nullable": true - }, { - "name": "_2", - "dataType": { - "map": { - "keyType": { - "integer": { - } - }, - "valueType": { - "string": { - "collation": "UTF8_BINARY" - } - }, - "valueContainsNull": true - } - }, - "nullable": true - }] - } - }, "elements": [{ "string": "a" }, { @@ -1452,9 +1347,70 @@ "string": "b" }] } - }] + }], + "dataTypeStruct": { + "fields": [{ + "name": "_1", + "dataType": { + "string": { + "collation": "UTF8_BINARY" + } + }, + "nullable": true + }, { + "name": "_2", + "dataType": { + "map": { + "keyType": { + "integer": { + } + }, + "valueType": { + "string": { + "collation": "UTF8_BINARY" + } + }, + "valueContainsNull": true + } + }, + "nullable": true + }] + } } - }] + }], + "dataTypeStruct": { + "fields": [{ + "name": "_1", + "dataType": { + "array": { + "elementType": { + "integer": { + } + } + } + }, + "nullable": true + }, { + "name": "_2", + "dataType": { + "map": { + "keyType": { + "string": { + "collation": "UTF8_BINARY" + } + }, + "valueType": { + "integer": { + } + } + } + }, + "nullable": true + }, { + "name": "_3", + "nullable": true + }] + } } }, "common": { diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin index dca6c588cb26..da3a4a946d21 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin differ diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala index ab7b56a9b74c..10f046a57da9 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala @@ -117,8 +117,9 @@ object LiteralExpressionProtoConverter { DataTypeProtoConverter.toCatalystType(lit.getMap.getValueType))) case proto.Expression.Literal.LiteralTypeCase.STRUCT => - val dataType = DataTypeProtoConverter.toCatalystType(lit.getStruct.getStructType) - val structData = LiteralValueProtoConverter.toCatalystStruct(lit.getStruct) + val (structData, structType) = LiteralValueProtoConverter.toCatalystStruct(lit.getStruct) + val dataType = DataTypeProtoConverter.toCatalystType( + proto.DataType.newBuilder.setStruct(structType).build()) val convert = CatalystTypeConverters.createToCatalystConverter(dataType) expressions.Literal(convert(structData), dataType) diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala index 79ef8decb310..559984e47cf8 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala @@ -19,7 +19,9 @@ package org.apache.spark.sql.connect.planner import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite +import org.apache.spark.connect.proto import org.apache.spark.sql.connect.common.LiteralValueProtoConverter +import org.apache.spark.sql.types._ class LiteralExpressionProtoConverterSuite extends AnyFunSuite { // scalastyle:ignore funsuite @@ -30,4 +32,135 @@ class LiteralExpressionProtoConverterSuite extends AnyFunSuite { // scalastyle:i LiteralValueProtoConverter.toCatalystValue(LiteralValueProtoConverter.toLiteralProto(v))) } } + + Seq( + ( + (1, "string", true), + StructType( + Seq( + StructField("a", IntegerType), + StructField("b", StringType), + StructField("c", BooleanType)))), + ( + Array((1, "string", true), (2, "string", false), (3, "string", true)), + ArrayType( + StructType( + Seq( + StructField("a", IntegerType), + StructField("b", StringType), + StructField("c", BooleanType))))), + ( + (1, (2, 3)), + StructType( + Seq( + StructField("a", IntegerType), + StructField( + "b", + StructType( + Seq(StructField("c", IntegerType), StructField("d", IntegerType)))))))).zipWithIndex + .foreach { case ((v, t), idx) => + test(s"complex proto value and catalyst value conversion #$idx") { + assertResult(v)( + LiteralValueProtoConverter.toCatalystValue( + LiteralValueProtoConverter.toLiteralProto(v, t))) + } + } + + test("backward compatibility for struct literal proto") { + // Test the old way of defining structs with structType field and elements + val structTypeProto = proto.DataType.Struct + .newBuilder() + .addFields( + proto.DataType.StructField + .newBuilder() + .setName("a") + .setDataType(proto.DataType + .newBuilder() + .setInteger(proto.DataType.Integer.newBuilder()) + .build()) + .setNullable(true) + .build()) + .addFields( + proto.DataType.StructField + .newBuilder() + .setName("b") + .setDataType(proto.DataType + .newBuilder() + .setString(proto.DataType.String.newBuilder()) + .build()) + .setNullable(false) + .build()) + .build() + + val structProto = proto.Expression.Literal.Struct + .newBuilder() + .setStructType(proto.DataType.newBuilder().setStruct(structTypeProto).build()) + .addElements(LiteralValueProtoConverter.toLiteralProto(1)) + .addElements(LiteralValueProtoConverter.toLiteralProto("test")) + .build() + + val (result, resultType) = LiteralValueProtoConverter.toCatalystStruct(structProto) + + // Verify the result is a tuple with correct values + assert(result.isInstanceOf[Product]) + val product = result.asInstanceOf[Product] + assert(product.productArity == 2) + assert(product.productElement(0) == 1) + assert(product.productElement(1) == "test") + + // Verify the returned struct type matches the original + assert(resultType.getFieldsCount == 2) + assert(resultType.getFields(0).getName == "a") + assert(resultType.getFields(0).getDataType.hasInteger) + assert(resultType.getFields(0).getNullable) + assert(resultType.getFields(1).getName == "b") + assert(resultType.getFields(1).getDataType.hasString) + assert(!resultType.getFields(1).getNullable) + } + + test("data types of struct fields are not set for inferable types") { + val literalProto = LiteralValueProtoConverter.toLiteralProto( + (1, 2.0, true, (1, 2)), + StructType( + Seq( + StructField("a", IntegerType), + StructField("b", DoubleType), + StructField("c", BooleanType), + StructField( + "d", + StructType(Seq(StructField("e", IntegerType), StructField("f", IntegerType))))))) + assert(!literalProto.getStruct.getDataTypeStruct.getFieldsList.get(0).hasDataType) + assert(!literalProto.getStruct.getDataTypeStruct.getFieldsList.get(1).hasDataType) + assert(!literalProto.getStruct.getDataTypeStruct.getFieldsList.get(2).hasDataType) + assert(!literalProto.getStruct.getDataTypeStruct.getFieldsList.get(3).hasDataType) + } + + test("data types of struct fields are set for non-inferable types") { + val literalProto = LiteralValueProtoConverter.toLiteralProto( + ("string", Decimal(1)), + StructType(Seq(StructField("a", StringType), StructField("b", DecimalType(10, 2))))) + assert(literalProto.getStruct.getDataTypeStruct.getFieldsList.get(0).hasDataType) + assert(literalProto.getStruct.getDataTypeStruct.getFieldsList.get(1).hasDataType) + } + + test("nullable and metadata fields are set for struct literal proto") { + val literalProto = LiteralValueProtoConverter.toLiteralProto( + ("string", Decimal(1)), + StructType(Seq( + StructField("a", StringType, nullable = true, Metadata.fromJson("""{"key": "value"}""")), + StructField("b", DecimalType(10, 2), nullable = false)))) + val structFields = literalProto.getStruct.getDataTypeStruct.getFieldsList + assert(structFields.get(0).getNullable) + assert(structFields.get(0).hasMetadata) + assert(structFields.get(0).getMetadata == """{"key":"value"}""") + assert(!structFields.get(1).getNullable) + assert(!structFields.get(1).hasMetadata) + + val (_, structTypeProto) = LiteralValueProtoConverter.toCatalystStruct(literalProto.getStruct) + assert(structTypeProto.getFieldsList.get(0).getNullable) + assert(structTypeProto.getFieldsList.get(0).hasMetadata) + assert(structTypeProto.getFieldsList.get(0).getMetadata == """{"key":"value"}""") + assert(!structTypeProto.getFieldsList.get(1).getNullable) + assert(!structTypeProto.getFieldsList.get(1).hasMetadata) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org