This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d1d29c9840fe [SPARK-48598][PYTHON][CONNECT] Propagate cached schema in 
dataframe operations
d1d29c9840fe is described below

commit d1d29c9840fedecc9b5d74651526359a2b70377e
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Jun 12 16:48:24 2024 -0700

    [SPARK-48598][PYTHON][CONNECT] Propagate cached schema in dataframe 
operations
    
    ### What changes were proposed in this pull request?
    Propagate cached schema in dataframe operations:
    
    - DataFrame.alias
    - DataFrame.coalesce
    - DataFrame.repartition
    - DataFrame.repartitionByRange
    - DataFrame.dropDuplicates
    - DataFrame.distinct
    - DataFrame.filter
    - DataFrame.where
    - DataFrame.limit
    - DataFrame.sort
    - DataFrame.sortWithinPartitions
    - DataFrame.orderBy
    - DataFrame.sample
    - DataFrame.hint
    - DataFrame.randomSplit
    - DataFrame.observe
    
    ### Why are the changes needed?
    to avoid unnecessary RPCs if possible
    
    ### Does this PR introduce _any_ user-facing change?
    No, optimization only
    
    ### How was this patch tested?
    added tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #46954 from zhengruifeng/py_connect_propagate_schema.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/connect/dataframe.py            | 69 ++++++++++++++++------
 .../connect/test_connect_dataframe_property.py     | 35 +++++++++++
 2 files changed, 85 insertions(+), 19 deletions(-)

diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index baac1523c709..f2705ec7ad71 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -262,7 +262,9 @@ class DataFrame(ParentDataFrame):
             return self.groupBy().agg(*exprs)
 
     def alias(self, alias: str) -> ParentDataFrame:
-        return DataFrame(plan.SubqueryAlias(self._plan, alias), 
session=self._session)
+        res = DataFrame(plan.SubqueryAlias(self._plan, alias), 
session=self._session)
+        res._cached_schema = self._cached_schema
+        return res
 
     def colRegex(self, colName: str) -> Column:
         from pyspark.sql.connect.column import Column as ConnectColumn
@@ -314,10 +316,12 @@ class DataFrame(ParentDataFrame):
                 error_class="VALUE_NOT_POSITIVE",
                 message_parameters={"arg_name": "numPartitions", "arg_value": 
str(numPartitions)},
             )
-        return DataFrame(
+        res = DataFrame(
             plan.Repartition(self._plan, num_partitions=numPartitions, 
shuffle=False),
             self._session,
         )
+        res._cached_schema = self._cached_schema
+        return res
 
     @overload
     def repartition(self, numPartitions: int, *cols: "ColumnOrName") -> 
ParentDataFrame:
@@ -340,12 +344,12 @@ class DataFrame(ParentDataFrame):
                     },
                 )
             if len(cols) == 0:
-                return DataFrame(
+                res = DataFrame(
                     plan.Repartition(self._plan, numPartitions, shuffle=True),
                     self._session,
                 )
             else:
-                return DataFrame(
+                res = DataFrame(
                     plan.RepartitionByExpression(
                         self._plan, numPartitions, [F._to_col(c) for c in cols]
                     ),
@@ -353,7 +357,7 @@ class DataFrame(ParentDataFrame):
                 )
         elif isinstance(numPartitions, (str, Column)):
             cols = (numPartitions,) + cols
-            return DataFrame(
+            res = DataFrame(
                 plan.RepartitionByExpression(self._plan, None, [F._to_col(c) 
for c in cols]),
                 self.sparkSession,
             )
@@ -366,6 +370,9 @@ class DataFrame(ParentDataFrame):
                 },
             )
 
+        res._cached_schema = self._cached_schema
+        return res
+
     @overload
     def repartitionByRange(self, numPartitions: int, *cols: "ColumnOrName") -> 
ParentDataFrame:
         ...
@@ -392,14 +399,14 @@ class DataFrame(ParentDataFrame):
                     message_parameters={"item": "cols"},
                 )
             else:
-                return DataFrame(
+                res = DataFrame(
                     plan.RepartitionByExpression(
                         self._plan, numPartitions, [F._sort_col(c) for c in 
cols]
                     ),
                     self.sparkSession,
                 )
         elif isinstance(numPartitions, (str, Column)):
-            return DataFrame(
+            res = DataFrame(
                 plan.RepartitionByExpression(
                     self._plan, None, [F._sort_col(c) for c in [numPartitions] 
+ list(cols)]
                 ),
@@ -414,6 +421,9 @@ class DataFrame(ParentDataFrame):
                 },
             )
 
+        res._cached_schema = self._cached_schema
+        return res
+
     def dropDuplicates(self, *subset: Union[str, List[str]]) -> 
ParentDataFrame:
         # Acceptable args should be str, ... or a single List[str]
         # So if subset length is 1, it can be either single str, or a list of 
str
@@ -422,20 +432,23 @@ class DataFrame(ParentDataFrame):
             assert all(isinstance(c, str) for c in subset)
 
         if not subset:
-            return DataFrame(
+            res = DataFrame(
                 plan.Deduplicate(child=self._plan, all_columns_as_keys=True), 
session=self._session
             )
         elif len(subset) == 1 and isinstance(subset[0], list):
-            return DataFrame(
+            res = DataFrame(
                 plan.Deduplicate(child=self._plan, column_names=subset[0]),
                 session=self._session,
             )
         else:
-            return DataFrame(
+            res = DataFrame(
                 plan.Deduplicate(child=self._plan, 
column_names=cast(List[str], subset)),
                 session=self._session,
             )
 
+        res._cached_schema = self._cached_schema
+        return res
+
     drop_duplicates = dropDuplicates
 
     def dropDuplicatesWithinWatermark(self, *subset: Union[str, List[str]]) -> 
ParentDataFrame:
@@ -466,9 +479,11 @@ class DataFrame(ParentDataFrame):
             )
 
     def distinct(self) -> ParentDataFrame:
-        return DataFrame(
+        res = DataFrame(
             plan.Deduplicate(child=self._plan, all_columns_as_keys=True), 
session=self._session
         )
+        res._cached_schema = self._cached_schema
+        return res
 
     @overload
     def drop(self, cols: "ColumnOrName") -> ParentDataFrame:
@@ -499,7 +514,9 @@ class DataFrame(ParentDataFrame):
             expr = F.expr(condition)
         else:
             expr = condition
-        return DataFrame(plan.Filter(child=self._plan, filter=expr), 
session=self._session)
+        res = DataFrame(plan.Filter(child=self._plan, filter=expr), 
session=self._session)
+        res._cached_schema = self._cached_schema
+        return res
 
     def first(self) -> Optional[Row]:
         return self.head()
@@ -709,7 +726,9 @@ class DataFrame(ParentDataFrame):
         )
 
     def limit(self, n: int) -> ParentDataFrame:
-        return DataFrame(plan.Limit(child=self._plan, limit=n), 
session=self._session)
+        res = DataFrame(plan.Limit(child=self._plan, limit=n), 
session=self._session)
+        res._cached_schema = self._cached_schema
+        return res
 
     def tail(self, num: int) -> List[Row]:
         return DataFrame(plan.Tail(child=self._plan, limit=num), 
session=self._session).collect()
@@ -766,7 +785,7 @@ class DataFrame(ParentDataFrame):
         *cols: Union[int, str, Column, List[Union[int, str, Column]]],
         **kwargs: Any,
     ) -> ParentDataFrame:
-        return DataFrame(
+        res = DataFrame(
             plan.Sort(
                 self._plan,
                 columns=self._sort_cols(cols, kwargs),
@@ -774,6 +793,8 @@ class DataFrame(ParentDataFrame):
             ),
             session=self._session,
         )
+        res._cached_schema = self._cached_schema
+        return res
 
     orderBy = sort
 
@@ -782,7 +803,7 @@ class DataFrame(ParentDataFrame):
         *cols: Union[int, str, Column, List[Union[int, str, Column]]],
         **kwargs: Any,
     ) -> ParentDataFrame:
-        return DataFrame(
+        res = DataFrame(
             plan.Sort(
                 self._plan,
                 columns=self._sort_cols(cols, kwargs),
@@ -790,6 +811,8 @@ class DataFrame(ParentDataFrame):
             ),
             session=self._session,
         )
+        res._cached_schema = self._cached_schema
+        return res
 
     def sample(
         self,
@@ -837,7 +860,7 @@ class DataFrame(ParentDataFrame):
 
         seed = int(seed) if seed is not None else random.randint(0, 
sys.maxsize)
 
-        return DataFrame(
+        res = DataFrame(
             plan.Sample(
                 child=self._plan,
                 lower_bound=0.0,
@@ -847,6 +870,8 @@ class DataFrame(ParentDataFrame):
             ),
             session=self._session,
         )
+        res._cached_schema = self._cached_schema
+        return res
 
     def withColumnRenamed(self, existing: str, new: str) -> ParentDataFrame:
         return self.withColumnsRenamed({existing: new})
@@ -1050,10 +1075,12 @@ class DataFrame(ParentDataFrame):
                         },
                     )
 
-        return DataFrame(
+        res = DataFrame(
             plan.Hint(self._plan, name, [F.lit(p) for p in list(parameters)]),
             session=self._session,
         )
+        res._cached_schema = self._cached_schema
+        return res
 
     def randomSplit(
         self,
@@ -1094,6 +1121,7 @@ class DataFrame(ParentDataFrame):
                 ),
                 session=self._session,
             )
+            samplePlan._cached_schema = self._cached_schema
             splits.append(samplePlan)
             j += 1
 
@@ -1118,9 +1146,9 @@ class DataFrame(ParentDataFrame):
             )
 
         if isinstance(observation, Observation):
-            return observation._on(self, *exprs)
+            res = observation._on(self, *exprs)
         elif isinstance(observation, str):
-            return DataFrame(
+            res = DataFrame(
                 plan.CollectMetrics(self._plan, observation, list(exprs)),
                 self._session,
             )
@@ -1133,6 +1161,9 @@ class DataFrame(ParentDataFrame):
                 },
             )
 
+        res._cached_schema = self._cached_schema
+        return res
+
     def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: 
bool = False) -> None:
         print(self._show_string(n, truncate, vertical))
 
diff --git 
a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py 
b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py
index 4a7e1e1ea760..c712e5d6efcb 100644
--- a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py
+++ b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py
@@ -20,6 +20,9 @@ import unittest
 from pyspark.sql.types import StructType, StructField, StringType, 
IntegerType, LongType, DoubleType
 from pyspark.sql.utils import is_remote
 
+from pyspark.sql import functions as SF
+from pyspark.sql.connect import functions as CF
+
 from pyspark.sql.tests.connect.test_connect_basic import 
SparkConnectSQLTestCase
 from pyspark.testing.sqlutils import (
     have_pandas,
@@ -393,6 +396,38 @@ class 
SparkConnectDataFramePropertyTests(SparkConnectSQLTestCase):
         # cannot infer when schemas mismatch
         self.assertTrue(cdf1.intersectAll(cdf3)._cached_schema is None)
 
+    def test_cached_schema_in_chain_op(self):
+        data = [(1, 1.0), (2, 2.0), (1, 3.0), (2, 4.0)]
+
+        cdf = self.connect.createDataFrame(data, ("id", "v1"))
+        sdf = self.spark.createDataFrame(data, ("id", "v1"))
+
+        cdf1 = cdf.withColumn("v2", CF.lit(1))
+        sdf1 = sdf.withColumn("v2", SF.lit(1))
+
+        self.assertTrue(cdf1._cached_schema is None)
+        # trigger analysis of cdf1.schema
+        self.assertEqual(cdf1.schema, sdf1.schema)
+        self.assertTrue(cdf1._cached_schema is not None)
+
+        cdf2 = cdf1.where(cdf1.v2 > 0)
+        sdf2 = sdf1.where(sdf1.v2 > 0)
+        self.assertEqual(cdf1._cached_schema, cdf2._cached_schema)
+
+        cdf3 = cdf2.repartition(10)
+        sdf3 = sdf2.repartition(10)
+        self.assertEqual(cdf1._cached_schema, cdf3._cached_schema)
+
+        cdf4 = cdf3.distinct()
+        sdf4 = sdf3.distinct()
+        self.assertEqual(cdf1._cached_schema, cdf4._cached_schema)
+
+        cdf5 = cdf4.sample(fraction=0.5)
+        sdf5 = sdf4.sample(fraction=0.5)
+        self.assertEqual(cdf1._cached_schema, cdf5._cached_schema)
+
+        self.assertEqual(cdf5.schema, sdf5.schema)
+
 
 if __name__ == "__main__":
     from pyspark.sql.tests.connect.test_connect_dataframe_property import *  # 
noqa: F401


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to