This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 0b3d9544c93 [SPARK-40836][CONNECT] AnalyzeResult should use struct for
schema
0b3d9544c93 is described below
commit 0b3d9544c934c0c21609cd2c1a08687333c7e0ca
Author: Rui Wang <[email protected]>
AuthorDate: Tue Oct 25 15:31:17 2022 +0800
[SPARK-40836][CONNECT] AnalyzeResult should use struct for schema
### What changes were proposed in this pull request?
This PR replace column names and columns type with a schema (which is a
struct).
### Why are the changes needed?
Before this PR, AnalyzeResult separates column names and column types.
However these two can be combined to form a schema which is a struct. This PR
will simplify that proto message.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
UT
Closes #38301 from amaliujia/return_schema_use_struct.
Authored-by: Rui Wang <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../src/main/protobuf/spark/connect/base.proto | 6 +--
.../connect/planner/DataTypeProtoConverter.scala | 19 ++++++-
.../sql/connect/service/SparkConnectService.scala | 47 +++++++++---------
.../connect/planner/SparkConnectServiceSuite.scala | 58 ++++++++++++++++++++++
python/pyspark/sql/connect/client.py | 47 ++++++++++++++++--
python/pyspark/sql/connect/dataframe.py | 26 +++++++++-
python/pyspark/sql/connect/proto/base_pb2.py | 51 +++++++++----------
python/pyspark/sql/connect/proto/base_pb2.pyi | 27 +++-------
.../sql/tests/connect/test_connect_basic.py | 10 ++++
9 files changed, 212 insertions(+), 79 deletions(-)
diff --git a/connector/connect/src/main/protobuf/spark/connect/base.proto
b/connector/connect/src/main/protobuf/spark/connect/base.proto
index dff1734335e..b376515bf1a 100644
--- a/connector/connect/src/main/protobuf/spark/connect/base.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/base.proto
@@ -22,6 +22,7 @@ package spark.connect;
import "google/protobuf/any.proto";
import "spark/connect/commands.proto";
import "spark/connect/relations.proto";
+import "spark/connect/types.proto";
option java_multiple_files = true;
option java_package = "org.apache.spark.connect.proto";
@@ -116,11 +117,10 @@ message Response {
// reason about the performance.
message AnalyzeResponse {
string client_id = 1;
- repeated string column_names = 2;
- repeated string column_types = 3;
+ DataType schema = 2;
// The extended explain string as produced by Spark.
- string explain_string = 4;
+ string explain_string = 3;
}
// Main interface for the SparkConnect service.
diff --git
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala
index da3adce43ba..0ee90b5e8fb 100644
---
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala
+++
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala
@@ -21,7 +21,7 @@ import scala.collection.convert.ImplicitConversions._
import org.apache.spark.connect.proto
import org.apache.spark.sql.SaveMode
-import org.apache.spark.sql.types.{DataType, IntegerType, StringType,
StructField, StructType}
+import org.apache.spark.sql.types.{DataType, IntegerType, LongType,
StringType, StructField, StructType}
/**
* This object offers methods to convert to/from connect proto to catalyst
types.
@@ -50,11 +50,28 @@ object DataTypeProtoConverter {
proto.DataType.newBuilder().setI32(proto.DataType.I32.getDefaultInstance).build()
case StringType =>
proto.DataType.newBuilder().setString(proto.DataType.String.getDefaultInstance).build()
+ case LongType =>
+
proto.DataType.newBuilder().setI64(proto.DataType.I64.getDefaultInstance).build()
+ case struct: StructType =>
+ toConnectProtoStructType(struct)
case _ =>
throw InvalidPlanInput(s"Does not support convert ${t.typeName} to
connect proto types.")
}
}
+ def toConnectProtoStructType(schema: StructType): proto.DataType = {
+ val struct = proto.DataType.Struct.newBuilder()
+ for (structField <- schema.fields) {
+ struct.addFields(
+ proto.DataType.StructField
+ .newBuilder()
+ .setName(structField.name)
+ .setType(toConnectProtoType(structField.dataType))
+ .setNullable(structField.nullable))
+ }
+ proto.DataType.newBuilder().setStruct(struct).build()
+ }
+
def toSaveMode(mode: proto.WriteOperation.SaveMode): SaveMode = {
mode match {
case proto.WriteOperation.SaveMode.SAVE_MODE_APPEND => SaveMode.Append
diff --git
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
index 20776a29eda..5841017e5bb 100644
---
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -19,8 +19,6 @@ package org.apache.spark.sql.connect.service
import java.util.concurrent.TimeUnit
-import scala.collection.JavaConverters._
-
import com.google.common.base.Ticker
import com.google.common.cache.CacheBuilder
import io.grpc.{Server, Status}
@@ -35,7 +33,7 @@ import org.apache.spark.connect.proto.{AnalyzeResponse,
Request, Response, Spark
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_BINDING_PORT
-import org.apache.spark.sql.connect.planner.SparkConnectPlanner
+import org.apache.spark.sql.connect.planner.{DataTypeProtoConverter,
SparkConnectPlanner}
import org.apache.spark.sql.execution.ExtendedMode
/**
@@ -89,29 +87,16 @@ class SparkConnectService(debug: Boolean)
request: Request,
responseObserver: StreamObserver[AnalyzeResponse]): Unit = {
try {
+ if (request.getPlan.getOpTypeCase != proto.Plan.OpTypeCase.ROOT) {
+ responseObserver.onError(
+ new UnsupportedOperationException(
+ s"${request.getPlan.getOpTypeCase} not supported for analysis."))
+ }
val session =
SparkConnectService.getOrCreateIsolatedSession(request.getUserContext.getUserId).session
-
- val logicalPlan = request.getPlan.getOpTypeCase match {
- case proto.Plan.OpTypeCase.ROOT =>
- new SparkConnectPlanner(request.getPlan.getRoot, session).transform()
- case _ =>
- responseObserver.onError(
- new UnsupportedOperationException(
- s"${request.getPlan.getOpTypeCase} not supported for analysis."))
- return
- }
- val ds = Dataset.ofRows(session, logicalPlan)
- val explainString = ds.queryExecution.explainString(ExtendedMode)
-
- val resp = proto.AnalyzeResponse
- .newBuilder()
- .setExplainString(explainString)
- .setClientId(request.getClientId)
-
- resp.addAllColumnTypes(ds.schema.fields.map(_.dataType.sql).toSeq.asJava)
- resp.addAllColumnNames(ds.schema.fields.map(_.name).toSeq.asJava)
- responseObserver.onNext(resp.build())
+ val response = handleAnalyzePlanRequest(request.getPlan.getRoot, session)
+ response.setClientId(request.getClientId)
+ responseObserver.onNext(response.build())
responseObserver.onCompleted()
} catch {
case e: Throwable =>
@@ -120,6 +105,20 @@ class SparkConnectService(debug: Boolean)
Status.UNKNOWN.withCause(e).withDescription(e.getLocalizedMessage).asRuntimeException())
}
}
+
+ def handleAnalyzePlanRequest(
+ relation: proto.Relation,
+ session: SparkSession): proto.AnalyzeResponse.Builder = {
+ val logicalPlan = new SparkConnectPlanner(relation, session).transform()
+
+ val ds = Dataset.ofRows(session, logicalPlan)
+ val explainString = ds.queryExecution.explainString(ExtendedMode)
+
+ val response = proto.AnalyzeResponse
+ .newBuilder()
+ .setExplainString(explainString)
+ response.setSchema(DataTypeProtoConverter.toConnectProtoType(ds.schema))
+ }
}
/**
diff --git
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
new file mode 100644
index 00000000000..4be8d1705b9
--- /dev/null
+++
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.planner
+
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.connect.service.SparkConnectService
+import org.apache.spark.sql.test.SharedSparkSession
+
+/**
+ * Testing Connect Service implementation.
+ */
+class SparkConnectServiceSuite extends SharedSparkSession {
+
+ test("Test schema in analyze response") {
+ withTable("test") {
+ spark.sql("""
+ | CREATE TABLE test (col1 INT, col2 STRING)
+ | USING parquet
+ |""".stripMargin)
+
+ val instance = new SparkConnectService(false)
+ val relation = proto.Relation
+ .newBuilder()
+ .setRead(
+ proto.Read
+ .newBuilder()
+
.setNamedTable(proto.Read.NamedTable.newBuilder.setUnparsedIdentifier("test").build())
+ .build())
+ .build()
+
+ val response = instance.handleAnalyzePlanRequest(relation, spark)
+
+ assert(response.getSchema.hasStruct)
+ val schema = response.getSchema.getStruct
+ assert(schema.getFieldsCount == 2)
+ assert(
+ schema.getFields(0).getName == "col1"
+ && schema.getFields(0).getType.getKindCase ==
proto.DataType.KindCase.I32)
+ assert(
+ schema.getFields(1).getName == "col2"
+ && schema.getFields(1).getType.getKindCase ==
proto.DataType.KindCase.STRING)
+ }
+ }
+}
diff --git a/python/pyspark/sql/connect/client.py
b/python/pyspark/sql/connect/client.py
index 0ae075521c6..f4b6d2ec302 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -33,6 +33,7 @@ from pyspark import cloudpickle
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.connect.readwriter import DataFrameReader
from pyspark.sql.connect.plan import SQL
+from pyspark.sql.types import DataType, StructType, StructField, LongType,
StringType
from typing import Optional, Any, Union
@@ -91,14 +92,13 @@ class PlanMetrics:
class AnalyzeResult:
- def __init__(self, cols: typing.List[str], types: typing.List[str],
explain: str):
- self.cols = cols
- self.types = types
+ def __init__(self, schema: pb2.DataType, explain: str):
+ self.schema = schema
self.explain_string = explain
@classmethod
def fromProto(cls, pb: typing.Any) -> "AnalyzeResult":
- return AnalyzeResult(pb.column_names, pb.column_types,
pb.explain_string)
+ return AnalyzeResult(pb.schema, pb.explain_string)
class RemoteSparkSession(object):
@@ -151,7 +151,44 @@ class RemoteSparkSession(object):
req.plan.CopyFrom(plan)
return self._execute_and_fetch(req)
- def analyze(self, plan: pb2.Plan) -> AnalyzeResult:
+ def _proto_schema_to_pyspark_schema(self, schema: pb2.DataType) ->
DataType:
+ if schema.HasField("struct"):
+ structFields = []
+ for proto_field in schema.struct.fields:
+ structFields.append(
+ StructField(
+ proto_field.name,
+ self._proto_schema_to_pyspark_schema(proto_field.type),
+ proto_field.nullable,
+ )
+ )
+ return StructType(structFields)
+ elif schema.HasField("i64"):
+ return LongType()
+ elif schema.HasField("string"):
+ return StringType()
+ else:
+ raise Exception("Only support long, string, struct conversion")
+
+ def schema(self, plan: pb2.Plan) -> StructType:
+ proto_schema = self._analyze(plan).schema
+ # Server side should populate the struct field which is the schema.
+ assert proto_schema.HasField("struct")
+ structFields = []
+ for proto_field in proto_schema.struct.fields:
+ structFields.append(
+ StructField(
+ proto_field.name,
+ self._proto_schema_to_pyspark_schema(proto_field.type),
+ proto_field.nullable,
+ )
+ )
+ return StructType(structFields)
+
+ def explain_string(self, plan: pb2.Plan) -> str:
+ return self._analyze(plan).explain_string
+
+ def _analyze(self, plan: pb2.Plan) -> AnalyzeResult:
req = pb2.Request()
req.user_context.user_id = self._user_id
req.plan.CopyFrom(plan)
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index 2b7e3d52039..bf9ed83615b 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -34,6 +34,7 @@ from pyspark.sql.connect.column import (
Expression,
LiteralExpression,
)
+from pyspark.sql.types import StructType
if TYPE_CHECKING:
from pyspark.sql.connect.typing import ColumnOrString, ExpressionOrString
@@ -96,7 +97,7 @@ class DataFrame(object):
of the DataFrame with the changes applied.
"""
- def __init__(self, data: Optional[List[Any]] = None, schema:
Optional[List[str]] = None):
+ def __init__(self, data: Optional[List[Any]] = None, schema:
Optional[StructType] = None):
"""Creates a new data frame"""
self._schema = schema
self._plan: Optional[plan.LogicalPlan] = None
@@ -315,11 +316,32 @@ class DataFrame(object):
query = self._plan.to_proto(self._session)
return self._session._to_pandas(query)
+ def schema(self) -> StructType:
+ """Returns the schema of this :class:`DataFrame` as a
:class:`pyspark.sql.types.StructType`.
+
+ .. versionadded:: 3.4.0
+
+ Returns
+ -------
+ :class:`StructType`
+ """
+ if self._schema is None:
+ if self._plan is not None:
+ query = self._plan.to_proto(self._session)
+ if self._session is None:
+ raise Exception("Cannot analyze without
RemoteSparkSession.")
+ self._schema = self._session.schema(query)
+ return self._schema
+ else:
+ raise Exception("Empty plan.")
+ else:
+ return self._schema
+
def explain(self) -> str:
if self._plan is not None:
query = self._plan.to_proto(self._session)
if self._session is None:
raise Exception("Cannot analyze without RemoteSparkSession.")
- return self._session.analyze(query).explain_string
+ return self._session.explain_string(query)
else:
return ""
diff --git a/python/pyspark/sql/connect/proto/base_pb2.py
b/python/pyspark/sql/connect/proto/base_pb2.py
index 408872dbb66..eb9ecc9157f 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.py
+++ b/python/pyspark/sql/connect/proto/base_pb2.py
@@ -31,10 +31,11 @@ _sym_db = _symbol_database.Default()
from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
from pyspark.sql.connect.proto import commands_pb2 as
spark_dot_connect_dot_commands__pb2
from pyspark.sql.connect.proto import relations_pb2 as
spark_dot_connect_dot_relations__pb2
+from pyspark.sql.connect.proto import types_pb2 as
spark_dot_connect_dot_types__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\x92\x02\n\x07Request\x12\x1b\n\tclient_id\x18\x01
\x01(\tR\x08\x63lientId\x12\x45\n\x0cuser_context\x18\x02
\x01(\x0b\x32".spark.co [...]
+
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\x92\x02\n\x07Request\x12\x1b\n\tclient_id\x18\x01
\x01(\tR\x08\x63lientId\x12\x45\n\x0cuser_contex [...]
)
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -45,28 +46,28 @@ if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._serialized_options =
b"\n\036org.apache.spark.connect.protoP\001"
_RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._options = None
_RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_options =
b"8\001"
- _PLAN._serialized_start = 131
- _PLAN._serialized_end = 247
- _REQUEST._serialized_start = 250
- _REQUEST._serialized_end = 524
- _REQUEST_USERCONTEXT._serialized_start = 402
- _REQUEST_USERCONTEXT._serialized_end = 524
- _RESPONSE._serialized_start = 527
- _RESPONSE._serialized_end = 1495
- _RESPONSE_ARROWBATCH._serialized_start = 756
- _RESPONSE_ARROWBATCH._serialized_end = 931
- _RESPONSE_JSONBATCH._serialized_start = 933
- _RESPONSE_JSONBATCH._serialized_end = 993
- _RESPONSE_METRICS._serialized_start = 996
- _RESPONSE_METRICS._serialized_end = 1480
- _RESPONSE_METRICS_METRICOBJECT._serialized_start = 1080
- _RESPONSE_METRICS_METRICOBJECT._serialized_end = 1390
- _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start =
1278
- _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1390
- _RESPONSE_METRICS_METRICVALUE._serialized_start = 1392
- _RESPONSE_METRICS_METRICVALUE._serialized_end = 1480
- _ANALYZERESPONSE._serialized_start = 1498
- _ANALYZERESPONSE._serialized_end = 1653
- _SPARKCONNECTSERVICE._serialized_start = 1656
- _SPARKCONNECTSERVICE._serialized_end = 1818
+ _PLAN._serialized_start = 158
+ _PLAN._serialized_end = 274
+ _REQUEST._serialized_start = 277
+ _REQUEST._serialized_end = 551
+ _REQUEST_USERCONTEXT._serialized_start = 429
+ _REQUEST_USERCONTEXT._serialized_end = 551
+ _RESPONSE._serialized_start = 554
+ _RESPONSE._serialized_end = 1522
+ _RESPONSE_ARROWBATCH._serialized_start = 783
+ _RESPONSE_ARROWBATCH._serialized_end = 958
+ _RESPONSE_JSONBATCH._serialized_start = 960
+ _RESPONSE_JSONBATCH._serialized_end = 1020
+ _RESPONSE_METRICS._serialized_start = 1023
+ _RESPONSE_METRICS._serialized_end = 1507
+ _RESPONSE_METRICS_METRICOBJECT._serialized_start = 1107
+ _RESPONSE_METRICS_METRICOBJECT._serialized_end = 1417
+ _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start =
1305
+ _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1417
+ _RESPONSE_METRICS_METRICVALUE._serialized_start = 1419
+ _RESPONSE_METRICS_METRICVALUE._serialized_end = 1507
+ _ANALYZERESPONSE._serialized_start = 1525
+ _ANALYZERESPONSE._serialized_end = 1659
+ _SPARKCONNECTSERVICE._serialized_start = 1662
+ _SPARKCONNECTSERVICE._serialized_end = 1824
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi
b/python/pyspark/sql/connect/proto/base_pb2.pyi
index bb3a6578cf7..5ffd7701b44 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/base_pb2.pyi
@@ -41,6 +41,7 @@ import google.protobuf.internal.containers
import google.protobuf.message
import pyspark.sql.connect.proto.commands_pb2
import pyspark.sql.connect.proto.relations_pb2
+import pyspark.sql.connect.proto.types_pb2
import sys
if sys.version_info >= (3, 8):
@@ -401,39 +402,27 @@ class AnalyzeResponse(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
CLIENT_ID_FIELD_NUMBER: builtins.int
- COLUMN_NAMES_FIELD_NUMBER: builtins.int
- COLUMN_TYPES_FIELD_NUMBER: builtins.int
+ SCHEMA_FIELD_NUMBER: builtins.int
EXPLAIN_STRING_FIELD_NUMBER: builtins.int
client_id: builtins.str
@property
- def column_names(
- self,
- ) ->
google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
...
- @property
- def column_types(
- self,
- ) ->
google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
...
+ def schema(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
explain_string: builtins.str
"""The extended explain string as produced by Spark."""
def __init__(
self,
*,
client_id: builtins.str = ...,
- column_names: collections.abc.Iterable[builtins.str] | None = ...,
- column_types: collections.abc.Iterable[builtins.str] | None = ...,
+ schema: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
explain_string: builtins.str = ...,
) -> None: ...
+ def HasField(
+ self, field_name: typing_extensions.Literal["schema", b"schema"]
+ ) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
- "client_id",
- b"client_id",
- "column_names",
- b"column_names",
- "column_types",
- b"column_types",
- "explain_string",
- b"explain_string",
+ "client_id", b"client_id", "explain_string", b"explain_string",
"schema", b"schema"
],
) -> None: ...
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index f6988a1d120..459b05cc37a 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -22,6 +22,7 @@ import tempfile
import pandas
from pyspark.sql import SparkSession, Row
+from pyspark.sql.types import StructType, StructField, LongType, StringType
from pyspark.sql.connect.client import RemoteSparkSession
from pyspark.sql.connect.function_builder import udf
from pyspark.sql.connect.functions import lit
@@ -97,6 +98,15 @@ class SparkConnectTests(SparkConnectSQLTestCase):
result = df.explain()
self.assertGreater(len(result), 0)
+ def test_schema(self):
+ schema = self.connect.read.table(self.tbl_name).schema()
+ self.assertEqual(
+ StructType(
+ [StructField("id", LongType(), True), StructField("name",
StringType(), True)]
+ ),
+ schema,
+ )
+
def test_simple_binary_expressions(self):
"""Test complex expression"""
df = self.connect.read.table(self.tbl_name)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]