This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new dc6cf74dc2d [SPARK-42020][CONNECT][PYTHON] Support UserDefinedType in 
Spark Connect
dc6cf74dc2d is described below

commit dc6cf74dc2db54d935cf54cb3e4829a468dcdf78
Author: Takuya UESHIN <[email protected]>
AuthorDate: Mon Mar 20 09:24:26 2023 +0900

    [SPARK-42020][CONNECT][PYTHON] Support UserDefinedType in Spark Connect
    
    ### What changes were proposed in this pull request?
    
    Supports `UserDefinedType` in Spark Connect.
    
    ### Why are the changes needed?
    
    Currently Spark Connect doesn't support UDTs.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, UDTs will be available in Spark Connect.
    
    ### How was this patch tested?
    
    Enabled the related tests.
    
    Closes #40402 from ueshin/issues/SPARK-42020/udt.
    
    Authored-by: Takuya UESHIN <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../src/main/protobuf/spark/connect/base.proto     |   3 +
 .../connect/common/DataTypeProtoConverter.scala    |   1 +
 .../sql/connect/planner/SparkConnectPlanner.scala  |  34 +++++--
 .../service/SparkConnectStreamHandler.scala        |  12 +++
 .../connect/planner/SparkConnectServiceSuite.scala |  18 +++-
 python/pyspark/sql/connect/client.py               |  27 ++++--
 python/pyspark/sql/connect/conversion.py           |  31 ++++++
 python/pyspark/sql/connect/dataframe.py            |   4 +-
 python/pyspark/sql/connect/proto/base_pb2.py       | 108 ++++++++++-----------
 python/pyspark/sql/connect/proto/base_pb2.pyi      |   9 ++
 python/pyspark/sql/connect/session.py              |  13 ++-
 python/pyspark/sql/connect/types.py                |   2 +
 .../pyspark/sql/tests/connect/test_parity_types.py |  56 +----------
 python/pyspark/sql/tests/test_types.py             |   8 +-
 .../org/apache/spark/sql/util/ArrowUtils.scala     |   1 +
 15 files changed, 188 insertions(+), 139 deletions(-)

diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/base.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
index 809a2dc5cbf..da0f974a749 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
@@ -272,6 +272,9 @@ message ExecutePlanResponse {
   // The metrics observed during the execution of the query plan.
   repeated ObservedMetrics observed_metrics = 6;
 
+  // (Optional) The Spark schema. This field is available when `collect` is 
called.
+  DataType schema = 7;
+
   // A SQL command returns an opaque Relation that can be directly used as 
input for the next
   // call.
   message SqlCommandResult {
diff --git 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
index c30ea8c8301..28ddbe844d4 100644
--- 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
@@ -335,6 +335,7 @@ object DataTypeProtoConverter {
               .setType("udt")
               .setPythonClass(pyudt.pyUDT)
               .setSqlType(toConnectProtoType(pyudt.sqlType))
+              .setSerializedPythonClass(pyudt.serializedPyClass)
               .build())
           .build()
 
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 20db252c057..b023adac98a 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, 
ParseException, ParserUtils}
 import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, 
LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
 import org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.plans.logical.{CollectMetrics, 
CommandResult, Deduplicate, Except, Intersect, LocalRelation, LogicalPlan, 
Sample, Sort, SubqueryAlias, Union, Unpivot, UnresolvedHint}
+import org.apache.spark.sql.catalyst.plans.logical.{CollectMetrics, 
CommandResult, Deduplicate, Except, Intersect, LocalRelation, LogicalPlan, 
Project, Sample, Sort, SubqueryAlias, Union, Unpivot, UnresolvedHint}
 import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, 
CharVarcharUtils}
 import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, 
InvalidPlanInput, UdfPacket}
 import 
org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE
@@ -676,16 +676,36 @@ class SparkConnectPlanner(val session: SparkSession) {
       }
       val attributes = structType.toAttributes
       val proj = UnsafeProjection.create(attributes, attributes)
-      val relation = logical.LocalRelation(attributes, rows.map(r => 
proj(r).copy()).toSeq)
+      val data = rows.map(proj)
 
       if (schema == null) {
-        relation
+        logical.LocalRelation(attributes, data.map(_.copy()).toSeq)
       } else {
-        Dataset
-          .ofRows(session, logicalPlan = relation)
-          .toDF(schema.names: _*)
-          .to(schema)
+        def udtToSqlType(dt: DataType): DataType = dt match {
+          case udt: UserDefinedType[_] => udt.sqlType
+          case StructType(fields) =>
+            val newFields = fields.map { case StructField(name, dataType, 
nullable, metadata) =>
+              StructField(name, udtToSqlType(dataType), nullable, metadata)
+            }
+            StructType(newFields)
+          case ArrayType(elementType, containsNull) =>
+            ArrayType(udtToSqlType(elementType), containsNull)
+          case MapType(keyType, valueType, valueContainsNull) =>
+            MapType(udtToSqlType(keyType), udtToSqlType(valueType), 
valueContainsNull)
+          case _ => dt
+        }
+
+        val sqlTypeOnlySchema = udtToSqlType(schema).asInstanceOf[StructType]
+
+        val project = Dataset
+          .ofRows(session, logicalPlan = logical.LocalRelation(attributes))
+          .toDF(sqlTypeOnlySchema.names: _*)
+          .to(sqlTypeOnlySchema)
           .logicalPlan
+          .asInstanceOf[Project]
+
+        val proj = UnsafeProjection.create(project.projectList, 
project.child.output)
+        logical.LocalRelation(schema.toAttributes, 
data.map(proj).map(_.copy()).toSeq)
       }
     } else {
       if (schema == null) {
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
index 104d840ed52..335b871d499 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
@@ -28,6 +28,7 @@ import org.apache.spark.connect.proto.{ExecutePlanRequest, 
ExecutePlanResponse}
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connect.common.DataTypeProtoConverter
 import 
org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto
 import 
org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE
 import org.apache.spark.sql.connect.planner.SparkConnectPlanner
@@ -60,6 +61,8 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[ExecutePlanResp
     // Extract the plan from the request and convert it to a logical plan
     val planner = new SparkConnectPlanner(session)
     val dataframe = Dataset.ofRows(session, 
planner.transformRelation(request.getPlan.getRoot))
+    responseObserver.onNext(
+      SparkConnectStreamHandler.sendSchemaToResponse(request.getSessionId, 
dataframe.schema))
     processAsArrowBatches(request.getSessionId, dataframe, responseObserver)
     responseObserver.onNext(
       SparkConnectStreamHandler.sendMetricsToResponse(request.getSessionId, 
dataframe))
@@ -203,6 +206,15 @@ object SparkConnectStreamHandler {
     }
   }
 
+  def sendSchemaToResponse(sessionId: String, schema: StructType): 
ExecutePlanResponse = {
+    // Send the Spark data type
+    ExecutePlanResponse
+      .newBuilder()
+      .setSessionId(sessionId)
+      .setSchema(DataTypeProtoConverter.toConnectProtoType(schema))
+      .build()
+  }
+
   def sendMetricsToResponse(sessionId: String, rows: DataFrame): 
ExecutePlanResponse = {
     // Send a last batch with the metrics
     ExecutePlanResponse
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
index e2aecaaea86..c36ba76f984 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
@@ -160,18 +160,22 @@ class SparkConnectServiceSuite extends SharedSparkSession 
{
     assert(done)
 
     // 4 Partitions + Metrics
-    assert(responses.size == 5)
+    assert(responses.size == 6)
+
+    // Make sure the first response is schema only
+    val head = responses.head
+    assert(head.hasSchema && !head.hasArrowBatch && !head.hasMetrics)
 
     // Make sure the last response is metrics only
     val last = responses.last
-    assert(last.hasMetrics && !last.hasArrowBatch)
+    assert(last.hasMetrics && !last.hasSchema && !last.hasArrowBatch)
 
     val allocator = new RootAllocator()
 
     // Check the 'data' batches
     var expectedId = 0L
     var previousEId = 0.0d
-    responses.dropRight(1).foreach { response =>
+    responses.tail.dropRight(1).foreach { response =>
       assert(response.hasArrowBatch)
       val batch = response.getArrowBatch
       assert(batch.getData != null)
@@ -347,11 +351,15 @@ class SparkConnectServiceSuite extends SharedSparkSession 
{
       // The current implementation is expected to be blocking. This is here 
to make sure it is.
       assert(done)
 
-      assert(responses.size == 6)
+      assert(responses.size == 7)
+
+      // Make sure the first response is schema only
+      val head = responses.head
+      assert(head.hasSchema && !head.hasArrowBatch && !head.hasMetrics)
 
       // Make sure the last response is observed metrics only
       val last = responses.last
-      assert(last.getObservedMetricsCount == 1 && !last.hasArrowBatch)
+      assert(last.getObservedMetricsCount == 1 && !last.hasSchema && 
!last.hasArrowBatch)
 
       val observedMetricsList = last.getObservedMetricsList.asScala
       val observedMetric = observedMetricsList.head
diff --git a/python/pyspark/sql/connect/client.py 
b/python/pyspark/sql/connect/client.py
index 8dd80a931b9..090d239fbb4 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -619,16 +619,16 @@ class SparkConnectClient(object):
             for x in metrics
         ]
 
-    def to_table(self, plan: pb2.Plan) -> "pa.Table":
+    def to_table(self, plan: pb2.Plan) -> Tuple["pa.Table", 
Optional[StructType]]:
         """
         Return given plan as a PyArrow Table.
         """
         logger.info(f"Executing plan {self._proto_to_string(plan)}")
         req = self._execute_plan_request_with_metadata()
         req.plan.CopyFrom(plan)
-        table, _, _, _3 = self._execute_and_fetch(req)
+        table, schema, _, _, _ = self._execute_and_fetch(req)
         assert table is not None
-        return table
+        return table, schema
 
     def to_pandas(self, plan: pb2.Plan) -> "pd.DataFrame":
         """
@@ -637,7 +637,7 @@ class SparkConnectClient(object):
         logger.info(f"Executing plan {self._proto_to_string(plan)}")
         req = self._execute_plan_request_with_metadata()
         req.plan.CopyFrom(plan)
-        table, metrics, observed_metrics, _ = self._execute_and_fetch(req)
+        table, _, metrics, observed_metrics, _ = self._execute_and_fetch(req)
         assert table is not None
         column_names = table.column_names
         table = table.rename_columns([f"col_{i}" for i in 
range(len(column_names))])
@@ -696,7 +696,7 @@ class SparkConnectClient(object):
         if self._user_id:
             req.user_context.user_id = self._user_id
         req.plan.command.CopyFrom(command)
-        data, _, _, properties = self._execute_and_fetch(req)
+        data, _, _, _, properties = self._execute_and_fetch(req)
         if data is not None:
             return (data.to_pandas(), properties)
         else:
@@ -844,12 +844,19 @@ class SparkConnectClient(object):
 
     def _execute_and_fetch(
         self, req: pb2.ExecutePlanRequest
-    ) -> Tuple[Optional["pa.Table"], List[PlanMetrics], 
List[PlanObservedMetrics], Dict[str, Any]]:
+    ) -> Tuple[
+        Optional["pa.Table"],
+        Optional[StructType],
+        List[PlanMetrics],
+        List[PlanObservedMetrics],
+        Dict[str, Any],
+    ]:
         logger.info("ExecuteAndFetch")
 
         m: Optional[pb2.ExecutePlanResponse.Metrics] = None
         om: List[pb2.ExecutePlanResponse.ObservedMetrics] = []
         batches: List[pa.RecordBatch] = []
+        schema: Optional[StructType] = None
         properties = {}
         try:
             for attempt in Retrying(
@@ -869,6 +876,10 @@ class SparkConnectClient(object):
                         if b.observed_metrics is not None:
                             logger.debug("Received observed metric batch.")
                             om.extend(b.observed_metrics)
+                        if b.HasField("schema"):
+                            dt = 
types.proto_schema_to_pyspark_data_type(b.schema)
+                            assert isinstance(dt, StructType)
+                            schema = dt
                         if b.HasField("sql_command_result"):
                             properties["sql_command_result"] = 
b.sql_command_result.relation
                         if b.HasField("arrow_batch"):
@@ -888,9 +899,9 @@ class SparkConnectClient(object):
 
         if len(batches) > 0:
             table = pa.Table.from_batches(batches=batches)
-            return table, metrics, observed_metrics, properties
+            return table, schema, metrics, observed_metrics, properties
         else:
-            return None, metrics, observed_metrics, properties
+            return None, schema, metrics, observed_metrics, properties
 
     def _config_request_with_metadata(self) -> pb2.ConfigRequest:
         req = pb2.ConfigRequest()
diff --git a/python/pyspark/sql/connect/conversion.py 
b/python/pyspark/sql/connect/conversion.py
index 2b16fc7766d..ba488d4d04e 100644
--- a/python/pyspark/sql/connect/conversion.py
+++ b/python/pyspark/sql/connect/conversion.py
@@ -37,6 +37,7 @@ from pyspark.sql.types import (
     NullType,
     DecimalType,
     StringType,
+    UserDefinedType,
 )
 
 from pyspark.sql.connect.types import to_arrow_schema
@@ -79,6 +80,8 @@ class LocalDataToArrowConversion:
         elif isinstance(dataType, StringType):
             # Coercion to StringType is allowed
             return True
+        elif isinstance(dataType, UserDefinedType):
+            return True
         else:
             return False
 
@@ -229,6 +232,19 @@ class LocalDataToArrowConversion:
 
             return convert_string
 
+        elif isinstance(dataType, UserDefinedType):
+            udt: UserDefinedType = dataType
+
+            conv = 
LocalDataToArrowConversion._create_converter(dataType.sqlType())
+
+            def convert_udt(value: Any) -> Any:
+                if value is None:
+                    return None
+                else:
+                    return conv(udt.serialize(value))
+
+            return convert_udt
+
         else:
 
             return lambda value: value
@@ -286,6 +302,8 @@ class ArrowTableToRowsConversion:
         elif isinstance(dataType, (TimestampType, TimestampNTZType)):
             # Always remove the time zone info for now
             return True
+        elif isinstance(dataType, UserDefinedType):
+            return True
         else:
             return False
 
@@ -380,6 +398,19 @@ class ArrowTableToRowsConversion:
 
             return convert_timestample
 
+        elif isinstance(dataType, UserDefinedType):
+            udt: UserDefinedType = dataType
+
+            conv = 
ArrowTableToRowsConversion._create_converter(dataType.sqlType())
+
+            def convert_udt(value: Any) -> Any:
+                if value is None:
+                    return None
+                else:
+                    return udt.deserialize(conv(value))
+
+            return convert_udt
+
         else:
 
             return lambda value: value
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index f1968bc0ad9..03df5781197 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -1348,9 +1348,9 @@ class DataFrame:
         if self._session is None:
             raise Exception("Cannot collect on empty session.")
         query = self._plan.to_proto(self._session.client)
-        table = self._session.client.to_table(query)
+        table, schema = self._session.client.to_table(query)
 
-        schema = from_arrow_schema(table.schema)
+        schema = schema or from_arrow_schema(table.schema)
 
         assert schema is not None and isinstance(schema, StructType)
 
diff --git a/python/pyspark/sql/connect/proto/base_pb2.py 
b/python/pyspark/sql/connect/proto/base_pb2.py
index 5a7a10e78ca..36557344893 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.py
+++ b/python/pyspark/sql/connect/proto/base_pb2.py
@@ -37,7 +37,7 @@ from pyspark.sql.connect.proto import types_pb2 as 
spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
 
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
 
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01
 \x01(\tR\x06 [...]
+    
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
 
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
 
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01
 \x01(\tR\x06 [...]
 )
 
 
@@ -690,57 +690,57 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _EXECUTEPLANREQUEST._serialized_start = 3573
     _EXECUTEPLANREQUEST._serialized_end = 3782
     _EXECUTEPLANRESPONSE._serialized_start = 3785
-    _EXECUTEPLANRESPONSE._serialized_end = 5011
-    _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_start = 4242
-    _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_end = 4313
-    _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 4315
-    _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 4376
-    _EXECUTEPLANRESPONSE_METRICS._serialized_start = 4379
-    _EXECUTEPLANRESPONSE_METRICS._serialized_end = 4896
-    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 4474
-    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 4806
-    
_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start
 = 4683
-    
_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end 
= 4806
-    _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 4808
-    _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 4896
-    _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_start = 4898
-    _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_end = 4994
-    _KEYVALUE._serialized_start = 5013
-    _KEYVALUE._serialized_end = 5078
-    _CONFIGREQUEST._serialized_start = 5081
-    _CONFIGREQUEST._serialized_end = 6109
-    _CONFIGREQUEST_OPERATION._serialized_start = 5301
-    _CONFIGREQUEST_OPERATION._serialized_end = 5799
-    _CONFIGREQUEST_SET._serialized_start = 5801
-    _CONFIGREQUEST_SET._serialized_end = 5853
-    _CONFIGREQUEST_GET._serialized_start = 5855
-    _CONFIGREQUEST_GET._serialized_end = 5880
-    _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 5882
-    _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 5945
-    _CONFIGREQUEST_GETOPTION._serialized_start = 5947
-    _CONFIGREQUEST_GETOPTION._serialized_end = 5978
-    _CONFIGREQUEST_GETALL._serialized_start = 5980
-    _CONFIGREQUEST_GETALL._serialized_end = 6028
-    _CONFIGREQUEST_UNSET._serialized_start = 6030
-    _CONFIGREQUEST_UNSET._serialized_end = 6057
-    _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 6059
-    _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 6093
-    _CONFIGRESPONSE._serialized_start = 6111
-    _CONFIGRESPONSE._serialized_end = 6233
-    _ADDARTIFACTSREQUEST._serialized_start = 6236
-    _ADDARTIFACTSREQUEST._serialized_end = 7107
-    _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_start = 6623
-    _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_end = 6676
-    _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_start = 6678
-    _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_end = 6789
-    _ADDARTIFACTSREQUEST_BATCH._serialized_start = 6791
-    _ADDARTIFACTSREQUEST_BATCH._serialized_end = 6884
-    _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_start = 6887
-    _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_end = 7080
-    _ADDARTIFACTSRESPONSE._serialized_start = 7110
-    _ADDARTIFACTSRESPONSE._serialized_end = 7298
-    _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_start = 7217
-    _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_end = 7298
-    _SPARKCONNECTSERVICE._serialized_start = 7301
-    _SPARKCONNECTSERVICE._serialized_end = 7666
+    _EXECUTEPLANRESPONSE._serialized_end = 5060
+    _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_start = 4291
+    _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_end = 4362
+    _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 4364
+    _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 4425
+    _EXECUTEPLANRESPONSE_METRICS._serialized_start = 4428
+    _EXECUTEPLANRESPONSE_METRICS._serialized_end = 4945
+    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 4523
+    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 4855
+    
_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start
 = 4732
+    
_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end 
= 4855
+    _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 4857
+    _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 4945
+    _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_start = 4947
+    _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_end = 5043
+    _KEYVALUE._serialized_start = 5062
+    _KEYVALUE._serialized_end = 5127
+    _CONFIGREQUEST._serialized_start = 5130
+    _CONFIGREQUEST._serialized_end = 6158
+    _CONFIGREQUEST_OPERATION._serialized_start = 5350
+    _CONFIGREQUEST_OPERATION._serialized_end = 5848
+    _CONFIGREQUEST_SET._serialized_start = 5850
+    _CONFIGREQUEST_SET._serialized_end = 5902
+    _CONFIGREQUEST_GET._serialized_start = 5904
+    _CONFIGREQUEST_GET._serialized_end = 5929
+    _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 5931
+    _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 5994
+    _CONFIGREQUEST_GETOPTION._serialized_start = 5996
+    _CONFIGREQUEST_GETOPTION._serialized_end = 6027
+    _CONFIGREQUEST_GETALL._serialized_start = 6029
+    _CONFIGREQUEST_GETALL._serialized_end = 6077
+    _CONFIGREQUEST_UNSET._serialized_start = 6079
+    _CONFIGREQUEST_UNSET._serialized_end = 6106
+    _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 6108
+    _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 6142
+    _CONFIGRESPONSE._serialized_start = 6160
+    _CONFIGRESPONSE._serialized_end = 6282
+    _ADDARTIFACTSREQUEST._serialized_start = 6285
+    _ADDARTIFACTSREQUEST._serialized_end = 7156
+    _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_start = 6672
+    _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_end = 6725
+    _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_start = 6727
+    _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_end = 6838
+    _ADDARTIFACTSREQUEST_BATCH._serialized_start = 6840
+    _ADDARTIFACTSREQUEST_BATCH._serialized_end = 6933
+    _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_start = 6936
+    _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_end = 7129
+    _ADDARTIFACTSRESPONSE._serialized_start = 7159
+    _ADDARTIFACTSRESPONSE._serialized_end = 7347
+    _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_start = 7266
+    _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_end = 7347
+    _SPARKCONNECTSERVICE._serialized_start = 7350
+    _SPARKCONNECTSERVICE._serialized_end = 7715
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi 
b/python/pyspark/sql/connect/proto/base_pb2.pyi
index 5ff14c9ac93..4c020308d9a 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/base_pb2.pyi
@@ -1056,6 +1056,7 @@ class 
ExecutePlanResponse(google.protobuf.message.Message):
     EXTENSION_FIELD_NUMBER: builtins.int
     METRICS_FIELD_NUMBER: builtins.int
     OBSERVED_METRICS_FIELD_NUMBER: builtins.int
+    SCHEMA_FIELD_NUMBER: builtins.int
     session_id: builtins.str
     @property
     def arrow_batch(self) -> global___ExecutePlanResponse.ArrowBatch: ...
@@ -1077,6 +1078,9 @@ class 
ExecutePlanResponse(google.protobuf.message.Message):
         global___ExecutePlanResponse.ObservedMetrics
     ]:
         """The metrics observed during the execution of the query plan."""
+    @property
+    def schema(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
+        """(Optional) The Spark schema. This field is available when `collect` 
is called."""
     def __init__(
         self,
         *,
@@ -1087,6 +1091,7 @@ class 
ExecutePlanResponse(google.protobuf.message.Message):
         metrics: global___ExecutePlanResponse.Metrics | None = ...,
         observed_metrics: 
collections.abc.Iterable[global___ExecutePlanResponse.ObservedMetrics]
         | None = ...,
+        schema: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
     ) -> None: ...
     def HasField(
         self,
@@ -1099,6 +1104,8 @@ class 
ExecutePlanResponse(google.protobuf.message.Message):
             b"metrics",
             "response_type",
             b"response_type",
+            "schema",
+            b"schema",
             "sql_command_result",
             b"sql_command_result",
         ],
@@ -1116,6 +1123,8 @@ class 
ExecutePlanResponse(google.protobuf.message.Message):
             b"observed_metrics",
             "response_type",
             b"response_type",
+            "schema",
+            b"schema",
             "session_id",
             b"session_id",
             "sql_command_result",
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index 4cf4d0341bb..8fe5020f4a4 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -342,20 +342,23 @@ class SparkSession:
                     # For cases like createDataFrame([("Alice", None, 80.1)], 
schema)
                     # we can not infer the schema from the data itself.
                     warnings.warn("failed to infer the schema from data")
-                    if _schema is None and _schema_str is not None:
+                    if _schema_str is not None:
                         _parsed = self.client._analyze(
                             method="ddl_parse", ddl_string=_schema_str
                         ).parsed
                         if isinstance(_parsed, StructType):
-                            _schema = _parsed
+                            _inferred_schema = _parsed
                         elif isinstance(_parsed, DataType):
-                            _schema = StructType().add("value", _parsed)
-                    if _schema is None or not isinstance(_schema, StructType):
+                            _inferred_schema = StructType().add("value", 
_parsed)
+                        _schema_str = None
+                    if _inferred_schema is None or not 
isinstance(_inferred_schema, StructType):
                         raise ValueError(
                             "Some of types cannot be determined after 
inferring, "
                             "a StructType Schema is required in this case"
                         )
-                    _inferred_schema = _schema
+
+                if _schema_str is None and _cols is None:
+                    _schema = _inferred_schema
 
             from pyspark.sql.connect.conversion import 
LocalDataToArrowConversion
 
diff --git a/python/pyspark/sql/connect/types.py 
b/python/pyspark/sql/connect/types.py
index b5145d91c76..dfb0fb5303f 100644
--- a/python/pyspark/sql/connect/types.py
+++ b/python/pyspark/sql/connect/types.py
@@ -323,6 +323,8 @@ def to_arrow_type(dt: DataType) -> "pa.DataType":
         arrow_type = pa.struct(fields)
     elif type(dt) == NullType:
         arrow_type = pa.null()
+    elif isinstance(dt, UserDefinedType):
+        arrow_type = to_arrow_type(dt.sqlType())
     else:
         raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
     return arrow_type
diff --git a/python/pyspark/sql/tests/connect/test_parity_types.py 
b/python/pyspark/sql/tests/connect/test_parity_types.py
index 67d5a17660e..a2f81fbf25e 100644
--- a/python/pyspark/sql/tests/connect/test_parity_types.py
+++ b/python/pyspark/sql/tests/connect/test_parity_types.py
@@ -34,26 +34,6 @@ class TypesParityTests(TypesTestsMixin, 
ReusedConnectTestCase):
     def test_apply_schema_to_row(self):
         super().test_apply_schema_to_dict_and_rows()
 
-    # TODO(SPARK-42020): createDataFrame with UDT
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_apply_schema_with_udt(self):
-        super().test_apply_schema_with_udt()
-
-    # TODO(SPARK-42020): createDataFrame with UDT
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_cast_to_string_with_udt(self):
-        super().test_cast_to_string_with_udt()
-
-    # TODO(SPARK-42020): createDataFrame with UDT
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_cast_to_udt_with_udt(self):
-        super().test_cast_to_udt_with_udt()
-
-    # TODO(SPARK-42020): createDataFrame with UDT
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_complex_nested_udt_in_df(self):
-        super().test_complex_nested_udt_in_df()
-
     @unittest.skip("Spark Connect does not support RDD but the tests depend on 
them.")
     def test_create_dataframe_schema_mismatch(self):
         super().test_create_dataframe_schema_mismatch()
@@ -103,46 +83,14 @@ class TypesParityTests(TypesTestsMixin, 
ReusedConnectTestCase):
     def test_infer_schema_upcast_int_to_string(self):
         super().test_infer_schema_upcast_int_to_string()
 
-    # TODO(SPARK-42020): createDataFrame with UDT
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_infer_schema_with_udt(self):
-        super().test_infer_schema_with_udt()
-
-    # TODO(SPARK-42020): createDataFrame with UDT
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_nested_udt_in_df(self):
-        super().test_nested_udt_in_df()
-
-    # TODO(SPARK-42020): createDataFrame with UDT
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_parquet_with_udt(self):
-        super().test_parquet_with_udt()
-
-    # TODO(SPARK-42020): createDataFrame with UDT
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_simple_udt_in_df(self):
-        super().test_simple_udt_in_df()
-
-    # TODO(SPARK-42020): createDataFrame with UDT
-    @unittest.skip("Fails in Spark Connect, should enable.")
+    @unittest.skip("Spark Connect does not support RDD but the tests depend on 
them.")
     def test_udf_with_udt(self):
         super().test_udf_with_udt()
 
-    # TODO(SPARK-42020): createDataFrame with UDT
-    @unittest.skip("Fails in Spark Connect, should enable.")
+    @unittest.skip("Requires JVM access.")
     def test_udt(self):
         super().test_udt()
 
-    # TODO(SPARK-42020): createDataFrame with UDT
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_udt_with_none(self):
-        super().test_udt_with_none()
-
-    # TODO(SPARK-42020): createDataFrame with UDT
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_union_with_udt(self):
-        super().test_union_with_udt()
-
 
 if __name__ == "__main__":
     import unittest
diff --git a/python/pyspark/sql/tests/test_types.py 
b/python/pyspark/sql/tests/test_types.py
index 9db090fa810..aaac43cdf67 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -692,10 +692,10 @@ class TypesTestsMixin:
 
         row = Row(point=ExamplePoint(1.0, 2.0), 
python_only_point=PythonOnlyPoint(1.0, 2.0))
         df = self.spark.createDataFrame([row])
-        self.assertRaises(AnalysisException, lambda: 
df.select(col("point").cast(PythonOnlyUDT())))
-        self.assertRaises(
-            AnalysisException, lambda: 
df.select(col("python_only_point").cast(ExamplePointUDT()))
-        )
+        with self.assertRaises(AnalysisException):
+            df.select(col("point").cast(PythonOnlyUDT())).collect()
+        with self.assertRaises(AnalysisException):
+            
df.select(col("python_only_point").cast(ExamplePointUDT())).collect()
 
     def test_struct_type(self):
         struct1 = StructType().add("f1", StringType(), True).add("f2", 
StringType(), True, None)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
index 6c6635bac57..d6a8fec81dd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
@@ -106,6 +106,7 @@ private[sql] object ArrowUtils {
               .add(MapVector.VALUE_NAME, valueType, nullable = 
valueContainsNull),
             nullable = false,
             timeZoneId)).asJava)
+      case udt: UserDefinedType[_] => toArrowField(name, udt.sqlType, 
nullable, timeZoneId)
       case dataType =>
         val fieldType = new FieldType(nullable, toArrowType(dataType, 
timeZoneId), null)
         new Field(name, fieldType, Seq.empty[Field].asJava)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to