This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d863503e8737 [SPARK-48434][PYTHON][CONNECT] Make `printSchema` use the 
cached schema
d863503e8737 is described below

commit d863503e8737937fc90c68583a3762fa67f53401
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Tue May 28 13:37:25 2024 +0900

    [SPARK-48434][PYTHON][CONNECT] Make `printSchema` use the cached schema
    
    ### What changes were proposed in this pull request?
    Make `printSchema` use the cached schema
    
    ### Why are the changes needed?
    to avoid extra RPCs
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    added tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #46764 from zhengruifeng/connect_print_schema.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/connect/dataframe.py            |  5 ++++-
 .../sql/tests/connect/test_connect_basic.py        | 26 ++++++++++++++++++++++
 2 files changed, 30 insertions(+), 1 deletion(-)

diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 62c73da374bc..354cf60c2014 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -1811,7 +1811,10 @@ class DataFrame(ParentDataFrame):
         return result
 
     def printSchema(self, level: Optional[int] = None) -> None:
-        print(self._tree_string(level))
+        if level:
+            print(self.schema.treeString(level))
+        else:
+            print(self.schema.treeString())
 
     def inputFiles(self) -> List[str]:
         query = self._plan.to_proto(self._session.client)
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 0648b5ce9925..eb5cb18d11e6 100755
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -20,6 +20,8 @@ import gc
 import unittest
 import shutil
 import tempfile
+import io
+from contextlib import redirect_stdout
 
 from pyspark.util import is_remote_only
 from pyspark.errors import PySparkTypeError, PySparkValueError
@@ -352,6 +354,24 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
         result = df._explain_string()
         self.assertGreater(len(result), 0)
 
+    def _check_print_schema(self, query: str):
+        with io.StringIO() as buf, redirect_stdout(buf):
+            self.spark.sql(query).printSchema()
+            print1 = buf.getvalue()
+        with io.StringIO() as buf, redirect_stdout(buf):
+            self.connect.sql(query).printSchema()
+            print2 = buf.getvalue()
+        self.assertEqual(print1, print2, query)
+
+        for level in [-1, 0, 1, 2, 3, 4]:
+            with io.StringIO() as buf, redirect_stdout(buf):
+                self.spark.sql(query).printSchema(level)
+                print1 = buf.getvalue()
+            with io.StringIO() as buf, redirect_stdout(buf):
+                self.connect.sql(query).printSchema(level)
+                print2 = buf.getvalue()
+            self.assertEqual(print1, print2, query)
+
     def test_schema(self):
         schema = self.connect.read.table(self.tbl_name).schema
         self.assertEqual(
@@ -373,6 +393,7 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
             self.spark.sql(query).schema,
             self.connect.sql(query).schema,
         )
+        self._check_print_schema(query)
 
         # test TimestampType, DateType
         query = """
@@ -386,6 +407,7 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
             self.spark.sql(query).schema,
             self.connect.sql(query).schema,
         )
+        self._check_print_schema(query)
 
         # test DayTimeIntervalType
         query = """ SELECT INTERVAL '100 10:30' DAY TO MINUTE AS interval """
@@ -393,6 +415,7 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
             self.spark.sql(query).schema,
             self.connect.sql(query).schema,
         )
+        self._check_print_schema(query)
 
         # test MapType
         query = """
@@ -406,6 +429,7 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
             self.spark.sql(query).schema,
             self.connect.sql(query).schema,
         )
+        self._check_print_schema(query)
 
         # test ArrayType
         query = """
@@ -419,6 +443,7 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
             self.spark.sql(query).schema,
             self.connect.sql(query).schema,
         )
+        self._check_print_schema(query)
 
         # test StructType
         query = """
@@ -432,6 +457,7 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
             self.spark.sql(query).schema,
             self.connect.sql(query).schema,
         )
+        self._check_print_schema(query)
 
     def test_to(self):
         # SPARK-41464: test DataFrame.to()


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to