This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new dc6cf74dc2d [SPARK-42020][CONNECT][PYTHON] Support UserDefinedType in
Spark Connect
dc6cf74dc2d is described below
commit dc6cf74dc2db54d935cf54cb3e4829a468dcdf78
Author: Takuya UESHIN <[email protected]>
AuthorDate: Mon Mar 20 09:24:26 2023 +0900
[SPARK-42020][CONNECT][PYTHON] Support UserDefinedType in Spark Connect
### What changes were proposed in this pull request?
Supports `UserDefinedType` in Spark Connect.
### Why are the changes needed?
Currently Spark Connect doesn't support UDTs.
### Does this PR introduce _any_ user-facing change?
Yes, UDTs will be available in Spark Connect.
### How was this patch tested?
Enabled the related tests.
Closes #40402 from ueshin/issues/SPARK-42020/udt.
Authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../src/main/protobuf/spark/connect/base.proto | 3 +
.../connect/common/DataTypeProtoConverter.scala | 1 +
.../sql/connect/planner/SparkConnectPlanner.scala | 34 +++++--
.../service/SparkConnectStreamHandler.scala | 12 +++
.../connect/planner/SparkConnectServiceSuite.scala | 18 +++-
python/pyspark/sql/connect/client.py | 27 ++++--
python/pyspark/sql/connect/conversion.py | 31 ++++++
python/pyspark/sql/connect/dataframe.py | 4 +-
python/pyspark/sql/connect/proto/base_pb2.py | 108 ++++++++++-----------
python/pyspark/sql/connect/proto/base_pb2.pyi | 9 ++
python/pyspark/sql/connect/session.py | 13 ++-
python/pyspark/sql/connect/types.py | 2 +
.../pyspark/sql/tests/connect/test_parity_types.py | 56 +----------
python/pyspark/sql/tests/test_types.py | 8 +-
.../org/apache/spark/sql/util/ArrowUtils.scala | 1 +
15 files changed, 188 insertions(+), 139 deletions(-)
diff --git
a/connector/connect/common/src/main/protobuf/spark/connect/base.proto
b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
index 809a2dc5cbf..da0f974a749 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
@@ -272,6 +272,9 @@ message ExecutePlanResponse {
// The metrics observed during the execution of the query plan.
repeated ObservedMetrics observed_metrics = 6;
+ // (Optional) The Spark schema. This field is available when `collect` is
called.
+ DataType schema = 7;
+
// A SQL command returns an opaque Relation that can be directly used as
input for the next
// call.
message SqlCommandResult {
diff --git
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
index c30ea8c8301..28ddbe844d4 100644
---
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
+++
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
@@ -335,6 +335,7 @@ object DataTypeProtoConverter {
.setType("udt")
.setPythonClass(pyudt.pyUDT)
.setSqlType(toConnectProtoType(pyudt.sqlType))
+ .setSerializedPythonClass(pyudt.serializedPyClass)
.build())
.build()
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 20db252c057..b023adac98a 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser,
ParseException, ParserUtils}
import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType,
LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
import org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.plans.logical.{CollectMetrics,
CommandResult, Deduplicate, Except, Intersect, LocalRelation, LogicalPlan,
Sample, Sort, SubqueryAlias, Union, Unpivot, UnresolvedHint}
+import org.apache.spark.sql.catalyst.plans.logical.{CollectMetrics,
CommandResult, Deduplicate, Except, Intersect, LocalRelation, LogicalPlan,
Project, Sample, Sort, SubqueryAlias, Union, Unpivot, UnresolvedHint}
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap,
CharVarcharUtils}
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter,
InvalidPlanInput, UdfPacket}
import
org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE
@@ -676,16 +676,36 @@ class SparkConnectPlanner(val session: SparkSession) {
}
val attributes = structType.toAttributes
val proj = UnsafeProjection.create(attributes, attributes)
- val relation = logical.LocalRelation(attributes, rows.map(r =>
proj(r).copy()).toSeq)
+ val data = rows.map(proj)
if (schema == null) {
- relation
+ logical.LocalRelation(attributes, data.map(_.copy()).toSeq)
} else {
- Dataset
- .ofRows(session, logicalPlan = relation)
- .toDF(schema.names: _*)
- .to(schema)
+ def udtToSqlType(dt: DataType): DataType = dt match {
+ case udt: UserDefinedType[_] => udt.sqlType
+ case StructType(fields) =>
+ val newFields = fields.map { case StructField(name, dataType,
nullable, metadata) =>
+ StructField(name, udtToSqlType(dataType), nullable, metadata)
+ }
+ StructType(newFields)
+ case ArrayType(elementType, containsNull) =>
+ ArrayType(udtToSqlType(elementType), containsNull)
+ case MapType(keyType, valueType, valueContainsNull) =>
+ MapType(udtToSqlType(keyType), udtToSqlType(valueType),
valueContainsNull)
+ case _ => dt
+ }
+
+ val sqlTypeOnlySchema = udtToSqlType(schema).asInstanceOf[StructType]
+
+ val project = Dataset
+ .ofRows(session, logicalPlan = logical.LocalRelation(attributes))
+ .toDF(sqlTypeOnlySchema.names: _*)
+ .to(sqlTypeOnlySchema)
.logicalPlan
+ .asInstanceOf[Project]
+
+ val proj = UnsafeProjection.create(project.projectList,
project.child.output)
+ logical.LocalRelation(schema.toAttributes,
data.map(proj).map(_.copy()).toSeq)
}
} else {
if (schema == null) {
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
index 104d840ed52..335b871d499 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
@@ -28,6 +28,7 @@ import org.apache.spark.connect.proto.{ExecutePlanRequest,
ExecutePlanResponse}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connect.common.DataTypeProtoConverter
import
org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto
import
org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
@@ -60,6 +61,8 @@ class SparkConnectStreamHandler(responseObserver:
StreamObserver[ExecutePlanResp
// Extract the plan from the request and convert it to a logical plan
val planner = new SparkConnectPlanner(session)
val dataframe = Dataset.ofRows(session,
planner.transformRelation(request.getPlan.getRoot))
+ responseObserver.onNext(
+ SparkConnectStreamHandler.sendSchemaToResponse(request.getSessionId,
dataframe.schema))
processAsArrowBatches(request.getSessionId, dataframe, responseObserver)
responseObserver.onNext(
SparkConnectStreamHandler.sendMetricsToResponse(request.getSessionId,
dataframe))
@@ -203,6 +206,15 @@ object SparkConnectStreamHandler {
}
}
+ def sendSchemaToResponse(sessionId: String, schema: StructType):
ExecutePlanResponse = {
+ // Send the Spark data type
+ ExecutePlanResponse
+ .newBuilder()
+ .setSessionId(sessionId)
+ .setSchema(DataTypeProtoConverter.toConnectProtoType(schema))
+ .build()
+ }
+
def sendMetricsToResponse(sessionId: String, rows: DataFrame):
ExecutePlanResponse = {
// Send a last batch with the metrics
ExecutePlanResponse
diff --git
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
index e2aecaaea86..c36ba76f984 100644
---
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
+++
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
@@ -160,18 +160,22 @@ class SparkConnectServiceSuite extends SharedSparkSession
{
assert(done)
// 4 Partitions + Metrics
- assert(responses.size == 5)
+ assert(responses.size == 6)
+
+ // Make sure the first response is schema only
+ val head = responses.head
+ assert(head.hasSchema && !head.hasArrowBatch && !head.hasMetrics)
// Make sure the last response is metrics only
val last = responses.last
- assert(last.hasMetrics && !last.hasArrowBatch)
+ assert(last.hasMetrics && !last.hasSchema && !last.hasArrowBatch)
val allocator = new RootAllocator()
// Check the 'data' batches
var expectedId = 0L
var previousEId = 0.0d
- responses.dropRight(1).foreach { response =>
+ responses.tail.dropRight(1).foreach { response =>
assert(response.hasArrowBatch)
val batch = response.getArrowBatch
assert(batch.getData != null)
@@ -347,11 +351,15 @@ class SparkConnectServiceSuite extends SharedSparkSession
{
// The current implementation is expected to be blocking. This is here
to make sure it is.
assert(done)
- assert(responses.size == 6)
+ assert(responses.size == 7)
+
+ // Make sure the first response is schema only
+ val head = responses.head
+ assert(head.hasSchema && !head.hasArrowBatch && !head.hasMetrics)
// Make sure the last response is observed metrics only
val last = responses.last
- assert(last.getObservedMetricsCount == 1 && !last.hasArrowBatch)
+ assert(last.getObservedMetricsCount == 1 && !last.hasSchema &&
!last.hasArrowBatch)
val observedMetricsList = last.getObservedMetricsList.asScala
val observedMetric = observedMetricsList.head
diff --git a/python/pyspark/sql/connect/client.py
b/python/pyspark/sql/connect/client.py
index 8dd80a931b9..090d239fbb4 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -619,16 +619,16 @@ class SparkConnectClient(object):
for x in metrics
]
- def to_table(self, plan: pb2.Plan) -> "pa.Table":
+ def to_table(self, plan: pb2.Plan) -> Tuple["pa.Table",
Optional[StructType]]:
"""
Return given plan as a PyArrow Table.
"""
logger.info(f"Executing plan {self._proto_to_string(plan)}")
req = self._execute_plan_request_with_metadata()
req.plan.CopyFrom(plan)
- table, _, _, _3 = self._execute_and_fetch(req)
+ table, schema, _, _, _ = self._execute_and_fetch(req)
assert table is not None
- return table
+ return table, schema
def to_pandas(self, plan: pb2.Plan) -> "pd.DataFrame":
"""
@@ -637,7 +637,7 @@ class SparkConnectClient(object):
logger.info(f"Executing plan {self._proto_to_string(plan)}")
req = self._execute_plan_request_with_metadata()
req.plan.CopyFrom(plan)
- table, metrics, observed_metrics, _ = self._execute_and_fetch(req)
+ table, _, metrics, observed_metrics, _ = self._execute_and_fetch(req)
assert table is not None
column_names = table.column_names
table = table.rename_columns([f"col_{i}" for i in
range(len(column_names))])
@@ -696,7 +696,7 @@ class SparkConnectClient(object):
if self._user_id:
req.user_context.user_id = self._user_id
req.plan.command.CopyFrom(command)
- data, _, _, properties = self._execute_and_fetch(req)
+ data, _, _, _, properties = self._execute_and_fetch(req)
if data is not None:
return (data.to_pandas(), properties)
else:
@@ -844,12 +844,19 @@ class SparkConnectClient(object):
def _execute_and_fetch(
self, req: pb2.ExecutePlanRequest
- ) -> Tuple[Optional["pa.Table"], List[PlanMetrics],
List[PlanObservedMetrics], Dict[str, Any]]:
+ ) -> Tuple[
+ Optional["pa.Table"],
+ Optional[StructType],
+ List[PlanMetrics],
+ List[PlanObservedMetrics],
+ Dict[str, Any],
+ ]:
logger.info("ExecuteAndFetch")
m: Optional[pb2.ExecutePlanResponse.Metrics] = None
om: List[pb2.ExecutePlanResponse.ObservedMetrics] = []
batches: List[pa.RecordBatch] = []
+ schema: Optional[StructType] = None
properties = {}
try:
for attempt in Retrying(
@@ -869,6 +876,10 @@ class SparkConnectClient(object):
if b.observed_metrics is not None:
logger.debug("Received observed metric batch.")
om.extend(b.observed_metrics)
+ if b.HasField("schema"):
+ dt =
types.proto_schema_to_pyspark_data_type(b.schema)
+ assert isinstance(dt, StructType)
+ schema = dt
if b.HasField("sql_command_result"):
properties["sql_command_result"] =
b.sql_command_result.relation
if b.HasField("arrow_batch"):
@@ -888,9 +899,9 @@ class SparkConnectClient(object):
if len(batches) > 0:
table = pa.Table.from_batches(batches=batches)
- return table, metrics, observed_metrics, properties
+ return table, schema, metrics, observed_metrics, properties
else:
- return None, metrics, observed_metrics, properties
+ return None, schema, metrics, observed_metrics, properties
def _config_request_with_metadata(self) -> pb2.ConfigRequest:
req = pb2.ConfigRequest()
diff --git a/python/pyspark/sql/connect/conversion.py
b/python/pyspark/sql/connect/conversion.py
index 2b16fc7766d..ba488d4d04e 100644
--- a/python/pyspark/sql/connect/conversion.py
+++ b/python/pyspark/sql/connect/conversion.py
@@ -37,6 +37,7 @@ from pyspark.sql.types import (
NullType,
DecimalType,
StringType,
+ UserDefinedType,
)
from pyspark.sql.connect.types import to_arrow_schema
@@ -79,6 +80,8 @@ class LocalDataToArrowConversion:
elif isinstance(dataType, StringType):
# Coercion to StringType is allowed
return True
+ elif isinstance(dataType, UserDefinedType):
+ return True
else:
return False
@@ -229,6 +232,19 @@ class LocalDataToArrowConversion:
return convert_string
+ elif isinstance(dataType, UserDefinedType):
+ udt: UserDefinedType = dataType
+
+ conv =
LocalDataToArrowConversion._create_converter(dataType.sqlType())
+
+ def convert_udt(value: Any) -> Any:
+ if value is None:
+ return None
+ else:
+ return conv(udt.serialize(value))
+
+ return convert_udt
+
else:
return lambda value: value
@@ -286,6 +302,8 @@ class ArrowTableToRowsConversion:
elif isinstance(dataType, (TimestampType, TimestampNTZType)):
# Always remove the time zone info for now
return True
+ elif isinstance(dataType, UserDefinedType):
+ return True
else:
return False
@@ -380,6 +398,19 @@ class ArrowTableToRowsConversion:
return convert_timestample
+ elif isinstance(dataType, UserDefinedType):
+ udt: UserDefinedType = dataType
+
+ conv =
ArrowTableToRowsConversion._create_converter(dataType.sqlType())
+
+ def convert_udt(value: Any) -> Any:
+ if value is None:
+ return None
+ else:
+ return udt.deserialize(conv(value))
+
+ return convert_udt
+
else:
return lambda value: value
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index f1968bc0ad9..03df5781197 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -1348,9 +1348,9 @@ class DataFrame:
if self._session is None:
raise Exception("Cannot collect on empty session.")
query = self._plan.to_proto(self._session.client)
- table = self._session.client.to_table(query)
+ table, schema = self._session.client.to_table(query)
- schema = from_arrow_schema(table.schema)
+ schema = schema or from_arrow_schema(table.schema)
assert schema is not None and isinstance(schema, StructType)
diff --git a/python/pyspark/sql/connect/proto/base_pb2.py
b/python/pyspark/sql/connect/proto/base_pb2.py
index 5a7a10e78ca..36557344893 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.py
+++ b/python/pyspark/sql/connect/proto/base_pb2.py
@@ -37,7 +37,7 @@ from pyspark.sql.connect.proto import types_pb2 as
spark_dot_connect_dot_types__
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01
\x01(\tR\x06 [...]
+
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01
\x01(\tR\x06 [...]
)
@@ -690,57 +690,57 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_EXECUTEPLANREQUEST._serialized_start = 3573
_EXECUTEPLANREQUEST._serialized_end = 3782
_EXECUTEPLANRESPONSE._serialized_start = 3785
- _EXECUTEPLANRESPONSE._serialized_end = 5011
- _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_start = 4242
- _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_end = 4313
- _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 4315
- _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 4376
- _EXECUTEPLANRESPONSE_METRICS._serialized_start = 4379
- _EXECUTEPLANRESPONSE_METRICS._serialized_end = 4896
- _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 4474
- _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 4806
-
_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start
= 4683
-
_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end
= 4806
- _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 4808
- _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 4896
- _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_start = 4898
- _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_end = 4994
- _KEYVALUE._serialized_start = 5013
- _KEYVALUE._serialized_end = 5078
- _CONFIGREQUEST._serialized_start = 5081
- _CONFIGREQUEST._serialized_end = 6109
- _CONFIGREQUEST_OPERATION._serialized_start = 5301
- _CONFIGREQUEST_OPERATION._serialized_end = 5799
- _CONFIGREQUEST_SET._serialized_start = 5801
- _CONFIGREQUEST_SET._serialized_end = 5853
- _CONFIGREQUEST_GET._serialized_start = 5855
- _CONFIGREQUEST_GET._serialized_end = 5880
- _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 5882
- _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 5945
- _CONFIGREQUEST_GETOPTION._serialized_start = 5947
- _CONFIGREQUEST_GETOPTION._serialized_end = 5978
- _CONFIGREQUEST_GETALL._serialized_start = 5980
- _CONFIGREQUEST_GETALL._serialized_end = 6028
- _CONFIGREQUEST_UNSET._serialized_start = 6030
- _CONFIGREQUEST_UNSET._serialized_end = 6057
- _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 6059
- _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 6093
- _CONFIGRESPONSE._serialized_start = 6111
- _CONFIGRESPONSE._serialized_end = 6233
- _ADDARTIFACTSREQUEST._serialized_start = 6236
- _ADDARTIFACTSREQUEST._serialized_end = 7107
- _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_start = 6623
- _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_end = 6676
- _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_start = 6678
- _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_end = 6789
- _ADDARTIFACTSREQUEST_BATCH._serialized_start = 6791
- _ADDARTIFACTSREQUEST_BATCH._serialized_end = 6884
- _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_start = 6887
- _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_end = 7080
- _ADDARTIFACTSRESPONSE._serialized_start = 7110
- _ADDARTIFACTSRESPONSE._serialized_end = 7298
- _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_start = 7217
- _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_end = 7298
- _SPARKCONNECTSERVICE._serialized_start = 7301
- _SPARKCONNECTSERVICE._serialized_end = 7666
+ _EXECUTEPLANRESPONSE._serialized_end = 5060
+ _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_start = 4291
+ _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_end = 4362
+ _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 4364
+ _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 4425
+ _EXECUTEPLANRESPONSE_METRICS._serialized_start = 4428
+ _EXECUTEPLANRESPONSE_METRICS._serialized_end = 4945
+ _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 4523
+ _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 4855
+
_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start
= 4732
+
_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end
= 4855
+ _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 4857
+ _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 4945
+ _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_start = 4947
+ _EXECUTEPLANRESPONSE_OBSERVEDMETRICS._serialized_end = 5043
+ _KEYVALUE._serialized_start = 5062
+ _KEYVALUE._serialized_end = 5127
+ _CONFIGREQUEST._serialized_start = 5130
+ _CONFIGREQUEST._serialized_end = 6158
+ _CONFIGREQUEST_OPERATION._serialized_start = 5350
+ _CONFIGREQUEST_OPERATION._serialized_end = 5848
+ _CONFIGREQUEST_SET._serialized_start = 5850
+ _CONFIGREQUEST_SET._serialized_end = 5902
+ _CONFIGREQUEST_GET._serialized_start = 5904
+ _CONFIGREQUEST_GET._serialized_end = 5929
+ _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 5931
+ _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 5994
+ _CONFIGREQUEST_GETOPTION._serialized_start = 5996
+ _CONFIGREQUEST_GETOPTION._serialized_end = 6027
+ _CONFIGREQUEST_GETALL._serialized_start = 6029
+ _CONFIGREQUEST_GETALL._serialized_end = 6077
+ _CONFIGREQUEST_UNSET._serialized_start = 6079
+ _CONFIGREQUEST_UNSET._serialized_end = 6106
+ _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 6108
+ _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 6142
+ _CONFIGRESPONSE._serialized_start = 6160
+ _CONFIGRESPONSE._serialized_end = 6282
+ _ADDARTIFACTSREQUEST._serialized_start = 6285
+ _ADDARTIFACTSREQUEST._serialized_end = 7156
+ _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_start = 6672
+ _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_end = 6725
+ _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_start = 6727
+ _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_end = 6838
+ _ADDARTIFACTSREQUEST_BATCH._serialized_start = 6840
+ _ADDARTIFACTSREQUEST_BATCH._serialized_end = 6933
+ _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_start = 6936
+ _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_end = 7129
+ _ADDARTIFACTSRESPONSE._serialized_start = 7159
+ _ADDARTIFACTSRESPONSE._serialized_end = 7347
+ _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_start = 7266
+ _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_end = 7347
+ _SPARKCONNECTSERVICE._serialized_start = 7350
+ _SPARKCONNECTSERVICE._serialized_end = 7715
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi
b/python/pyspark/sql/connect/proto/base_pb2.pyi
index 5ff14c9ac93..4c020308d9a 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/base_pb2.pyi
@@ -1056,6 +1056,7 @@ class
ExecutePlanResponse(google.protobuf.message.Message):
EXTENSION_FIELD_NUMBER: builtins.int
METRICS_FIELD_NUMBER: builtins.int
OBSERVED_METRICS_FIELD_NUMBER: builtins.int
+ SCHEMA_FIELD_NUMBER: builtins.int
session_id: builtins.str
@property
def arrow_batch(self) -> global___ExecutePlanResponse.ArrowBatch: ...
@@ -1077,6 +1078,9 @@ class
ExecutePlanResponse(google.protobuf.message.Message):
global___ExecutePlanResponse.ObservedMetrics
]:
"""The metrics observed during the execution of the query plan."""
+ @property
+ def schema(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
+ """(Optional) The Spark schema. This field is available when `collect`
is called."""
def __init__(
self,
*,
@@ -1087,6 +1091,7 @@ class
ExecutePlanResponse(google.protobuf.message.Message):
metrics: global___ExecutePlanResponse.Metrics | None = ...,
observed_metrics:
collections.abc.Iterable[global___ExecutePlanResponse.ObservedMetrics]
| None = ...,
+ schema: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
) -> None: ...
def HasField(
self,
@@ -1099,6 +1104,8 @@ class
ExecutePlanResponse(google.protobuf.message.Message):
b"metrics",
"response_type",
b"response_type",
+ "schema",
+ b"schema",
"sql_command_result",
b"sql_command_result",
],
@@ -1116,6 +1123,8 @@ class
ExecutePlanResponse(google.protobuf.message.Message):
b"observed_metrics",
"response_type",
b"response_type",
+ "schema",
+ b"schema",
"session_id",
b"session_id",
"sql_command_result",
diff --git a/python/pyspark/sql/connect/session.py
b/python/pyspark/sql/connect/session.py
index 4cf4d0341bb..8fe5020f4a4 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -342,20 +342,23 @@ class SparkSession:
# For cases like createDataFrame([("Alice", None, 80.1)],
schema)
# we can not infer the schema from the data itself.
warnings.warn("failed to infer the schema from data")
- if _schema is None and _schema_str is not None:
+ if _schema_str is not None:
_parsed = self.client._analyze(
method="ddl_parse", ddl_string=_schema_str
).parsed
if isinstance(_parsed, StructType):
- _schema = _parsed
+ _inferred_schema = _parsed
elif isinstance(_parsed, DataType):
- _schema = StructType().add("value", _parsed)
- if _schema is None or not isinstance(_schema, StructType):
+ _inferred_schema = StructType().add("value",
_parsed)
+ _schema_str = None
+ if _inferred_schema is None or not
isinstance(_inferred_schema, StructType):
raise ValueError(
"Some of types cannot be determined after
inferring, "
"a StructType Schema is required in this case"
)
- _inferred_schema = _schema
+
+ if _schema_str is None and _cols is None:
+ _schema = _inferred_schema
from pyspark.sql.connect.conversion import
LocalDataToArrowConversion
diff --git a/python/pyspark/sql/connect/types.py
b/python/pyspark/sql/connect/types.py
index b5145d91c76..dfb0fb5303f 100644
--- a/python/pyspark/sql/connect/types.py
+++ b/python/pyspark/sql/connect/types.py
@@ -323,6 +323,8 @@ def to_arrow_type(dt: DataType) -> "pa.DataType":
arrow_type = pa.struct(fields)
elif type(dt) == NullType:
arrow_type = pa.null()
+ elif isinstance(dt, UserDefinedType):
+ arrow_type = to_arrow_type(dt.sqlType())
else:
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
return arrow_type
diff --git a/python/pyspark/sql/tests/connect/test_parity_types.py
b/python/pyspark/sql/tests/connect/test_parity_types.py
index 67d5a17660e..a2f81fbf25e 100644
--- a/python/pyspark/sql/tests/connect/test_parity_types.py
+++ b/python/pyspark/sql/tests/connect/test_parity_types.py
@@ -34,26 +34,6 @@ class TypesParityTests(TypesTestsMixin,
ReusedConnectTestCase):
def test_apply_schema_to_row(self):
super().test_apply_schema_to_dict_and_rows()
- # TODO(SPARK-42020): createDataFrame with UDT
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_apply_schema_with_udt(self):
- super().test_apply_schema_with_udt()
-
- # TODO(SPARK-42020): createDataFrame with UDT
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_cast_to_string_with_udt(self):
- super().test_cast_to_string_with_udt()
-
- # TODO(SPARK-42020): createDataFrame with UDT
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_cast_to_udt_with_udt(self):
- super().test_cast_to_udt_with_udt()
-
- # TODO(SPARK-42020): createDataFrame with UDT
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_complex_nested_udt_in_df(self):
- super().test_complex_nested_udt_in_df()
-
@unittest.skip("Spark Connect does not support RDD but the tests depend on
them.")
def test_create_dataframe_schema_mismatch(self):
super().test_create_dataframe_schema_mismatch()
@@ -103,46 +83,14 @@ class TypesParityTests(TypesTestsMixin,
ReusedConnectTestCase):
def test_infer_schema_upcast_int_to_string(self):
super().test_infer_schema_upcast_int_to_string()
- # TODO(SPARK-42020): createDataFrame with UDT
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_infer_schema_with_udt(self):
- super().test_infer_schema_with_udt()
-
- # TODO(SPARK-42020): createDataFrame with UDT
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_nested_udt_in_df(self):
- super().test_nested_udt_in_df()
-
- # TODO(SPARK-42020): createDataFrame with UDT
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_parquet_with_udt(self):
- super().test_parquet_with_udt()
-
- # TODO(SPARK-42020): createDataFrame with UDT
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_simple_udt_in_df(self):
- super().test_simple_udt_in_df()
-
- # TODO(SPARK-42020): createDataFrame with UDT
- @unittest.skip("Fails in Spark Connect, should enable.")
+ @unittest.skip("Spark Connect does not support RDD but the tests depend on
them.")
def test_udf_with_udt(self):
super().test_udf_with_udt()
- # TODO(SPARK-42020): createDataFrame with UDT
- @unittest.skip("Fails in Spark Connect, should enable.")
+ @unittest.skip("Requires JVM access.")
def test_udt(self):
super().test_udt()
- # TODO(SPARK-42020): createDataFrame with UDT
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_udt_with_none(self):
- super().test_udt_with_none()
-
- # TODO(SPARK-42020): createDataFrame with UDT
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_union_with_udt(self):
- super().test_union_with_udt()
-
if __name__ == "__main__":
import unittest
diff --git a/python/pyspark/sql/tests/test_types.py
b/python/pyspark/sql/tests/test_types.py
index 9db090fa810..aaac43cdf67 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -692,10 +692,10 @@ class TypesTestsMixin:
row = Row(point=ExamplePoint(1.0, 2.0),
python_only_point=PythonOnlyPoint(1.0, 2.0))
df = self.spark.createDataFrame([row])
- self.assertRaises(AnalysisException, lambda:
df.select(col("point").cast(PythonOnlyUDT())))
- self.assertRaises(
- AnalysisException, lambda:
df.select(col("python_only_point").cast(ExamplePointUDT()))
- )
+ with self.assertRaises(AnalysisException):
+ df.select(col("point").cast(PythonOnlyUDT())).collect()
+ with self.assertRaises(AnalysisException):
+
df.select(col("python_only_point").cast(ExamplePointUDT())).collect()
def test_struct_type(self):
struct1 = StructType().add("f1", StringType(), True).add("f2",
StringType(), True, None)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
index 6c6635bac57..d6a8fec81dd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
@@ -106,6 +106,7 @@ private[sql] object ArrowUtils {
.add(MapVector.VALUE_NAME, valueType, nullable =
valueContainsNull),
nullable = false,
timeZoneId)).asJava)
+ case udt: UserDefinedType[_] => toArrowField(name, udt.sqlType,
nullable, timeZoneId)
case dataType =>
val fieldType = new FieldType(nullable, toArrowType(dataType,
timeZoneId), null)
new Field(name, fieldType, Seq.empty[Field].asJava)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]