This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d863503e8737 [SPARK-48434][PYTHON][CONNECT] Make `printSchema` use the
cached schema
d863503e8737 is described below
commit d863503e8737937fc90c68583a3762fa67f53401
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Tue May 28 13:37:25 2024 +0900
[SPARK-48434][PYTHON][CONNECT] Make `printSchema` use the cached schema
### What changes were proposed in this pull request?
Make `printSchema` use the cached schema
### Why are the changes needed?
to avoid extra RPCs
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
added tests
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #46764 from zhengruifeng/connect_print_schema.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/connect/dataframe.py | 5 ++++-
.../sql/tests/connect/test_connect_basic.py | 26 ++++++++++++++++++++++
2 files changed, 30 insertions(+), 1 deletion(-)
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index 62c73da374bc..354cf60c2014 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -1811,7 +1811,10 @@ class DataFrame(ParentDataFrame):
return result
def printSchema(self, level: Optional[int] = None) -> None:
- print(self._tree_string(level))
+ if level:
+ print(self.schema.treeString(level))
+ else:
+ print(self.schema.treeString())
def inputFiles(self) -> List[str]:
query = self._plan.to_proto(self._session.client)
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 0648b5ce9925..eb5cb18d11e6 100755
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -20,6 +20,8 @@ import gc
import unittest
import shutil
import tempfile
+import io
+from contextlib import redirect_stdout
from pyspark.util import is_remote_only
from pyspark.errors import PySparkTypeError, PySparkValueError
@@ -352,6 +354,24 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
result = df._explain_string()
self.assertGreater(len(result), 0)
+ def _check_print_schema(self, query: str):
+ with io.StringIO() as buf, redirect_stdout(buf):
+ self.spark.sql(query).printSchema()
+ print1 = buf.getvalue()
+ with io.StringIO() as buf, redirect_stdout(buf):
+ self.connect.sql(query).printSchema()
+ print2 = buf.getvalue()
+ self.assertEqual(print1, print2, query)
+
+ for level in [-1, 0, 1, 2, 3, 4]:
+ with io.StringIO() as buf, redirect_stdout(buf):
+ self.spark.sql(query).printSchema(level)
+ print1 = buf.getvalue()
+ with io.StringIO() as buf, redirect_stdout(buf):
+ self.connect.sql(query).printSchema(level)
+ print2 = buf.getvalue()
+ self.assertEqual(print1, print2, query)
+
def test_schema(self):
schema = self.connect.read.table(self.tbl_name).schema
self.assertEqual(
@@ -373,6 +393,7 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
self.spark.sql(query).schema,
self.connect.sql(query).schema,
)
+ self._check_print_schema(query)
# test TimestampType, DateType
query = """
@@ -386,6 +407,7 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
self.spark.sql(query).schema,
self.connect.sql(query).schema,
)
+ self._check_print_schema(query)
# test DayTimeIntervalType
query = """ SELECT INTERVAL '100 10:30' DAY TO MINUTE AS interval """
@@ -393,6 +415,7 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
self.spark.sql(query).schema,
self.connect.sql(query).schema,
)
+ self._check_print_schema(query)
# test MapType
query = """
@@ -406,6 +429,7 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
self.spark.sql(query).schema,
self.connect.sql(query).schema,
)
+ self._check_print_schema(query)
# test ArrayType
query = """
@@ -419,6 +443,7 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
self.spark.sql(query).schema,
self.connect.sql(query).schema,
)
+ self._check_print_schema(query)
# test StructType
query = """
@@ -432,6 +457,7 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
self.spark.sql(query).schema,
self.connect.sql(query).schema,
)
+ self._check_print_schema(query)
def test_to(self):
# SPARK-41464: test DataFrame.to()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]