This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ef5b3eee339 [SPARK-41924][CONNECT][PYTHON] Make StructType support 
metadata and Implement `DataFrame.withMetadata`
ef5b3eee339 is described below

commit ef5b3eee339a37803e96826bf27bfce9f6c8ed46
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Sat Jan 7 09:09:58 2023 +0800

    [SPARK-41924][CONNECT][PYTHON] Make StructType support metadata and 
Implement `DataFrame.withMetadata`
    
    ### What changes were proposed in this pull request?
    Make `StructType` support metadata and Implement `DataFrame.withMetadata`
    
    ### Why are the changes needed?
    For API coverage
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    added ut
    
    Closes #39432 from zhengruifeng/connect_struct_metadata.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../main/protobuf/spark/connect/relations.proto    |  15 ++
 .../src/main/protobuf/spark/connect/types.proto    |   2 +-
 .../connect/planner/DataTypeProtoConverter.scala   |  31 ++-
 .../sql/connect/planner/SparkConnectPlanner.scala  |   8 +
 python/pyspark/sql/connect/client.py               |  21 +-
 python/pyspark/sql/connect/dataframe.py            |  16 ++
 python/pyspark/sql/connect/plan.py                 |  22 +++
 python/pyspark/sql/connect/proto/relations_pb2.py  | 214 +++++++++++----------
 python/pyspark/sql/connect/proto/relations_pb2.pyi |  43 +++++
 python/pyspark/sql/connect/proto/types_pb2.py      |  33 +---
 python/pyspark/sql/connect/proto/types_pb2.pyi     |  34 ++--
 python/pyspark/sql/connect/types.py                |  21 +-
 .../sql/tests/connect/test_connect_basic.py        |  23 ++-
 13 files changed, 308 insertions(+), 175 deletions(-)

diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index c0f22dd4576..45187f6adcf 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -60,6 +60,7 @@ message Relation {
     Unpivot unpivot = 25;
     ToSchema to_schema = 26;
     RepartitionByExpression repartition_by_expression = 27;
+    WithMetadata with_metadata = 28;
 
     // NA functions
     NAFill fill_na = 90;
@@ -694,6 +695,20 @@ message WithColumns {
   repeated Expression.Alias name_expr_list = 2;
 }
 
+
+// Update an existing column with metadata.
+message WithMetadata {
+  // (Required) The input relation.
+  Relation input = 1;
+
+  // (Required) The column.
+  string column = 2;
+
+  // (Required) The JSON-formatted metadata.
+  string metadata = 3;
+}
+
+
 // Specify a hint over a relation. Hint should have a name and optional 
parameters.
 message Hint {
   // (Required) The input relation.
diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/types.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/types.proto
index 785a191955f..7600c547714 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/types.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/types.proto
@@ -152,7 +152,7 @@ message DataType {
     string name = 1;
     DataType data_type = 2;
     bool nullable = 3;
-    map<string, string> metadata = 4;
+    optional string metadata = 4;
   }
 
   message Struct {
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala
index 0b8d79596c3..388ce9aebfc 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala
@@ -95,13 +95,17 @@ object DataTypeProtoConverter {
   }
 
   private def toCatalystStructType(t: proto.DataType.Struct): StructType = {
-    // TODO: support metadata
     val fields = t.getFieldsList.toSeq.map { protoField =>
+      val metadata = if (protoField.hasMetadata) {
+        Metadata.fromJson(protoField.getMetadata)
+      } else {
+        Metadata.empty
+      }
       StructField(
         name = protoField.getName,
         dataType = toCatalystType(protoField.getDataType),
         nullable = protoField.getNullable,
-        metadata = Metadata.empty)
+        metadata = metadata)
     }
     StructType.apply(fields)
   }
@@ -249,19 +253,28 @@ object DataTypeProtoConverter {
           .build()
 
       case StructType(fields: Array[StructField]) =>
-        // TODO: support metadata
         val protoFields = fields.toSeq.map {
           case StructField(
                 name: String,
                 dataType: DataType,
                 nullable: Boolean,
                 metadata: Metadata) =>
-            proto.DataType.StructField
-              .newBuilder()
-              .setName(name)
-              .setDataType(toConnectProtoType(dataType))
-              .setNullable(nullable)
-              .build()
+            if (metadata.equals(Metadata.empty)) {
+              proto.DataType.StructField
+                .newBuilder()
+                .setName(name)
+                .setDataType(toConnectProtoType(dataType))
+                .setNullable(nullable)
+                .build()
+            } else {
+              proto.DataType.StructField
+                .newBuilder()
+                .setName(name)
+                .setDataType(toConnectProtoType(dataType))
+                .setNullable(nullable)
+                .setMetadata(metadata.json)
+                .build()
+            }
         }
         proto.DataType
           .newBuilder()
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index b4c882541e0..2980fc3a7e0 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -106,6 +106,7 @@ class SparkConnectPlanner(session: SparkSession) {
       case proto.Relation.RelTypeCase.RENAME_COLUMNS_BY_NAME_TO_NAME_MAP =>
         
transformRenameColumnsByNameToNameMap(rel.getRenameColumnsByNameToNameMap)
       case proto.Relation.RelTypeCase.WITH_COLUMNS => 
transformWithColumns(rel.getWithColumns)
+      case proto.Relation.RelTypeCase.WITH_METADATA => 
transformWithMetadata(rel.getWithMetadata)
       case proto.Relation.RelTypeCase.HINT => transformHint(rel.getHint)
       case proto.Relation.RelTypeCase.UNPIVOT => 
transformUnpivot(rel.getUnpivot)
       case proto.Relation.RelTypeCase.REPARTITION_BY_EXPRESSION =>
@@ -485,6 +486,13 @@ class SparkConnectPlanner(session: SparkSession) {
       .logicalPlan
   }
 
+  private def transformWithMetadata(rel: proto.WithMetadata): LogicalPlan = {
+    Dataset
+      .ofRows(session, transformRelation(rel.getInput))
+      .withMetadata(rel.getColumn, Metadata.fromJson(rel.getMetadata))
+      .logicalPlan
+  }
+
   private def transformHint(rel: proto.Hint): LogicalPlan = {
     val params = rel.getParametersList.asScala.map(toCatalystValue).toSeq.map {
       case name: String => UnresolvedAttribute.quotedString(name)
diff --git a/python/pyspark/sql/connect/client.py 
b/python/pyspark/sql/connect/client.py
index 832b5648676..72839f8def9 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -19,6 +19,7 @@ import logging
 import os
 import urllib.parse
 import uuid
+import json
 from typing import Iterable, Optional, Any, Union, List, Tuple, Dict, 
NoReturn, cast
 
 import pandas as pd
@@ -447,14 +448,20 @@ class SparkConnectClient(object):
         # Server side should populate the struct field which is the schema.
         assert proto_schema.HasField("struct")
 
-        fields = [
-            StructField(
-                f.name,
-                self._proto_schema_to_pyspark_schema(f.data_type),
-                f.nullable,
+        fields = []
+        for f in proto_schema.struct.fields:
+            if f.HasField("metadata"):
+                metadata = json.loads(f.metadata)
+            else:
+                metadata = None
+            fields.append(
+                StructField(
+                    f.name,
+                    self._proto_schema_to_pyspark_schema(f.data_type),
+                    f.nullable,
+                    metadata,
+                )
             )
-            for f in proto_schema.struct.fields
-        ]
         return StructType(fields)
 
     def explain_string(self, plan: pb2.Plan, explain_mode: str = "extended") 
-> str:
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index b1331439cf8..8aca9fbb968 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -34,6 +34,7 @@ import sys
 import random
 import pandas
 import datetime
+import json
 import warnings
 from collections.abc import Iterable
 
@@ -472,6 +473,21 @@ class DataFrame:
 
     sample.__doc__ = PySparkDataFrame.sample.__doc__
 
+    def withMetadata(self, columnName: str, metadata: Dict[str, Any]) -> 
"DataFrame":
+        if not isinstance(metadata, dict):
+            raise TypeError("metadata should be a dict")
+
+        return DataFrame.withPlan(
+            plan.WithMetadata(
+                child=self._plan,
+                column=columnName,
+                metadata=json.dumps(metadata),
+            ),
+            session=self._session,
+        )
+
+    withMetadata.__doc__ = PySparkDataFrame.withMetadata.__doc__
+
     def withColumnRenamed(self, existing: str, new: str) -> "DataFrame":
         return self.withColumnsRenamed({existing: new})
 
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index 1973755be27..ae1b97a61c4 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -17,6 +17,7 @@
 
 from typing import Any, List, Optional, Sequence, Union, cast, TYPE_CHECKING, 
Mapping, Dict
 import functools
+import json
 import pyarrow as pa
 from inspect import signature, isclass
 
@@ -393,6 +394,27 @@ class WithColumns(LogicalPlan):
         return plan
 
 
+class WithMetadata(LogicalPlan):
+    def __init__(self, child: Optional["LogicalPlan"], column: str, metadata: 
str) -> None:
+        super().__init__(child)
+
+        assert isinstance(column, str)
+        assert isinstance(metadata, str)
+        # validate json string
+        json.loads(metadata)
+
+        self._column = column
+        self._metadata = metadata
+
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
+        assert self._child is not None
+        plan = proto.Relation()
+        plan.with_metadata.input.CopyFrom(self._child.plan(session))
+        plan.with_metadata.column = self._column
+        plan.with_metadata.metadata = self._metadata
+        return plan
+
+
 class Hint(LogicalPlan):
     """Logical plan object for a Hint operation."""
 
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py 
b/python/pyspark/sql/connect/proto/relations_pb2.py
index 9e230c3d239..fd3e209a78f 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as 
spark_dot_connect_dot_catal
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xed\x12\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
+    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xb1\x13\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
 )
 
 
@@ -85,6 +85,7 @@ _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY = (
     _RENAMECOLUMNSBYNAMETONAMEMAP.nested_types_by_name["RenameColumnsMapEntry"]
 )
 _WITHCOLUMNS = DESCRIPTOR.message_types_by_name["WithColumns"]
+_WITHMETADATA = DESCRIPTOR.message_types_by_name["WithMetadata"]
 _HINT = DESCRIPTOR.message_types_by_name["Hint"]
 _UNPIVOT = DESCRIPTOR.message_types_by_name["Unpivot"]
 _TOSCHEMA = DESCRIPTOR.message_types_by_name["ToSchema"]
@@ -558,6 +559,17 @@ WithColumns = _reflection.GeneratedProtocolMessageType(
 )
 _sym_db.RegisterMessage(WithColumns)
 
+WithMetadata = _reflection.GeneratedProtocolMessageType(
+    "WithMetadata",
+    (_message.Message,),
+    {
+        "DESCRIPTOR": _WITHMETADATA,
+        "__module__": "spark.connect.relations_pb2"
+        # @@protoc_insertion_point(class_scope:spark.connect.WithMetadata)
+    },
+)
+_sym_db.RegisterMessage(WithMetadata)
+
 Hint = _reflection.GeneratedProtocolMessageType(
     "Hint",
     (_message.Message,),
@@ -611,103 +623,105 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._options = None
     _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_options = 
b"8\001"
     _RELATION._serialized_start = 165
-    _RELATION._serialized_end = 2578
-    _UNKNOWN._serialized_start = 2580
-    _UNKNOWN._serialized_end = 2589
-    _RELATIONCOMMON._serialized_start = 2591
-    _RELATIONCOMMON._serialized_end = 2640
-    _SQL._serialized_start = 2642
-    _SQL._serialized_end = 2669
-    _READ._serialized_start = 2672
-    _READ._serialized_end = 3098
-    _READ_NAMEDTABLE._serialized_start = 2814
-    _READ_NAMEDTABLE._serialized_end = 2875
-    _READ_DATASOURCE._serialized_start = 2878
-    _READ_DATASOURCE._serialized_end = 3085
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3016
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3074
-    _PROJECT._serialized_start = 3100
-    _PROJECT._serialized_end = 3217
-    _FILTER._serialized_start = 3219
-    _FILTER._serialized_end = 3331
-    _JOIN._serialized_start = 3334
-    _JOIN._serialized_end = 3805
-    _JOIN_JOINTYPE._serialized_start = 3597
-    _JOIN_JOINTYPE._serialized_end = 3805
-    _SETOPERATION._serialized_start = 3808
-    _SETOPERATION._serialized_end = 4204
-    _SETOPERATION_SETOPTYPE._serialized_start = 4067
-    _SETOPERATION_SETOPTYPE._serialized_end = 4181
-    _LIMIT._serialized_start = 4206
-    _LIMIT._serialized_end = 4282
-    _OFFSET._serialized_start = 4284
-    _OFFSET._serialized_end = 4363
-    _TAIL._serialized_start = 4365
-    _TAIL._serialized_end = 4440
-    _AGGREGATE._serialized_start = 4443
-    _AGGREGATE._serialized_end = 5025
-    _AGGREGATE_PIVOT._serialized_start = 4782
-    _AGGREGATE_PIVOT._serialized_end = 4893
-    _AGGREGATE_GROUPTYPE._serialized_start = 4896
-    _AGGREGATE_GROUPTYPE._serialized_end = 5025
-    _SORT._serialized_start = 5028
-    _SORT._serialized_end = 5188
-    _DROP._serialized_start = 5190
-    _DROP._serialized_end = 5290
-    _DEDUPLICATE._serialized_start = 5293
-    _DEDUPLICATE._serialized_end = 5464
-    _LOCALRELATION._serialized_start = 5466
-    _LOCALRELATION._serialized_end = 5555
-    _SAMPLE._serialized_start = 5558
-    _SAMPLE._serialized_end = 5831
-    _RANGE._serialized_start = 5834
-    _RANGE._serialized_end = 5979
-    _SUBQUERYALIAS._serialized_start = 5981
-    _SUBQUERYALIAS._serialized_end = 6095
-    _REPARTITION._serialized_start = 6098
-    _REPARTITION._serialized_end = 6240
-    _SHOWSTRING._serialized_start = 6243
-    _SHOWSTRING._serialized_end = 6385
-    _STATSUMMARY._serialized_start = 6387
-    _STATSUMMARY._serialized_end = 6479
-    _STATDESCRIBE._serialized_start = 6481
-    _STATDESCRIBE._serialized_end = 6562
-    _STATCROSSTAB._serialized_start = 6564
-    _STATCROSSTAB._serialized_end = 6665
-    _STATCOV._serialized_start = 6667
-    _STATCOV._serialized_end = 6763
-    _STATCORR._serialized_start = 6766
-    _STATCORR._serialized_end = 6903
-    _STATAPPROXQUANTILE._serialized_start = 6906
-    _STATAPPROXQUANTILE._serialized_end = 7070
-    _STATFREQITEMS._serialized_start = 7072
-    _STATFREQITEMS._serialized_end = 7197
-    _STATSAMPLEBY._serialized_start = 7200
-    _STATSAMPLEBY._serialized_end = 7509
-    _STATSAMPLEBY_FRACTION._serialized_start = 7401
-    _STATSAMPLEBY_FRACTION._serialized_end = 7500
-    _NAFILL._serialized_start = 7512
-    _NAFILL._serialized_end = 7646
-    _NADROP._serialized_start = 7649
-    _NADROP._serialized_end = 7783
-    _NAREPLACE._serialized_start = 7786
-    _NAREPLACE._serialized_end = 8082
-    _NAREPLACE_REPLACEMENT._serialized_start = 7941
-    _NAREPLACE_REPLACEMENT._serialized_end = 8082
-    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 8084
-    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 8198
-    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 8201
-    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 8460
-    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 
8393
-    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 8460
-    _WITHCOLUMNS._serialized_start = 8463
-    _WITHCOLUMNS._serialized_end = 8594
-    _HINT._serialized_start = 8597
-    _HINT._serialized_end = 8737
-    _UNPIVOT._serialized_start = 8740
-    _UNPIVOT._serialized_end = 8986
-    _TOSCHEMA._serialized_start = 8988
-    _TOSCHEMA._serialized_end = 9094
-    _REPARTITIONBYEXPRESSION._serialized_start = 9097
-    _REPARTITIONBYEXPRESSION._serialized_end = 9300
+    _RELATION._serialized_end = 2646
+    _UNKNOWN._serialized_start = 2648
+    _UNKNOWN._serialized_end = 2657
+    _RELATIONCOMMON._serialized_start = 2659
+    _RELATIONCOMMON._serialized_end = 2708
+    _SQL._serialized_start = 2710
+    _SQL._serialized_end = 2737
+    _READ._serialized_start = 2740
+    _READ._serialized_end = 3166
+    _READ_NAMEDTABLE._serialized_start = 2882
+    _READ_NAMEDTABLE._serialized_end = 2943
+    _READ_DATASOURCE._serialized_start = 2946
+    _READ_DATASOURCE._serialized_end = 3153
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3084
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3142
+    _PROJECT._serialized_start = 3168
+    _PROJECT._serialized_end = 3285
+    _FILTER._serialized_start = 3287
+    _FILTER._serialized_end = 3399
+    _JOIN._serialized_start = 3402
+    _JOIN._serialized_end = 3873
+    _JOIN_JOINTYPE._serialized_start = 3665
+    _JOIN_JOINTYPE._serialized_end = 3873
+    _SETOPERATION._serialized_start = 3876
+    _SETOPERATION._serialized_end = 4272
+    _SETOPERATION_SETOPTYPE._serialized_start = 4135
+    _SETOPERATION_SETOPTYPE._serialized_end = 4249
+    _LIMIT._serialized_start = 4274
+    _LIMIT._serialized_end = 4350
+    _OFFSET._serialized_start = 4352
+    _OFFSET._serialized_end = 4431
+    _TAIL._serialized_start = 4433
+    _TAIL._serialized_end = 4508
+    _AGGREGATE._serialized_start = 4511
+    _AGGREGATE._serialized_end = 5093
+    _AGGREGATE_PIVOT._serialized_start = 4850
+    _AGGREGATE_PIVOT._serialized_end = 4961
+    _AGGREGATE_GROUPTYPE._serialized_start = 4964
+    _AGGREGATE_GROUPTYPE._serialized_end = 5093
+    _SORT._serialized_start = 5096
+    _SORT._serialized_end = 5256
+    _DROP._serialized_start = 5258
+    _DROP._serialized_end = 5358
+    _DEDUPLICATE._serialized_start = 5361
+    _DEDUPLICATE._serialized_end = 5532
+    _LOCALRELATION._serialized_start = 5534
+    _LOCALRELATION._serialized_end = 5623
+    _SAMPLE._serialized_start = 5626
+    _SAMPLE._serialized_end = 5899
+    _RANGE._serialized_start = 5902
+    _RANGE._serialized_end = 6047
+    _SUBQUERYALIAS._serialized_start = 6049
+    _SUBQUERYALIAS._serialized_end = 6163
+    _REPARTITION._serialized_start = 6166
+    _REPARTITION._serialized_end = 6308
+    _SHOWSTRING._serialized_start = 6311
+    _SHOWSTRING._serialized_end = 6453
+    _STATSUMMARY._serialized_start = 6455
+    _STATSUMMARY._serialized_end = 6547
+    _STATDESCRIBE._serialized_start = 6549
+    _STATDESCRIBE._serialized_end = 6630
+    _STATCROSSTAB._serialized_start = 6632
+    _STATCROSSTAB._serialized_end = 6733
+    _STATCOV._serialized_start = 6735
+    _STATCOV._serialized_end = 6831
+    _STATCORR._serialized_start = 6834
+    _STATCORR._serialized_end = 6971
+    _STATAPPROXQUANTILE._serialized_start = 6974
+    _STATAPPROXQUANTILE._serialized_end = 7138
+    _STATFREQITEMS._serialized_start = 7140
+    _STATFREQITEMS._serialized_end = 7265
+    _STATSAMPLEBY._serialized_start = 7268
+    _STATSAMPLEBY._serialized_end = 7577
+    _STATSAMPLEBY_FRACTION._serialized_start = 7469
+    _STATSAMPLEBY_FRACTION._serialized_end = 7568
+    _NAFILL._serialized_start = 7580
+    _NAFILL._serialized_end = 7714
+    _NADROP._serialized_start = 7717
+    _NADROP._serialized_end = 7851
+    _NAREPLACE._serialized_start = 7854
+    _NAREPLACE._serialized_end = 8150
+    _NAREPLACE_REPLACEMENT._serialized_start = 8009
+    _NAREPLACE_REPLACEMENT._serialized_end = 8150
+    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 8152
+    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 8266
+    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 8269
+    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 8528
+    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 
8461
+    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 8528
+    _WITHCOLUMNS._serialized_start = 8531
+    _WITHCOLUMNS._serialized_end = 8662
+    _WITHMETADATA._serialized_start = 8664
+    _WITHMETADATA._serialized_end = 8777
+    _HINT._serialized_start = 8780
+    _HINT._serialized_end = 8920
+    _UNPIVOT._serialized_start = 8923
+    _UNPIVOT._serialized_end = 9169
+    _TOSCHEMA._serialized_start = 9171
+    _TOSCHEMA._serialized_end = 9277
+    _REPARTITIONBYEXPRESSION._serialized_start = 9280
+    _REPARTITIONBYEXPRESSION._serialized_end = 9483
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi 
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 811f005d24b..11b11394c0c 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -89,6 +89,7 @@ class Relation(google.protobuf.message.Message):
     UNPIVOT_FIELD_NUMBER: builtins.int
     TO_SCHEMA_FIELD_NUMBER: builtins.int
     REPARTITION_BY_EXPRESSION_FIELD_NUMBER: builtins.int
+    WITH_METADATA_FIELD_NUMBER: builtins.int
     FILL_NA_FIELD_NUMBER: builtins.int
     DROP_NA_FIELD_NUMBER: builtins.int
     REPLACE_FIELD_NUMBER: builtins.int
@@ -158,6 +159,8 @@ class Relation(google.protobuf.message.Message):
     @property
     def repartition_by_expression(self) -> global___RepartitionByExpression: 
...
     @property
+    def with_metadata(self) -> global___WithMetadata: ...
+    @property
     def fill_na(self) -> global___NAFill:
         """NA functions"""
     @property
@@ -221,6 +224,7 @@ class Relation(google.protobuf.message.Message):
         unpivot: global___Unpivot | None = ...,
         to_schema: global___ToSchema | None = ...,
         repartition_by_expression: global___RepartitionByExpression | None = 
...,
+        with_metadata: global___WithMetadata | None = ...,
         fill_na: global___NAFill | None = ...,
         drop_na: global___NADrop | None = ...,
         replace: global___NAReplace | None = ...,
@@ -323,6 +327,8 @@ class Relation(google.protobuf.message.Message):
             b"unpivot",
             "with_columns",
             b"with_columns",
+            "with_metadata",
+            b"with_metadata",
         ],
     ) -> builtins.bool: ...
     def ClearField(
@@ -412,6 +418,8 @@ class Relation(google.protobuf.message.Message):
             b"unpivot",
             "with_columns",
             b"with_columns",
+            "with_metadata",
+            b"with_metadata",
         ],
     ) -> None: ...
     def WhichOneof(
@@ -443,6 +451,7 @@ class Relation(google.protobuf.message.Message):
         "unpivot",
         "to_schema",
         "repartition_by_expression",
+        "with_metadata",
         "fill_na",
         "drop_na",
         "replace",
@@ -2356,6 +2365,40 @@ class WithColumns(google.protobuf.message.Message):
 
 global___WithColumns = WithColumns
 
+class WithMetadata(google.protobuf.message.Message):
+    """Update an existing column with metadata."""
+
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    INPUT_FIELD_NUMBER: builtins.int
+    COLUMN_FIELD_NUMBER: builtins.int
+    METADATA_FIELD_NUMBER: builtins.int
+    @property
+    def input(self) -> global___Relation:
+        """(Required) The input relation."""
+    column: builtins.str
+    """(Required) The column."""
+    metadata: builtins.str
+    """(Required) The JSON-formatted metadata."""
+    def __init__(
+        self,
+        *,
+        input: global___Relation | None = ...,
+        column: builtins.str = ...,
+        metadata: builtins.str = ...,
+    ) -> None: ...
+    def HasField(
+        self, field_name: typing_extensions.Literal["input", b"input"]
+    ) -> builtins.bool: ...
+    def ClearField(
+        self,
+        field_name: typing_extensions.Literal[
+            "column", b"column", "input", b"input", "metadata", b"metadata"
+        ],
+    ) -> None: ...
+
+global___WithMetadata = WithMetadata
+
 class Hint(google.protobuf.message.Message):
     """Specify a hint over a relation. Hint should have a name and optional 
parameters."""
 
diff --git a/python/pyspark/sql/connect/proto/types_pb2.py 
b/python/pyspark/sql/connect/proto/types_pb2.py
index 2fcb56acb4d..4b2137d2901 100644
--- a/python/pyspark/sql/connect/proto/types_pb2.py
+++ b/python/pyspark/sql/connect/proto/types_pb2.py
@@ -30,7 +30,7 @@ _sym_db = _symbol_database.Default()
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x19spark/connect/types.proto\x12\rspark.connect"\xec\x1d\n\x08\x44\x61taType\x12\x32\n\x04null\x18\x01
 
\x01(\x0b\x32\x1c.spark.connect.DataType.NULLH\x00R\x04null\x12\x38\n\x06\x62inary\x18\x02
 
\x01(\x0b\x32\x1e.spark.connect.DataType.BinaryH\x00R\x06\x62inary\x12;\n\x07\x62oolean\x18\x03
 
\x01(\x0b\x32\x1f.spark.connect.DataType.BooleanH\x00R\x07\x62oolean\x12\x32\n\x04\x62yte\x18\x04
 
\x01(\x0b\x32\x1c.spark.connect.DataType.ByteH\x00R\x04\x62yte\x12\x35\n\x05short\x18\x05
 \x01( [...]
+    
b'\n\x19spark/connect/types.proto\x12\rspark.connect"\x8e\x1d\n\x08\x44\x61taType\x12\x32\n\x04null\x18\x01
 
\x01(\x0b\x32\x1c.spark.connect.DataType.NULLH\x00R\x04null\x12\x38\n\x06\x62inary\x18\x02
 
\x01(\x0b\x32\x1e.spark.connect.DataType.BinaryH\x00R\x06\x62inary\x12;\n\x07\x62oolean\x18\x03
 
\x01(\x0b\x32\x1f.spark.connect.DataType.BooleanH\x00R\x07\x62oolean\x12\x32\n\x04\x62yte\x18\x04
 
\x01(\x0b\x32\x1c.spark.connect.DataType.ByteH\x00R\x04\x62yte\x12\x35\n\x05short\x18\x05
 \x01( [...]
 )
 
 
@@ -55,7 +55,6 @@ _DATATYPE_CHAR = _DATATYPE.nested_types_by_name["Char"]
 _DATATYPE_VARCHAR = _DATATYPE.nested_types_by_name["VarChar"]
 _DATATYPE_DECIMAL = _DATATYPE.nested_types_by_name["Decimal"]
 _DATATYPE_STRUCTFIELD = _DATATYPE.nested_types_by_name["StructField"]
-_DATATYPE_STRUCTFIELD_METADATAENTRY = 
_DATATYPE_STRUCTFIELD.nested_types_by_name["MetadataEntry"]
 _DATATYPE_STRUCT = _DATATYPE.nested_types_by_name["Struct"]
 _DATATYPE_ARRAY = _DATATYPE.nested_types_by_name["Array"]
 _DATATYPE_MAP = _DATATYPE.nested_types_by_name["Map"]
@@ -238,15 +237,6 @@ DataType = _reflection.GeneratedProtocolMessageType(
             "StructField",
             (_message.Message,),
             {
-                "MetadataEntry": _reflection.GeneratedProtocolMessageType(
-                    "MetadataEntry",
-                    (_message.Message,),
-                    {
-                        "DESCRIPTOR": _DATATYPE_STRUCTFIELD_METADATAENTRY,
-                        "__module__": "spark.connect.types_pb2"
-                        # 
@@protoc_insertion_point(class_scope:spark.connect.DataType.StructField.MetadataEntry)
-                    },
-                ),
                 "DESCRIPTOR": _DATATYPE_STRUCTFIELD,
                 "__module__": "spark.connect.types_pb2"
                 # 
@@protoc_insertion_point(class_scope:spark.connect.DataType.StructField)
@@ -305,7 +295,6 @@ _sym_db.RegisterMessage(DataType.Char)
 _sym_db.RegisterMessage(DataType.VarChar)
 _sym_db.RegisterMessage(DataType.Decimal)
 _sym_db.RegisterMessage(DataType.StructField)
-_sym_db.RegisterMessage(DataType.StructField.MetadataEntry)
 _sym_db.RegisterMessage(DataType.Struct)
 _sym_db.RegisterMessage(DataType.Array)
 _sym_db.RegisterMessage(DataType.Map)
@@ -314,10 +303,8 @@ if _descriptor._USE_C_DESCRIPTORS == False:
 
     DESCRIPTOR._options = None
     DESCRIPTOR._serialized_options = 
b"\n\036org.apache.spark.connect.protoP\001"
-    _DATATYPE_STRUCTFIELD_METADATAENTRY._options = None
-    _DATATYPE_STRUCTFIELD_METADATAENTRY._serialized_options = b"8\001"
     _DATATYPE._serialized_start = 45
-    _DATATYPE._serialized_end = 3865
+    _DATATYPE._serialized_end = 3771
     _DATATYPE_BOOLEAN._serialized_start = 1421
     _DATATYPE_BOOLEAN._serialized_end = 1488
     _DATATYPE_BYTE._serialized_start = 1490
@@ -357,13 +344,11 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _DATATYPE_DECIMAL._serialized_start = 2930
     _DATATYPE_DECIMAL._serialized_end = 3083
     _DATATYPE_STRUCTFIELD._serialized_start = 3086
-    _DATATYPE_STRUCTFIELD._serialized_end = 3341
-    _DATATYPE_STRUCTFIELD_METADATAENTRY._serialized_start = 3282
-    _DATATYPE_STRUCTFIELD_METADATAENTRY._serialized_end = 3341
-    _DATATYPE_STRUCT._serialized_start = 3343
-    _DATATYPE_STRUCT._serialized_end = 3470
-    _DATATYPE_ARRAY._serialized_start = 3473
-    _DATATYPE_ARRAY._serialized_end = 3635
-    _DATATYPE_MAP._serialized_start = 3638
-    _DATATYPE_MAP._serialized_end = 3857
+    _DATATYPE_STRUCTFIELD._serialized_end = 3247
+    _DATATYPE_STRUCT._serialized_start = 3249
+    _DATATYPE_STRUCT._serialized_end = 3376
+    _DATATYPE_ARRAY._serialized_start = 3379
+    _DATATYPE_ARRAY._serialized_end = 3541
+    _DATATYPE_MAP._serialized_start = 3544
+    _DATATYPE_MAP._serialized_end = 3763
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/types_pb2.pyi 
b/python/pyspark/sql/connect/proto/types_pb2.pyi
index 72736301f88..3db884af5f2 100644
--- a/python/pyspark/sql/connect/proto/types_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/types_pb2.pyi
@@ -497,23 +497,6 @@ class DataType(google.protobuf.message.Message):
     class StructField(google.protobuf.message.Message):
         DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
-        class MetadataEntry(google.protobuf.message.Message):
-            DESCRIPTOR: google.protobuf.descriptor.Descriptor
-
-            KEY_FIELD_NUMBER: builtins.int
-            VALUE_FIELD_NUMBER: builtins.int
-            key: builtins.str
-            value: builtins.str
-            def __init__(
-                self,
-                *,
-                key: builtins.str = ...,
-                value: builtins.str = ...,
-            ) -> None: ...
-            def ClearField(
-                self, field_name: typing_extensions.Literal["key", b"key", 
"value", b"value"]
-            ) -> None: ...
-
         NAME_FIELD_NUMBER: builtins.int
         DATA_TYPE_FIELD_NUMBER: builtins.int
         NULLABLE_FIELD_NUMBER: builtins.int
@@ -522,24 +505,26 @@ class DataType(google.protobuf.message.Message):
         @property
         def data_type(self) -> global___DataType: ...
         nullable: builtins.bool
-        @property
-        def metadata(
-            self,
-        ) -> google.protobuf.internal.containers.ScalarMap[builtins.str, 
builtins.str]: ...
+        metadata: builtins.str
         def __init__(
             self,
             *,
             name: builtins.str = ...,
             data_type: global___DataType | None = ...,
             nullable: builtins.bool = ...,
-            metadata: collections.abc.Mapping[builtins.str, builtins.str] | 
None = ...,
+            metadata: builtins.str | None = ...,
         ) -> None: ...
         def HasField(
-            self, field_name: typing_extensions.Literal["data_type", 
b"data_type"]
+            self,
+            field_name: typing_extensions.Literal[
+                "_metadata", b"_metadata", "data_type", b"data_type", 
"metadata", b"metadata"
+            ],
         ) -> builtins.bool: ...
         def ClearField(
             self,
             field_name: typing_extensions.Literal[
+                "_metadata",
+                b"_metadata",
                 "data_type",
                 b"data_type",
                 "metadata",
@@ -550,6 +535,9 @@ class DataType(google.protobuf.message.Message):
                 b"nullable",
             ],
         ) -> None: ...
+        def WhichOneof(
+            self, oneof_group: typing_extensions.Literal["_metadata", 
b"_metadata"]
+        ) -> typing_extensions.Literal["metadata"] | None: ...
 
     class Struct(google.protobuf.message.Message):
         DESCRIPTOR: google.protobuf.descriptor.Descriptor
diff --git a/python/pyspark/sql/connect/types.py 
b/python/pyspark/sql/connect/types.py
index 1c6c3fa8c21..2f4abcec9b3 100644
--- a/python/pyspark/sql/connect/types.py
+++ b/python/pyspark/sql/connect/types.py
@@ -15,6 +15,8 @@
 # limitations under the License.
 #
 
+import json
+
 from typing import Optional
 
 from pyspark.sql.types import (
@@ -94,6 +96,8 @@ def pyspark_types_to_proto_types(data_type: DataType) -> 
pb2.DataType:
             struct_field.name = field.name
             
struct_field.data_type.CopyFrom(pyspark_types_to_proto_types(field.dataType))
             struct_field.nullable = field.nullable
+            if field.metadata is not None and len(field.metadata) > 0:
+                struct_field.metadata = json.dumps(field.metadata)
             ret.struct.fields.append(struct_field)
     elif isinstance(data_type, MapType):
         
ret.map.key_type.CopyFrom(pyspark_types_to_proto_types(data_type.keyType))
@@ -160,14 +164,17 @@ def proto_schema_to_pyspark_data_type(schema: 
pb2.DataType) -> DataType:
             schema.array.contains_null,
         )
     elif schema.HasField("struct"):
-        fields = [
-            StructField(
-                f.name,
-                proto_schema_to_pyspark_data_type(f.data_type),
-                f.nullable,
+        fields = []
+        for f in schema.struct.fields:
+            if f.HasField("metadata"):
+                metadata = json.loads(f.metadata)
+            else:
+                metadata = None
+            fields.append(
+                StructField(
+                    f.name, proto_schema_to_pyspark_data_type(f.data_type), 
f.nullable, metadata
+                )
             )
-            for f in schema.struct.fields
-        ]
         return StructType(fields)
     elif schema.HasField("map"):
         return MapType(
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 356eb17ee05..72e60712b98 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -671,11 +671,9 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
             (float(3.0), double(3.0), NULL, "3", false, NULL, ARRAY(NULL), 
NULL)
             AS tab(a, b, c, d, e, f, g, h)
             """
-        # compare the __repr__() to ignore the metadata for now
-        # the metadata is not supported in Connect for now
         self.assertEqual(
-            self.spark.sql(query).schema.__repr__(),
-            self.connect.sql(query).schema.__repr__(),
+            self.spark.sql(query).schema,
+            self.connect.sql(query).schema,
         )
 
     def test_to(self):
@@ -1971,6 +1969,23 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
         ):
             cdf.groupBy("name").pivot("department").sum("salary", 
"department").show()
 
+    def test_with_metadata(self):
+        cdf = self.connect.createDataFrame(data=[(2, "Alice"), (5, "Bob")], 
schema=["age", "name"])
+        self.assertEqual(cdf.schema["age"].metadata, {})
+        self.assertEqual(cdf.schema["name"].metadata, {})
+
+        cdf1 = cdf.withMetadata(columnName="age", metadata={"max_age": 5})
+        self.assertEqual(cdf1.schema["age"].metadata, {"max_age": 5})
+
+        cdf2 = cdf.withMetadata(columnName="name", metadata={"names": 
["Alice", "Bob"]})
+        self.assertEqual(cdf2.schema["name"].metadata, {"names": ["Alice", 
"Bob"]})
+
+        with self.assertRaisesRegex(
+            TypeError,
+            "metadata should be a dict",
+        ):
+            cdf.withMetadata(columnName="name", metadata=["magic"])
+
     def test_unsupported_functions(self):
         # SPARK-41225: Disable unsupported functions.
         df = self.connect.read.table(self.tbl_name)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to