This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new ef5b3eee339 [SPARK-41924][CONNECT][PYTHON] Make StructType support metadata and Implement `DataFrame.withMetadata` ef5b3eee339 is described below commit ef5b3eee339a37803e96826bf27bfce9f6c8ed46 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Sat Jan 7 09:09:58 2023 +0800 [SPARK-41924][CONNECT][PYTHON] Make StructType support metadata and Implement `DataFrame.withMetadata` ### What changes were proposed in this pull request? Make `StructType` support metadata and Implement `DataFrame.withMetadata` ### Why are the changes needed? For API coverage ### Does this PR introduce _any_ user-facing change? yes ### How was this patch tested? added ut Closes #39432 from zhengruifeng/connect_struct_metadata. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../main/protobuf/spark/connect/relations.proto | 15 ++ .../src/main/protobuf/spark/connect/types.proto | 2 +- .../connect/planner/DataTypeProtoConverter.scala | 31 ++- .../sql/connect/planner/SparkConnectPlanner.scala | 8 + python/pyspark/sql/connect/client.py | 21 +- python/pyspark/sql/connect/dataframe.py | 16 ++ python/pyspark/sql/connect/plan.py | 22 +++ python/pyspark/sql/connect/proto/relations_pb2.py | 214 +++++++++++---------- python/pyspark/sql/connect/proto/relations_pb2.pyi | 43 +++++ python/pyspark/sql/connect/proto/types_pb2.py | 33 +--- python/pyspark/sql/connect/proto/types_pb2.pyi | 34 ++-- python/pyspark/sql/connect/types.py | 21 +- .../sql/tests/connect/test_connect_basic.py | 23 ++- 13 files changed, 308 insertions(+), 175 deletions(-) diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto index c0f22dd4576..45187f6adcf 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto @@ -60,6 +60,7 @@ message Relation { Unpivot unpivot = 25; ToSchema to_schema = 26; RepartitionByExpression repartition_by_expression = 27; + WithMetadata with_metadata = 28; // NA functions NAFill fill_na = 90; @@ -694,6 +695,20 @@ message WithColumns { repeated Expression.Alias name_expr_list = 2; } + +// Update an existing column with metadata. +message WithMetadata { + // (Required) The input relation. + Relation input = 1; + + // (Required) The column. + string column = 2; + + // (Required) The JSON-formatted metadata. + string metadata = 3; +} + + // Specify a hint over a relation. Hint should have a name and optional parameters. message Hint { // (Required) The input relation. diff --git a/connector/connect/common/src/main/protobuf/spark/connect/types.proto b/connector/connect/common/src/main/protobuf/spark/connect/types.proto index 785a191955f..7600c547714 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/types.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/types.proto @@ -152,7 +152,7 @@ message DataType { string name = 1; DataType data_type = 2; bool nullable = 3; - map<string, string> metadata = 4; + optional string metadata = 4; } message Struct { diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala index 0b8d79596c3..388ce9aebfc 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala @@ -95,13 +95,17 @@ object DataTypeProtoConverter { } private def toCatalystStructType(t: proto.DataType.Struct): StructType = { - // TODO: support metadata val fields = t.getFieldsList.toSeq.map { protoField => + val metadata = if (protoField.hasMetadata) { + Metadata.fromJson(protoField.getMetadata) + } else { + Metadata.empty + } StructField( name = protoField.getName, dataType = toCatalystType(protoField.getDataType), nullable = protoField.getNullable, - metadata = Metadata.empty) + metadata = metadata) } StructType.apply(fields) } @@ -249,19 +253,28 @@ object DataTypeProtoConverter { .build() case StructType(fields: Array[StructField]) => - // TODO: support metadata val protoFields = fields.toSeq.map { case StructField( name: String, dataType: DataType, nullable: Boolean, metadata: Metadata) => - proto.DataType.StructField - .newBuilder() - .setName(name) - .setDataType(toConnectProtoType(dataType)) - .setNullable(nullable) - .build() + if (metadata.equals(Metadata.empty)) { + proto.DataType.StructField + .newBuilder() + .setName(name) + .setDataType(toConnectProtoType(dataType)) + .setNullable(nullable) + .build() + } else { + proto.DataType.StructField + .newBuilder() + .setName(name) + .setDataType(toConnectProtoType(dataType)) + .setNullable(nullable) + .setMetadata(metadata.json) + .build() + } } proto.DataType .newBuilder() diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index b4c882541e0..2980fc3a7e0 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -106,6 +106,7 @@ class SparkConnectPlanner(session: SparkSession) { case proto.Relation.RelTypeCase.RENAME_COLUMNS_BY_NAME_TO_NAME_MAP => transformRenameColumnsByNameToNameMap(rel.getRenameColumnsByNameToNameMap) case proto.Relation.RelTypeCase.WITH_COLUMNS => transformWithColumns(rel.getWithColumns) + case proto.Relation.RelTypeCase.WITH_METADATA => transformWithMetadata(rel.getWithMetadata) case proto.Relation.RelTypeCase.HINT => transformHint(rel.getHint) case proto.Relation.RelTypeCase.UNPIVOT => transformUnpivot(rel.getUnpivot) case proto.Relation.RelTypeCase.REPARTITION_BY_EXPRESSION => @@ -485,6 +486,13 @@ class SparkConnectPlanner(session: SparkSession) { .logicalPlan } + private def transformWithMetadata(rel: proto.WithMetadata): LogicalPlan = { + Dataset + .ofRows(session, transformRelation(rel.getInput)) + .withMetadata(rel.getColumn, Metadata.fromJson(rel.getMetadata)) + .logicalPlan + } + private def transformHint(rel: proto.Hint): LogicalPlan = { val params = rel.getParametersList.asScala.map(toCatalystValue).toSeq.map { case name: String => UnresolvedAttribute.quotedString(name) diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index 832b5648676..72839f8def9 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -19,6 +19,7 @@ import logging import os import urllib.parse import uuid +import json from typing import Iterable, Optional, Any, Union, List, Tuple, Dict, NoReturn, cast import pandas as pd @@ -447,14 +448,20 @@ class SparkConnectClient(object): # Server side should populate the struct field which is the schema. assert proto_schema.HasField("struct") - fields = [ - StructField( - f.name, - self._proto_schema_to_pyspark_schema(f.data_type), - f.nullable, + fields = [] + for f in proto_schema.struct.fields: + if f.HasField("metadata"): + metadata = json.loads(f.metadata) + else: + metadata = None + fields.append( + StructField( + f.name, + self._proto_schema_to_pyspark_schema(f.data_type), + f.nullable, + metadata, + ) ) - for f in proto_schema.struct.fields - ] return StructType(fields) def explain_string(self, plan: pb2.Plan, explain_mode: str = "extended") -> str: diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index b1331439cf8..8aca9fbb968 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -34,6 +34,7 @@ import sys import random import pandas import datetime +import json import warnings from collections.abc import Iterable @@ -472,6 +473,21 @@ class DataFrame: sample.__doc__ = PySparkDataFrame.sample.__doc__ + def withMetadata(self, columnName: str, metadata: Dict[str, Any]) -> "DataFrame": + if not isinstance(metadata, dict): + raise TypeError("metadata should be a dict") + + return DataFrame.withPlan( + plan.WithMetadata( + child=self._plan, + column=columnName, + metadata=json.dumps(metadata), + ), + session=self._session, + ) + + withMetadata.__doc__ = PySparkDataFrame.withMetadata.__doc__ + def withColumnRenamed(self, existing: str, new: str) -> "DataFrame": return self.withColumnsRenamed({existing: new}) diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 1973755be27..ae1b97a61c4 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -17,6 +17,7 @@ from typing import Any, List, Optional, Sequence, Union, cast, TYPE_CHECKING, Mapping, Dict import functools +import json import pyarrow as pa from inspect import signature, isclass @@ -393,6 +394,27 @@ class WithColumns(LogicalPlan): return plan +class WithMetadata(LogicalPlan): + def __init__(self, child: Optional["LogicalPlan"], column: str, metadata: str) -> None: + super().__init__(child) + + assert isinstance(column, str) + assert isinstance(metadata, str) + # validate json string + json.loads(metadata) + + self._column = column + self._metadata = metadata + + def plan(self, session: "SparkConnectClient") -> proto.Relation: + assert self._child is not None + plan = proto.Relation() + plan.with_metadata.input.CopyFrom(self._child.plan(session)) + plan.with_metadata.column = self._column + plan.with_metadata.metadata = self._metadata + return plan + + class Hint(LogicalPlan): """Logical plan object for a Hint operation.""" diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index 9e230c3d239..fd3e209a78f 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as spark_dot_connect_dot_catal DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xed\x12\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...] + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xb1\x13\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...] ) @@ -85,6 +85,7 @@ _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY = ( _RENAMECOLUMNSBYNAMETONAMEMAP.nested_types_by_name["RenameColumnsMapEntry"] ) _WITHCOLUMNS = DESCRIPTOR.message_types_by_name["WithColumns"] +_WITHMETADATA = DESCRIPTOR.message_types_by_name["WithMetadata"] _HINT = DESCRIPTOR.message_types_by_name["Hint"] _UNPIVOT = DESCRIPTOR.message_types_by_name["Unpivot"] _TOSCHEMA = DESCRIPTOR.message_types_by_name["ToSchema"] @@ -558,6 +559,17 @@ WithColumns = _reflection.GeneratedProtocolMessageType( ) _sym_db.RegisterMessage(WithColumns) +WithMetadata = _reflection.GeneratedProtocolMessageType( + "WithMetadata", + (_message.Message,), + { + "DESCRIPTOR": _WITHMETADATA, + "__module__": "spark.connect.relations_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.WithMetadata) + }, +) +_sym_db.RegisterMessage(WithMetadata) + Hint = _reflection.GeneratedProtocolMessageType( "Hint", (_message.Message,), @@ -611,103 +623,105 @@ if _descriptor._USE_C_DESCRIPTORS == False: _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._options = None _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_options = b"8\001" _RELATION._serialized_start = 165 - _RELATION._serialized_end = 2578 - _UNKNOWN._serialized_start = 2580 - _UNKNOWN._serialized_end = 2589 - _RELATIONCOMMON._serialized_start = 2591 - _RELATIONCOMMON._serialized_end = 2640 - _SQL._serialized_start = 2642 - _SQL._serialized_end = 2669 - _READ._serialized_start = 2672 - _READ._serialized_end = 3098 - _READ_NAMEDTABLE._serialized_start = 2814 - _READ_NAMEDTABLE._serialized_end = 2875 - _READ_DATASOURCE._serialized_start = 2878 - _READ_DATASOURCE._serialized_end = 3085 - _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3016 - _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3074 - _PROJECT._serialized_start = 3100 - _PROJECT._serialized_end = 3217 - _FILTER._serialized_start = 3219 - _FILTER._serialized_end = 3331 - _JOIN._serialized_start = 3334 - _JOIN._serialized_end = 3805 - _JOIN_JOINTYPE._serialized_start = 3597 - _JOIN_JOINTYPE._serialized_end = 3805 - _SETOPERATION._serialized_start = 3808 - _SETOPERATION._serialized_end = 4204 - _SETOPERATION_SETOPTYPE._serialized_start = 4067 - _SETOPERATION_SETOPTYPE._serialized_end = 4181 - _LIMIT._serialized_start = 4206 - _LIMIT._serialized_end = 4282 - _OFFSET._serialized_start = 4284 - _OFFSET._serialized_end = 4363 - _TAIL._serialized_start = 4365 - _TAIL._serialized_end = 4440 - _AGGREGATE._serialized_start = 4443 - _AGGREGATE._serialized_end = 5025 - _AGGREGATE_PIVOT._serialized_start = 4782 - _AGGREGATE_PIVOT._serialized_end = 4893 - _AGGREGATE_GROUPTYPE._serialized_start = 4896 - _AGGREGATE_GROUPTYPE._serialized_end = 5025 - _SORT._serialized_start = 5028 - _SORT._serialized_end = 5188 - _DROP._serialized_start = 5190 - _DROP._serialized_end = 5290 - _DEDUPLICATE._serialized_start = 5293 - _DEDUPLICATE._serialized_end = 5464 - _LOCALRELATION._serialized_start = 5466 - _LOCALRELATION._serialized_end = 5555 - _SAMPLE._serialized_start = 5558 - _SAMPLE._serialized_end = 5831 - _RANGE._serialized_start = 5834 - _RANGE._serialized_end = 5979 - _SUBQUERYALIAS._serialized_start = 5981 - _SUBQUERYALIAS._serialized_end = 6095 - _REPARTITION._serialized_start = 6098 - _REPARTITION._serialized_end = 6240 - _SHOWSTRING._serialized_start = 6243 - _SHOWSTRING._serialized_end = 6385 - _STATSUMMARY._serialized_start = 6387 - _STATSUMMARY._serialized_end = 6479 - _STATDESCRIBE._serialized_start = 6481 - _STATDESCRIBE._serialized_end = 6562 - _STATCROSSTAB._serialized_start = 6564 - _STATCROSSTAB._serialized_end = 6665 - _STATCOV._serialized_start = 6667 - _STATCOV._serialized_end = 6763 - _STATCORR._serialized_start = 6766 - _STATCORR._serialized_end = 6903 - _STATAPPROXQUANTILE._serialized_start = 6906 - _STATAPPROXQUANTILE._serialized_end = 7070 - _STATFREQITEMS._serialized_start = 7072 - _STATFREQITEMS._serialized_end = 7197 - _STATSAMPLEBY._serialized_start = 7200 - _STATSAMPLEBY._serialized_end = 7509 - _STATSAMPLEBY_FRACTION._serialized_start = 7401 - _STATSAMPLEBY_FRACTION._serialized_end = 7500 - _NAFILL._serialized_start = 7512 - _NAFILL._serialized_end = 7646 - _NADROP._serialized_start = 7649 - _NADROP._serialized_end = 7783 - _NAREPLACE._serialized_start = 7786 - _NAREPLACE._serialized_end = 8082 - _NAREPLACE_REPLACEMENT._serialized_start = 7941 - _NAREPLACE_REPLACEMENT._serialized_end = 8082 - _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 8084 - _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 8198 - _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 8201 - _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 8460 - _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 8393 - _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 8460 - _WITHCOLUMNS._serialized_start = 8463 - _WITHCOLUMNS._serialized_end = 8594 - _HINT._serialized_start = 8597 - _HINT._serialized_end = 8737 - _UNPIVOT._serialized_start = 8740 - _UNPIVOT._serialized_end = 8986 - _TOSCHEMA._serialized_start = 8988 - _TOSCHEMA._serialized_end = 9094 - _REPARTITIONBYEXPRESSION._serialized_start = 9097 - _REPARTITIONBYEXPRESSION._serialized_end = 9300 + _RELATION._serialized_end = 2646 + _UNKNOWN._serialized_start = 2648 + _UNKNOWN._serialized_end = 2657 + _RELATIONCOMMON._serialized_start = 2659 + _RELATIONCOMMON._serialized_end = 2708 + _SQL._serialized_start = 2710 + _SQL._serialized_end = 2737 + _READ._serialized_start = 2740 + _READ._serialized_end = 3166 + _READ_NAMEDTABLE._serialized_start = 2882 + _READ_NAMEDTABLE._serialized_end = 2943 + _READ_DATASOURCE._serialized_start = 2946 + _READ_DATASOURCE._serialized_end = 3153 + _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3084 + _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3142 + _PROJECT._serialized_start = 3168 + _PROJECT._serialized_end = 3285 + _FILTER._serialized_start = 3287 + _FILTER._serialized_end = 3399 + _JOIN._serialized_start = 3402 + _JOIN._serialized_end = 3873 + _JOIN_JOINTYPE._serialized_start = 3665 + _JOIN_JOINTYPE._serialized_end = 3873 + _SETOPERATION._serialized_start = 3876 + _SETOPERATION._serialized_end = 4272 + _SETOPERATION_SETOPTYPE._serialized_start = 4135 + _SETOPERATION_SETOPTYPE._serialized_end = 4249 + _LIMIT._serialized_start = 4274 + _LIMIT._serialized_end = 4350 + _OFFSET._serialized_start = 4352 + _OFFSET._serialized_end = 4431 + _TAIL._serialized_start = 4433 + _TAIL._serialized_end = 4508 + _AGGREGATE._serialized_start = 4511 + _AGGREGATE._serialized_end = 5093 + _AGGREGATE_PIVOT._serialized_start = 4850 + _AGGREGATE_PIVOT._serialized_end = 4961 + _AGGREGATE_GROUPTYPE._serialized_start = 4964 + _AGGREGATE_GROUPTYPE._serialized_end = 5093 + _SORT._serialized_start = 5096 + _SORT._serialized_end = 5256 + _DROP._serialized_start = 5258 + _DROP._serialized_end = 5358 + _DEDUPLICATE._serialized_start = 5361 + _DEDUPLICATE._serialized_end = 5532 + _LOCALRELATION._serialized_start = 5534 + _LOCALRELATION._serialized_end = 5623 + _SAMPLE._serialized_start = 5626 + _SAMPLE._serialized_end = 5899 + _RANGE._serialized_start = 5902 + _RANGE._serialized_end = 6047 + _SUBQUERYALIAS._serialized_start = 6049 + _SUBQUERYALIAS._serialized_end = 6163 + _REPARTITION._serialized_start = 6166 + _REPARTITION._serialized_end = 6308 + _SHOWSTRING._serialized_start = 6311 + _SHOWSTRING._serialized_end = 6453 + _STATSUMMARY._serialized_start = 6455 + _STATSUMMARY._serialized_end = 6547 + _STATDESCRIBE._serialized_start = 6549 + _STATDESCRIBE._serialized_end = 6630 + _STATCROSSTAB._serialized_start = 6632 + _STATCROSSTAB._serialized_end = 6733 + _STATCOV._serialized_start = 6735 + _STATCOV._serialized_end = 6831 + _STATCORR._serialized_start = 6834 + _STATCORR._serialized_end = 6971 + _STATAPPROXQUANTILE._serialized_start = 6974 + _STATAPPROXQUANTILE._serialized_end = 7138 + _STATFREQITEMS._serialized_start = 7140 + _STATFREQITEMS._serialized_end = 7265 + _STATSAMPLEBY._serialized_start = 7268 + _STATSAMPLEBY._serialized_end = 7577 + _STATSAMPLEBY_FRACTION._serialized_start = 7469 + _STATSAMPLEBY_FRACTION._serialized_end = 7568 + _NAFILL._serialized_start = 7580 + _NAFILL._serialized_end = 7714 + _NADROP._serialized_start = 7717 + _NADROP._serialized_end = 7851 + _NAREPLACE._serialized_start = 7854 + _NAREPLACE._serialized_end = 8150 + _NAREPLACE_REPLACEMENT._serialized_start = 8009 + _NAREPLACE_REPLACEMENT._serialized_end = 8150 + _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 8152 + _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 8266 + _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 8269 + _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 8528 + _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 8461 + _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 8528 + _WITHCOLUMNS._serialized_start = 8531 + _WITHCOLUMNS._serialized_end = 8662 + _WITHMETADATA._serialized_start = 8664 + _WITHMETADATA._serialized_end = 8777 + _HINT._serialized_start = 8780 + _HINT._serialized_end = 8920 + _UNPIVOT._serialized_start = 8923 + _UNPIVOT._serialized_end = 9169 + _TOSCHEMA._serialized_start = 9171 + _TOSCHEMA._serialized_end = 9277 + _REPARTITIONBYEXPRESSION._serialized_start = 9280 + _REPARTITIONBYEXPRESSION._serialized_end = 9483 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index 811f005d24b..11b11394c0c 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -89,6 +89,7 @@ class Relation(google.protobuf.message.Message): UNPIVOT_FIELD_NUMBER: builtins.int TO_SCHEMA_FIELD_NUMBER: builtins.int REPARTITION_BY_EXPRESSION_FIELD_NUMBER: builtins.int + WITH_METADATA_FIELD_NUMBER: builtins.int FILL_NA_FIELD_NUMBER: builtins.int DROP_NA_FIELD_NUMBER: builtins.int REPLACE_FIELD_NUMBER: builtins.int @@ -158,6 +159,8 @@ class Relation(google.protobuf.message.Message): @property def repartition_by_expression(self) -> global___RepartitionByExpression: ... @property + def with_metadata(self) -> global___WithMetadata: ... + @property def fill_na(self) -> global___NAFill: """NA functions""" @property @@ -221,6 +224,7 @@ class Relation(google.protobuf.message.Message): unpivot: global___Unpivot | None = ..., to_schema: global___ToSchema | None = ..., repartition_by_expression: global___RepartitionByExpression | None = ..., + with_metadata: global___WithMetadata | None = ..., fill_na: global___NAFill | None = ..., drop_na: global___NADrop | None = ..., replace: global___NAReplace | None = ..., @@ -323,6 +327,8 @@ class Relation(google.protobuf.message.Message): b"unpivot", "with_columns", b"with_columns", + "with_metadata", + b"with_metadata", ], ) -> builtins.bool: ... def ClearField( @@ -412,6 +418,8 @@ class Relation(google.protobuf.message.Message): b"unpivot", "with_columns", b"with_columns", + "with_metadata", + b"with_metadata", ], ) -> None: ... def WhichOneof( @@ -443,6 +451,7 @@ class Relation(google.protobuf.message.Message): "unpivot", "to_schema", "repartition_by_expression", + "with_metadata", "fill_na", "drop_na", "replace", @@ -2356,6 +2365,40 @@ class WithColumns(google.protobuf.message.Message): global___WithColumns = WithColumns +class WithMetadata(google.protobuf.message.Message): + """Update an existing column with metadata.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + INPUT_FIELD_NUMBER: builtins.int + COLUMN_FIELD_NUMBER: builtins.int + METADATA_FIELD_NUMBER: builtins.int + @property + def input(self) -> global___Relation: + """(Required) The input relation.""" + column: builtins.str + """(Required) The column.""" + metadata: builtins.str + """(Required) The JSON-formatted metadata.""" + def __init__( + self, + *, + input: global___Relation | None = ..., + column: builtins.str = ..., + metadata: builtins.str = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["input", b"input"] + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "column", b"column", "input", b"input", "metadata", b"metadata" + ], + ) -> None: ... + +global___WithMetadata = WithMetadata + class Hint(google.protobuf.message.Message): """Specify a hint over a relation. Hint should have a name and optional parameters.""" diff --git a/python/pyspark/sql/connect/proto/types_pb2.py b/python/pyspark/sql/connect/proto/types_pb2.py index 2fcb56acb4d..4b2137d2901 100644 --- a/python/pyspark/sql/connect/proto/types_pb2.py +++ b/python/pyspark/sql/connect/proto/types_pb2.py @@ -30,7 +30,7 @@ _sym_db = _symbol_database.Default() DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x19spark/connect/types.proto\x12\rspark.connect"\xec\x1d\n\x08\x44\x61taType\x12\x32\n\x04null\x18\x01 \x01(\x0b\x32\x1c.spark.connect.DataType.NULLH\x00R\x04null\x12\x38\n\x06\x62inary\x18\x02 \x01(\x0b\x32\x1e.spark.connect.DataType.BinaryH\x00R\x06\x62inary\x12;\n\x07\x62oolean\x18\x03 \x01(\x0b\x32\x1f.spark.connect.DataType.BooleanH\x00R\x07\x62oolean\x12\x32\n\x04\x62yte\x18\x04 \x01(\x0b\x32\x1c.spark.connect.DataType.ByteH\x00R\x04\x62yte\x12\x35\n\x05short\x18\x05 \x01( [...] + b'\n\x19spark/connect/types.proto\x12\rspark.connect"\x8e\x1d\n\x08\x44\x61taType\x12\x32\n\x04null\x18\x01 \x01(\x0b\x32\x1c.spark.connect.DataType.NULLH\x00R\x04null\x12\x38\n\x06\x62inary\x18\x02 \x01(\x0b\x32\x1e.spark.connect.DataType.BinaryH\x00R\x06\x62inary\x12;\n\x07\x62oolean\x18\x03 \x01(\x0b\x32\x1f.spark.connect.DataType.BooleanH\x00R\x07\x62oolean\x12\x32\n\x04\x62yte\x18\x04 \x01(\x0b\x32\x1c.spark.connect.DataType.ByteH\x00R\x04\x62yte\x12\x35\n\x05short\x18\x05 \x01( [...] ) @@ -55,7 +55,6 @@ _DATATYPE_CHAR = _DATATYPE.nested_types_by_name["Char"] _DATATYPE_VARCHAR = _DATATYPE.nested_types_by_name["VarChar"] _DATATYPE_DECIMAL = _DATATYPE.nested_types_by_name["Decimal"] _DATATYPE_STRUCTFIELD = _DATATYPE.nested_types_by_name["StructField"] -_DATATYPE_STRUCTFIELD_METADATAENTRY = _DATATYPE_STRUCTFIELD.nested_types_by_name["MetadataEntry"] _DATATYPE_STRUCT = _DATATYPE.nested_types_by_name["Struct"] _DATATYPE_ARRAY = _DATATYPE.nested_types_by_name["Array"] _DATATYPE_MAP = _DATATYPE.nested_types_by_name["Map"] @@ -238,15 +237,6 @@ DataType = _reflection.GeneratedProtocolMessageType( "StructField", (_message.Message,), { - "MetadataEntry": _reflection.GeneratedProtocolMessageType( - "MetadataEntry", - (_message.Message,), - { - "DESCRIPTOR": _DATATYPE_STRUCTFIELD_METADATAENTRY, - "__module__": "spark.connect.types_pb2" - # @@protoc_insertion_point(class_scope:spark.connect.DataType.StructField.MetadataEntry) - }, - ), "DESCRIPTOR": _DATATYPE_STRUCTFIELD, "__module__": "spark.connect.types_pb2" # @@protoc_insertion_point(class_scope:spark.connect.DataType.StructField) @@ -305,7 +295,6 @@ _sym_db.RegisterMessage(DataType.Char) _sym_db.RegisterMessage(DataType.VarChar) _sym_db.RegisterMessage(DataType.Decimal) _sym_db.RegisterMessage(DataType.StructField) -_sym_db.RegisterMessage(DataType.StructField.MetadataEntry) _sym_db.RegisterMessage(DataType.Struct) _sym_db.RegisterMessage(DataType.Array) _sym_db.RegisterMessage(DataType.Map) @@ -314,10 +303,8 @@ if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001" - _DATATYPE_STRUCTFIELD_METADATAENTRY._options = None - _DATATYPE_STRUCTFIELD_METADATAENTRY._serialized_options = b"8\001" _DATATYPE._serialized_start = 45 - _DATATYPE._serialized_end = 3865 + _DATATYPE._serialized_end = 3771 _DATATYPE_BOOLEAN._serialized_start = 1421 _DATATYPE_BOOLEAN._serialized_end = 1488 _DATATYPE_BYTE._serialized_start = 1490 @@ -357,13 +344,11 @@ if _descriptor._USE_C_DESCRIPTORS == False: _DATATYPE_DECIMAL._serialized_start = 2930 _DATATYPE_DECIMAL._serialized_end = 3083 _DATATYPE_STRUCTFIELD._serialized_start = 3086 - _DATATYPE_STRUCTFIELD._serialized_end = 3341 - _DATATYPE_STRUCTFIELD_METADATAENTRY._serialized_start = 3282 - _DATATYPE_STRUCTFIELD_METADATAENTRY._serialized_end = 3341 - _DATATYPE_STRUCT._serialized_start = 3343 - _DATATYPE_STRUCT._serialized_end = 3470 - _DATATYPE_ARRAY._serialized_start = 3473 - _DATATYPE_ARRAY._serialized_end = 3635 - _DATATYPE_MAP._serialized_start = 3638 - _DATATYPE_MAP._serialized_end = 3857 + _DATATYPE_STRUCTFIELD._serialized_end = 3247 + _DATATYPE_STRUCT._serialized_start = 3249 + _DATATYPE_STRUCT._serialized_end = 3376 + _DATATYPE_ARRAY._serialized_start = 3379 + _DATATYPE_ARRAY._serialized_end = 3541 + _DATATYPE_MAP._serialized_start = 3544 + _DATATYPE_MAP._serialized_end = 3763 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/types_pb2.pyi b/python/pyspark/sql/connect/proto/types_pb2.pyi index 72736301f88..3db884af5f2 100644 --- a/python/pyspark/sql/connect/proto/types_pb2.pyi +++ b/python/pyspark/sql/connect/proto/types_pb2.pyi @@ -497,23 +497,6 @@ class DataType(google.protobuf.message.Message): class StructField(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - class MetadataEntry(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - KEY_FIELD_NUMBER: builtins.int - VALUE_FIELD_NUMBER: builtins.int - key: builtins.str - value: builtins.str - def __init__( - self, - *, - key: builtins.str = ..., - value: builtins.str = ..., - ) -> None: ... - def ClearField( - self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"] - ) -> None: ... - NAME_FIELD_NUMBER: builtins.int DATA_TYPE_FIELD_NUMBER: builtins.int NULLABLE_FIELD_NUMBER: builtins.int @@ -522,24 +505,26 @@ class DataType(google.protobuf.message.Message): @property def data_type(self) -> global___DataType: ... nullable: builtins.bool - @property - def metadata( - self, - ) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: ... + metadata: builtins.str def __init__( self, *, name: builtins.str = ..., data_type: global___DataType | None = ..., nullable: builtins.bool = ..., - metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ..., + metadata: builtins.str | None = ..., ) -> None: ... def HasField( - self, field_name: typing_extensions.Literal["data_type", b"data_type"] + self, + field_name: typing_extensions.Literal[ + "_metadata", b"_metadata", "data_type", b"data_type", "metadata", b"metadata" + ], ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ + "_metadata", + b"_metadata", "data_type", b"data_type", "metadata", @@ -550,6 +535,9 @@ class DataType(google.protobuf.message.Message): b"nullable", ], ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_metadata", b"_metadata"] + ) -> typing_extensions.Literal["metadata"] | None: ... class Struct(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor diff --git a/python/pyspark/sql/connect/types.py b/python/pyspark/sql/connect/types.py index 1c6c3fa8c21..2f4abcec9b3 100644 --- a/python/pyspark/sql/connect/types.py +++ b/python/pyspark/sql/connect/types.py @@ -15,6 +15,8 @@ # limitations under the License. # +import json + from typing import Optional from pyspark.sql.types import ( @@ -94,6 +96,8 @@ def pyspark_types_to_proto_types(data_type: DataType) -> pb2.DataType: struct_field.name = field.name struct_field.data_type.CopyFrom(pyspark_types_to_proto_types(field.dataType)) struct_field.nullable = field.nullable + if field.metadata is not None and len(field.metadata) > 0: + struct_field.metadata = json.dumps(field.metadata) ret.struct.fields.append(struct_field) elif isinstance(data_type, MapType): ret.map.key_type.CopyFrom(pyspark_types_to_proto_types(data_type.keyType)) @@ -160,14 +164,17 @@ def proto_schema_to_pyspark_data_type(schema: pb2.DataType) -> DataType: schema.array.contains_null, ) elif schema.HasField("struct"): - fields = [ - StructField( - f.name, - proto_schema_to_pyspark_data_type(f.data_type), - f.nullable, + fields = [] + for f in schema.struct.fields: + if f.HasField("metadata"): + metadata = json.loads(f.metadata) + else: + metadata = None + fields.append( + StructField( + f.name, proto_schema_to_pyspark_data_type(f.data_type), f.nullable, metadata + ) ) - for f in schema.struct.fields - ] return StructType(fields) elif schema.HasField("map"): return MapType( diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 356eb17ee05..72e60712b98 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -671,11 +671,9 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): (float(3.0), double(3.0), NULL, "3", false, NULL, ARRAY(NULL), NULL) AS tab(a, b, c, d, e, f, g, h) """ - # compare the __repr__() to ignore the metadata for now - # the metadata is not supported in Connect for now self.assertEqual( - self.spark.sql(query).schema.__repr__(), - self.connect.sql(query).schema.__repr__(), + self.spark.sql(query).schema, + self.connect.sql(query).schema, ) def test_to(self): @@ -1971,6 +1969,23 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): ): cdf.groupBy("name").pivot("department").sum("salary", "department").show() + def test_with_metadata(self): + cdf = self.connect.createDataFrame(data=[(2, "Alice"), (5, "Bob")], schema=["age", "name"]) + self.assertEqual(cdf.schema["age"].metadata, {}) + self.assertEqual(cdf.schema["name"].metadata, {}) + + cdf1 = cdf.withMetadata(columnName="age", metadata={"max_age": 5}) + self.assertEqual(cdf1.schema["age"].metadata, {"max_age": 5}) + + cdf2 = cdf.withMetadata(columnName="name", metadata={"names": ["Alice", "Bob"]}) + self.assertEqual(cdf2.schema["name"].metadata, {"names": ["Alice", "Bob"]}) + + with self.assertRaisesRegex( + TypeError, + "metadata should be a dict", + ): + cdf.withMetadata(columnName="name", metadata=["magic"]) + def test_unsupported_functions(self): # SPARK-41225: Disable unsupported functions. df = self.connect.read.table(self.tbl_name) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org