This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 0e1401dc71b [SPARK-43894][PYTHON] Fix bug in df.cache()
0e1401dc71b is described below

commit 0e1401dc71b5aee540a54fc6a36a1857b13390b4
Author: Martin Grund <[email protected]>
AuthorDate: Wed May 31 11:55:19 2023 -0400

    [SPARK-43894][PYTHON] Fix bug in df.cache()
    
    ### What changes were proposed in this pull request?
    Previously calling `df.cache()` would result in an invalid plan input 
exception because we did not invoke `persist()` with the right arguments. This 
patch simplifies the logic and makes it compatible to the behavior in Spark 
itself.
    
    ### Why are the changes needed?
    Bug
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Added UT
    
    Closes #41399 from grundprinzip/df_cache.
    
    Authored-by: Martin Grund <[email protected]>
    Signed-off-by: Herman van Hovell <[email protected]>
    (cherry picked from commit d3f76c6ca07a7a11fd228dde770186c0fbc3f03f)
    Signed-off-by: Herman van Hovell <[email protected]>
---
 python/pyspark/sql/connect/dataframe.py                | 4 +---
 python/pyspark/sql/tests/connect/test_connect_basic.py | 6 ++++++
 2 files changed, 7 insertions(+), 3 deletions(-)

diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index ca2e1b7a0dc..03049109061 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -1544,9 +1544,7 @@ class DataFrame:
     def cache(self) -> "DataFrame":
         if self._plan is None:
             raise Exception("Cannot cache on empty plan.")
-        relation = self._plan.plan(self._session.client)
-        self._session.client._analyze(method="persist", relation=relation)
-        return self
+        return self.persist()
 
     cache.__doc__ = PySparkDataFrame.cache.__doc__
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 008b95d6f14..b051b9233c8 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -3032,6 +3032,12 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
             message_parameters={"attr_name": "_jreader"},
         )
 
+    def test_df_caache(self):
+        df = self.connect.range(10)
+        df.cache()
+        self.assert_eq(10, df.count())
+        self.assertTrue(df.is_cached)
+
 
 class SparkConnectSessionTests(SparkConnectSQLTestCase):
     def _check_no_active_session_error(self, e: PySparkException):


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to