This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3d7da16440b [SPARK-41785][CONNECT][PYTHON] Implement `GroupedData.mean`
3d7da16440b is described below

commit 3d7da16440b824b9edd423ca31b795a3f9044f3c
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Sat Dec 31 09:53:32 2022 +0900

    [SPARK-41785][CONNECT][PYTHON] Implement `GroupedData.mean`
    
    ### What changes were proposed in this pull request?
    Implement `GroupedData.mean` - the last missing API for grouped data
    
    ### Why are the changes needed?
    for api coverage
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    added ut
    
    Closes #39304 from zhengruifeng/connect_grouped_mean.
    
    Lead-authored-by: Ruifeng Zheng <[email protected]>
    Co-authored-by: Hyukjin Kwon <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/connect/group.py                    |  2 ++
 python/pyspark/sql/tests/connect/test_connect_basic.py | 12 ++++++++++++
 2 files changed, 14 insertions(+)

diff --git a/python/pyspark/sql/connect/group.py 
b/python/pyspark/sql/connect/group.py
index fd6f9816e2d..a6006c64158 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -166,6 +166,8 @@ class GroupedData:
 
     avg.__doc__ = PySparkGroupedData.avg.__doc__
 
+    mean = avg
+
     def count(self) -> "DataFrame":
         return self.agg(scalar_function("count", lit(1)))
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 9663f3123f9..1ea93f7b743 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -1483,6 +1483,10 @@ class SparkConnectTests(SparkConnectSQLTestCase):
             cdf.groupBy("name", cdf.department).avg("salary", 
"year").toPandas(),
             sdf.groupBy("name", sdf.department).avg("salary", 
"year").toPandas(),
         )
+        self.assert_eq(
+            cdf.groupBy("name", cdf.department).mean("salary", 
"year").toPandas(),
+            sdf.groupBy("name", sdf.department).mean("salary", 
"year").toPandas(),
+        )
         self.assert_eq(
             cdf.groupBy("name", cdf.department).sum("salary", 
"year").toPandas(),
             sdf.groupBy("name", sdf.department).sum("salary", 
"year").toPandas(),
@@ -1505,6 +1509,10 @@ class SparkConnectTests(SparkConnectSQLTestCase):
             cdf.rollup("name", cdf.department).avg("salary", 
"year").toPandas(),
             sdf.rollup("name", sdf.department).avg("salary", 
"year").toPandas(),
         )
+        self.assert_eq(
+            cdf.rollup("name", cdf.department).mean("salary", 
"year").toPandas(),
+            sdf.rollup("name", sdf.department).mean("salary", 
"year").toPandas(),
+        )
         self.assert_eq(
             cdf.rollup("name", cdf.department).sum("salary", 
"year").toPandas(),
             sdf.rollup("name", sdf.department).sum("salary", 
"year").toPandas(),
@@ -1515,6 +1523,10 @@ class SparkConnectTests(SparkConnectSQLTestCase):
             cdf.cube("name").avg().toPandas(),
             sdf.cube("name").avg().toPandas(),
         )
+        self.assert_eq(
+            cdf.cube("name").mean().toPandas(),
+            sdf.cube("name").mean().toPandas(),
+        )
         self.assert_eq(
             cdf.cube("name").min("salary").toPandas(),
             sdf.cube("name").min("salary").toPandas(),


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to