This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new fbd620989d7 [SPARK-42042][CONNECT][PYTHON] DataFrameReader` should
support StructType schema
fbd620989d7 is described below
commit fbd620989d77e810c1066889840b237a7eca920a
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Fri Jan 13 14:56:09 2023 +0800
[SPARK-42042][CONNECT][PYTHON] DataFrameReader` should support StructType
schema
### What changes were proposed in this pull request?
`DataFrameReader` should support StructType schema
### Why are the changes needed?
for parity
### Does this PR introduce _any_ user-facing change?
yes
### How was this patch tested?
updated ut
Closes #39545 from zhengruifeng/connect_io_struct_type_schema.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../main/protobuf/spark/connect/relations.proto | 2 ++
.../sql/connect/planner/SparkConnectPlanner.scala | 11 ++++--
python/pyspark/sql/connect/proto/relations_pb2.pyi | 5 ++-
python/pyspark/sql/connect/readwriter.py | 14 +++++---
.../sql/tests/connect/test_connect_basic.py | 39 +++++++++++++++-------
5 files changed, 51 insertions(+), 20 deletions(-)
diff --git
a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index 1ddf9f9c0db..f029273c48b 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -119,6 +119,8 @@ message Read {
string format = 1;
// (Optional) If not set, Spark will infer the schema.
+ //
+ // This schema string should be either DDL-formatted or JSON-formatted.
optional string schema = 2;
// Options for the data source. The context of this map varies based on the
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 4c6c63729d5..512ac8efe2e 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -658,8 +658,15 @@ class SparkConnectPlanner(session: SparkSession) {
val reader = session.read
reader.format(rel.getDataSource.getFormat)
localMap.foreach { case (key, value) => reader.option(key, value) }
- if (rel.getDataSource.getSchema != null &&
!rel.getDataSource.getSchema.isEmpty) {
- reader.schema(rel.getDataSource.getSchema)
+ if (rel.getDataSource.hasSchema &&
rel.getDataSource.getSchema.nonEmpty) {
+
+ DataType.parseTypeWithFallback(
+ rel.getDataSource.getSchema,
+ StructType.fromDDL,
+ fallbackParser = DataType.fromJson) match {
+ case s: StructType => reader.schema(s)
+ case other => throw InvalidPlanInput(s"Invalid schema $other")
+ }
}
reader.load().queryExecution.analyzed
case _ => throw InvalidPlanInput("Does not support " +
rel.getReadTypeCase.name())
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 282bac5dff1..04512a4c891 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -557,7 +557,10 @@ class Read(google.protobuf.message.Message):
format: builtins.str
"""(Required) Supported formats include: parquet, orc, text, json,
parquet, csv, avro."""
schema: builtins.str
- """(Optional) If not set, Spark will infer the schema."""
+ """(Optional) If not set, Spark will infer the schema.
+
+ This schema string should be either DDL-formatted or JSON-formatted.
+ """
@property
def options(
self,
diff --git a/python/pyspark/sql/connect/readwriter.py
b/python/pyspark/sql/connect/readwriter.py
index f256e5d3118..62a082cc90a 100644
--- a/python/pyspark/sql/connect/readwriter.py
+++ b/python/pyspark/sql/connect/readwriter.py
@@ -69,9 +69,13 @@ class DataFrameReader(OptionUtils):
format.__doc__ = PySparkDataFrameReader.format.__doc__
- # TODO(SPARK-40539): support StructType in python client and support
schema as StructType.
- def schema(self, schema: str) -> "DataFrameReader":
- self._schema = schema
+ def schema(self, schema: Union[StructType, str]) -> "DataFrameReader":
+ if isinstance(schema, StructType):
+ self._schema = schema.json()
+ elif isinstance(schema, str):
+ self._schema = schema
+ else:
+ raise TypeError(f"schema must be a StructType or str, but got
{schema}")
return self
schema.__doc__ = PySparkDataFrameReader.schema.__doc__
@@ -93,7 +97,7 @@ class DataFrameReader(OptionUtils):
self,
path: Optional[str] = None,
format: Optional[str] = None,
- schema: Optional[str] = None,
+ schema: Optional[Union[StructType, str]] = None,
**options: "OptionalPrimitiveType",
) -> "DataFrame":
if format is not None:
@@ -122,7 +126,7 @@ class DataFrameReader(OptionUtils):
def json(
self,
path: str,
- schema: Optional[str] = None,
+ schema: Optional[Union[StructType, str]] = None,
primitivesAsString: Optional[Union[bool, str]] = None,
prefersDecimal: Optional[Union[bool, str]] = None,
allowComments: Optional[Union[bool, str]] = None,
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 8d9becf259a..94b9854c7b3 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -214,10 +214,21 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
).format("json").save(d)
# Read the JSON file as a DataFrame.
self.assert_eq(self.connect.read.json(d).toPandas(),
self.spark.read.json(d).toPandas())
- self.assert_eq(
- self.connect.read.json(path=d, schema="age INT, name
STRING").toPandas(),
- self.spark.read.json(path=d, schema="age INT, name
STRING").toPandas(),
- )
+
+ for schema in [
+ "age INT, name STRING",
+ StructType(
+ [
+ StructField("age", IntegerType()),
+ StructField("name", StringType()),
+ ]
+ ),
+ ]:
+ self.assert_eq(
+ self.connect.read.json(path=d, schema=schema).toPandas(),
+ self.spark.read.json(path=d, schema=schema).toPandas(),
+ )
+
self.assert_eq(
self.connect.read.json(path=d,
primitivesAsString=True).toPandas(),
self.spark.read.json(path=d,
primitivesAsString=True).toPandas(),
@@ -1677,14 +1688,18 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
shutil.rmtree(tmpPath)
writeDf.write.text(tmpPath)
- readDf = self.connect.read.format("text").schema("id
STRING").load(path=tmpPath)
- expectResult = writeDf.collect()
- pandasResult = readDf.toPandas()
- if pandasResult is None:
- self.assertTrue(False, "Empty pandas dataframe")
- else:
- actualResult = pandasResult.values.tolist()
- self.assertEqual(len(expectResult), len(actualResult))
+ for schema in [
+ "id STRING",
+ StructType([StructField("id", StringType())]),
+ ]:
+ readDf =
self.connect.read.format("text").schema(schema).load(path=tmpPath)
+ expectResult = writeDf.collect()
+ pandasResult = readDf.toPandas()
+ if pandasResult is None:
+ self.assertTrue(False, "Empty pandas dataframe")
+ else:
+ actualResult = pandasResult.values.tolist()
+ self.assertEqual(len(expectResult), len(actualResult))
def test_simple_read_without_schema(self) -> None:
"""SPARK-41300: Schema not set when reading CSV."""
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]