This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1a89bdc60d5 [SPARK-44620][SQL][PS][CONNECT] Make `ResolvePivot` retain 
the `Plan_ID_TAG`
1a89bdc60d5 is described below

commit 1a89bdc60d55394a1a9d94d4fa69fa5ab8041671
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Thu Aug 3 11:57:34 2023 +0900

    [SPARK-44620][SQL][PS][CONNECT] Make `ResolvePivot` retain the `Plan_ID_TAG`
    
    ### What changes were proposed in this pull request?
    Make `ResolvePivot` retain the `Plan_ID_TAG`
    
    ### Why are the changes needed?
    to resolve the `AnalysisException` in Pandas APIs on Connect
    
    ### Does this PR introduce _any_ user-facing change?
    yes, new APIs enabled:
    
    1. `frame.pivot_table`
    2. `frame.transpose`
    3. `series.unstack`
    
    ### How was this patch tested?
    enabled UTs
    
    Closes #42261 from zhengruifeng/ps_connect_analyze_pivot.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../tests/connect/computation/test_parity_pivot.py      | 17 +----------------
 .../pandas/tests/connect/frame/test_parity_reshaping.py | 11 +----------
 .../pandas/tests/connect/series/test_parity_compute.py  |  6 +-----
 .../pandas/tests/connect/test_parity_categorical.py     |  6 ------
 .../apache/spark/sql/catalyst/analysis/Analyzer.scala   | 10 +++++++---
 5 files changed, 10 insertions(+), 40 deletions(-)

diff --git 
a/python/pyspark/pandas/tests/connect/computation/test_parity_pivot.py 
b/python/pyspark/pandas/tests/connect/computation/test_parity_pivot.py
index d2c4f9ae607..c8ec48eb06a 100644
--- a/python/pyspark/pandas/tests/connect/computation/test_parity_pivot.py
+++ b/python/pyspark/pandas/tests/connect/computation/test_parity_pivot.py
@@ -16,28 +16,13 @@
 #
 import unittest
 
-from pyspark import pandas as ps
 from pyspark.pandas.tests.computation.test_pivot import FramePivotMixin
 from pyspark.testing.connectutils import ReusedConnectTestCase
 from pyspark.testing.pandasutils import PandasOnSparkTestUtils
 
 
 class FrameParityPivotTests(FramePivotMixin, PandasOnSparkTestUtils, 
ReusedConnectTestCase):
-    @property
-    def psdf(self):
-        return ps.from_pandas(self.pdf)
-
-    @unittest.skip(
-        "TODO(SPARK-43611): Fix unexpected `AnalysisException` from Spark 
Connect client."
-    )
-    def test_pivot_table(self):
-        super().test_pivot_table()
-
-    @unittest.skip(
-        "TODO(SPARK-43611): Fix unexpected `AnalysisException` from Spark 
Connect client."
-    )
-    def test_pivot_table_dtypes(self):
-        super().test_pivot_table_dtypes()
+    pass
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/pandas/tests/connect/frame/test_parity_reshaping.py 
b/python/pyspark/pandas/tests/connect/frame/test_parity_reshaping.py
index 98ebf3ca44a..e4bac7b078e 100644
--- a/python/pyspark/pandas/tests/connect/frame/test_parity_reshaping.py
+++ b/python/pyspark/pandas/tests/connect/frame/test_parity_reshaping.py
@@ -16,22 +16,13 @@
 #
 import unittest
 
-from pyspark import pandas as ps
 from pyspark.pandas.tests.frame.test_reshaping import FrameReshapingMixin
 from pyspark.testing.connectutils import ReusedConnectTestCase
 from pyspark.testing.pandasutils import PandasOnSparkTestUtils
 
 
 class FrameParityReshapingTests(FrameReshapingMixin, PandasOnSparkTestUtils, 
ReusedConnectTestCase):
-    @property
-    def psdf(self):
-        return ps.from_pandas(self.pdf)
-
-    @unittest.skip(
-        "TODO(SPARK-43611): Fix unexpected `AnalysisException` from Spark 
Connect client."
-    )
-    def test_transpose(self):
-        super().test_transpose()
+    pass
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/pandas/tests/connect/series/test_parity_compute.py 
b/python/pyspark/pandas/tests/connect/series/test_parity_compute.py
index f757d19ca69..8876fcb1398 100644
--- a/python/pyspark/pandas/tests/connect/series/test_parity_compute.py
+++ b/python/pyspark/pandas/tests/connect/series/test_parity_compute.py
@@ -22,11 +22,7 @@ from pyspark.testing.pandasutils import 
PandasOnSparkTestUtils
 
 
 class SeriesParityComputeTests(SeriesComputeMixin, PandasOnSparkTestUtils, 
ReusedConnectTestCase):
-    @unittest.skip(
-        "TODO(SPARK-43611): Fix unexpected `AnalysisException` from Spark 
Connect client."
-    )
-    def test_unstack(self):
-        super().test_unstack()
+    pass
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/pandas/tests/connect/test_parity_categorical.py 
b/python/pyspark/pandas/tests/connect/test_parity_categorical.py
index 3e05eb2c0f3..210cfce8ddb 100644
--- a/python/pyspark/pandas/tests/connect/test_parity_categorical.py
+++ b/python/pyspark/pandas/tests/connect/test_parity_categorical.py
@@ -53,12 +53,6 @@ class CategoricalParityTests(
     def test_set_categories(self):
         super().test_set_categories()
 
-    @unittest.skip(
-        "TODO(SPARK-43611): Fix unexpected `AnalysisException` from Spark 
Connect client."
-    )
-    def test_unstack(self):
-        super().test_unstack()
-
 
 if __name__ == "__main__":
     from pyspark.pandas.tests.connect.test_parity_categorical import *  # 
noqa: F401
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 1de745baa05..6c1d774a1b5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -759,7 +759,7 @@ class Analyzer(override val catalogManager: CatalogManager) 
extends RuleExecutor
       case p: Pivot if !p.childrenResolved || !p.aggregates.forall(_.resolved)
         || (p.groupByExprsOpt.isDefined && 
!p.groupByExprsOpt.get.forall(_.resolved))
         || !p.pivotColumn.resolved || !p.pivotValues.forall(_.resolved) => p
-      case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) 
=>
+      case p @ Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, 
child) =>
         if (!RowOrdering.isOrderable(pivotColumn.dataType)) {
           throw QueryCompilationErrors.unorderablePivotColError(pivotColumn)
         }
@@ -823,7 +823,9 @@ class Analyzer(override val catalogManager: CatalogManager) 
extends RuleExecutor
               Alias(ExtractValue(pivotAtt, Literal(i), resolver), 
outputName(value, aggregate))()
             }
           }
-          Project(groupByExprsAttr ++ pivotOutputs, secondAgg)
+          val newProject = Project(groupByExprsAttr ++ pivotOutputs, secondAgg)
+          newProject.copyTagsFrom(p)
+          newProject
         } else {
           val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { 
value =>
             def ifExpr(e: Expression) = {
@@ -857,7 +859,9 @@ class Analyzer(override val catalogManager: CatalogManager) 
extends RuleExecutor
               Alias(filteredAggregate, outputName(value, aggregate))()
             }
           }
-          Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child)
+          val newAggregate = Aggregate(groupByExprs, groupByExprs ++ 
pivotAggregates, child)
+          newAggregate.copyTagsFrom(p)
+          newAggregate
         }
     }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to