This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 1a89bdc60d5 [SPARK-44620][SQL][PS][CONNECT] Make `ResolvePivot` retain the `Plan_ID_TAG` 1a89bdc60d5 is described below commit 1a89bdc60d55394a1a9d94d4fa69fa5ab8041671 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Thu Aug 3 11:57:34 2023 +0900 [SPARK-44620][SQL][PS][CONNECT] Make `ResolvePivot` retain the `Plan_ID_TAG` ### What changes were proposed in this pull request? Make `ResolvePivot` retain the `Plan_ID_TAG` ### Why are the changes needed? to resolve the `AnalysisException` in Pandas APIs on Connect ### Does this PR introduce _any_ user-facing change? yes, new APIs enabled: 1. `frame.pivot_table` 2. `frame.transpose` 3. `series.unstack` ### How was this patch tested? enabled UTs Closes #42261 from zhengruifeng/ps_connect_analyze_pivot. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../tests/connect/computation/test_parity_pivot.py | 17 +---------------- .../pandas/tests/connect/frame/test_parity_reshaping.py | 11 +---------- .../pandas/tests/connect/series/test_parity_compute.py | 6 +----- .../pandas/tests/connect/test_parity_categorical.py | 6 ------ .../apache/spark/sql/catalyst/analysis/Analyzer.scala | 10 +++++++--- 5 files changed, 10 insertions(+), 40 deletions(-) diff --git a/python/pyspark/pandas/tests/connect/computation/test_parity_pivot.py b/python/pyspark/pandas/tests/connect/computation/test_parity_pivot.py index d2c4f9ae607..c8ec48eb06a 100644 --- a/python/pyspark/pandas/tests/connect/computation/test_parity_pivot.py +++ b/python/pyspark/pandas/tests/connect/computation/test_parity_pivot.py @@ -16,28 +16,13 @@ # import unittest -from pyspark import pandas as ps from pyspark.pandas.tests.computation.test_pivot import FramePivotMixin from pyspark.testing.connectutils import ReusedConnectTestCase from pyspark.testing.pandasutils import PandasOnSparkTestUtils class FrameParityPivotTests(FramePivotMixin, PandasOnSparkTestUtils, ReusedConnectTestCase): - @property - def psdf(self): - return ps.from_pandas(self.pdf) - - @unittest.skip( - "TODO(SPARK-43611): Fix unexpected `AnalysisException` from Spark Connect client." - ) - def test_pivot_table(self): - super().test_pivot_table() - - @unittest.skip( - "TODO(SPARK-43611): Fix unexpected `AnalysisException` from Spark Connect client." - ) - def test_pivot_table_dtypes(self): - super().test_pivot_table_dtypes() + pass if __name__ == "__main__": diff --git a/python/pyspark/pandas/tests/connect/frame/test_parity_reshaping.py b/python/pyspark/pandas/tests/connect/frame/test_parity_reshaping.py index 98ebf3ca44a..e4bac7b078e 100644 --- a/python/pyspark/pandas/tests/connect/frame/test_parity_reshaping.py +++ b/python/pyspark/pandas/tests/connect/frame/test_parity_reshaping.py @@ -16,22 +16,13 @@ # import unittest -from pyspark import pandas as ps from pyspark.pandas.tests.frame.test_reshaping import FrameReshapingMixin from pyspark.testing.connectutils import ReusedConnectTestCase from pyspark.testing.pandasutils import PandasOnSparkTestUtils class FrameParityReshapingTests(FrameReshapingMixin, PandasOnSparkTestUtils, ReusedConnectTestCase): - @property - def psdf(self): - return ps.from_pandas(self.pdf) - - @unittest.skip( - "TODO(SPARK-43611): Fix unexpected `AnalysisException` from Spark Connect client." - ) - def test_transpose(self): - super().test_transpose() + pass if __name__ == "__main__": diff --git a/python/pyspark/pandas/tests/connect/series/test_parity_compute.py b/python/pyspark/pandas/tests/connect/series/test_parity_compute.py index f757d19ca69..8876fcb1398 100644 --- a/python/pyspark/pandas/tests/connect/series/test_parity_compute.py +++ b/python/pyspark/pandas/tests/connect/series/test_parity_compute.py @@ -22,11 +22,7 @@ from pyspark.testing.pandasutils import PandasOnSparkTestUtils class SeriesParityComputeTests(SeriesComputeMixin, PandasOnSparkTestUtils, ReusedConnectTestCase): - @unittest.skip( - "TODO(SPARK-43611): Fix unexpected `AnalysisException` from Spark Connect client." - ) - def test_unstack(self): - super().test_unstack() + pass if __name__ == "__main__": diff --git a/python/pyspark/pandas/tests/connect/test_parity_categorical.py b/python/pyspark/pandas/tests/connect/test_parity_categorical.py index 3e05eb2c0f3..210cfce8ddb 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_categorical.py +++ b/python/pyspark/pandas/tests/connect/test_parity_categorical.py @@ -53,12 +53,6 @@ class CategoricalParityTests( def test_set_categories(self): super().test_set_categories() - @unittest.skip( - "TODO(SPARK-43611): Fix unexpected `AnalysisException` from Spark Connect client." - ) - def test_unstack(self): - super().test_unstack() - if __name__ == "__main__": from pyspark.pandas.tests.connect.test_parity_categorical import * # noqa: F401 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1de745baa05..6c1d774a1b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -759,7 +759,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case p: Pivot if !p.childrenResolved || !p.aggregates.forall(_.resolved) || (p.groupByExprsOpt.isDefined && !p.groupByExprsOpt.get.forall(_.resolved)) || !p.pivotColumn.resolved || !p.pivotValues.forall(_.resolved) => p - case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) => + case p @ Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) => if (!RowOrdering.isOrderable(pivotColumn.dataType)) { throw QueryCompilationErrors.unorderablePivotColError(pivotColumn) } @@ -823,7 +823,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor Alias(ExtractValue(pivotAtt, Literal(i), resolver), outputName(value, aggregate))() } } - Project(groupByExprsAttr ++ pivotOutputs, secondAgg) + val newProject = Project(groupByExprsAttr ++ pivotOutputs, secondAgg) + newProject.copyTagsFrom(p) + newProject } else { val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => def ifExpr(e: Expression) = { @@ -857,7 +859,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor Alias(filteredAggregate, outputName(value, aggregate))() } } - Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child) + val newAggregate = Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child) + newAggregate.copyTagsFrom(p) + newAggregate } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org