This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 32ff77cdb8e [SPARK-41325][CONNECT] Fix missing avg() for GroupBy on DF 32ff77cdb8e is described below commit 32ff77cdb8ef4973494beb1a31ced05ea493dc6d Author: Martin Grund <martin.gr...@databricks.com> AuthorDate: Wed Nov 30 19:06:12 2022 +0800 [SPARK-41325][CONNECT] Fix missing avg() for GroupBy on DF ### What changes were proposed in this pull request? Previously, the `avg` function was missing in the `GroupedData` class. This patch adds this method and the necessary plan transformation using an unresolved function. In addition, it identified a small issue where when an alias is used for a grouping column, the planner would incorrectly try to wrap the existing alias expression using an unresolved alias which would then fail. ``` df = ( self.connect.range(10) .groupBy((col("id") % lit(2)).alias("moded")) .avg("id") .sort("moded") ) ``` ### Why are the changes needed? Bug / Compatibility ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT Closes #38841 from grundprinzip/SPARK-41325. Authored-by: Martin Grund <martin.gr...@databricks.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../spark/sql/connect/planner/SparkConnectPlanner.scala | 3 ++- python/pyspark/sql/connect/dataframe.py | 4 ++++ python/pyspark/sql/tests/connect/test_connect_basic.py | 13 +++++++++++++ 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 7b9e13cadab..d1d4c3d4fa9 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -682,7 +682,8 @@ class SparkConnectPlanner(session: SparkSession) { rel.getGroupingExpressionsList.asScala .map(transformExpression) .map { - case x @ UnresolvedAttribute(_) => x + case ua @ UnresolvedAttribute(_) => ua + case a @ Alias(_, _) => a case x => UnresolvedAlias(x) } diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index c9960a71fb8..ebfb52cdd74 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -84,6 +84,10 @@ class GroupedData(object): expr = self._map_cols_to_expression("sum", col) return self.agg(expr) + def avg(self, col: Union[Column, str]) -> "DataFrame": + expr = self._map_cols_to_expression("avg", col) + return self.agg(expr) + def count(self) -> "DataFrame": return self.agg([scalar_function("count", lit(1))]) diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index f518a09ad4a..22d57994794 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -837,6 +837,19 @@ class SparkConnectTests(SparkConnectSQLTestCase): ndf = self.connect.read.table("parquet_test") self.assertEqual(set(df.collect()), set(ndf.collect())) + def test_agg_with_avg(self): + # SPARK-41325: groupby.avg() + df = ( + self.connect.range(10) + .groupBy((col("id") % lit(2)).alias("moded")) + .avg("id") + .sort("moded") + ) + res = df.collect() + self.assertEqual(2, len(res)) + self.assertEqual(4.0, res[0][1]) + self.assertEqual(5.0, res[1][1]) + class ChannelBuilderTests(ReusedPySparkTestCase): def test_invalid_connection_strings(self): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org