This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 3d7da16440b [SPARK-41785][CONNECT][PYTHON] Implement `GroupedData.mean`
3d7da16440b is described below
commit 3d7da16440b824b9edd423ca31b795a3f9044f3c
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Sat Dec 31 09:53:32 2022 +0900
[SPARK-41785][CONNECT][PYTHON] Implement `GroupedData.mean`
### What changes were proposed in this pull request?
Implement `GroupedData.mean` - the last missing API for grouped data
### Why are the changes needed?
for api coverage
### Does this PR introduce _any_ user-facing change?
yes
### How was this patch tested?
added ut
Closes #39304 from zhengruifeng/connect_grouped_mean.
Lead-authored-by: Ruifeng Zheng <[email protected]>
Co-authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/connect/group.py | 2 ++
python/pyspark/sql/tests/connect/test_connect_basic.py | 12 ++++++++++++
2 files changed, 14 insertions(+)
diff --git a/python/pyspark/sql/connect/group.py
b/python/pyspark/sql/connect/group.py
index fd6f9816e2d..a6006c64158 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -166,6 +166,8 @@ class GroupedData:
avg.__doc__ = PySparkGroupedData.avg.__doc__
+ mean = avg
+
def count(self) -> "DataFrame":
return self.agg(scalar_function("count", lit(1)))
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 9663f3123f9..1ea93f7b743 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -1483,6 +1483,10 @@ class SparkConnectTests(SparkConnectSQLTestCase):
cdf.groupBy("name", cdf.department).avg("salary",
"year").toPandas(),
sdf.groupBy("name", sdf.department).avg("salary",
"year").toPandas(),
)
+ self.assert_eq(
+ cdf.groupBy("name", cdf.department).mean("salary",
"year").toPandas(),
+ sdf.groupBy("name", sdf.department).mean("salary",
"year").toPandas(),
+ )
self.assert_eq(
cdf.groupBy("name", cdf.department).sum("salary",
"year").toPandas(),
sdf.groupBy("name", sdf.department).sum("salary",
"year").toPandas(),
@@ -1505,6 +1509,10 @@ class SparkConnectTests(SparkConnectSQLTestCase):
cdf.rollup("name", cdf.department).avg("salary",
"year").toPandas(),
sdf.rollup("name", sdf.department).avg("salary",
"year").toPandas(),
)
+ self.assert_eq(
+ cdf.rollup("name", cdf.department).mean("salary",
"year").toPandas(),
+ sdf.rollup("name", sdf.department).mean("salary",
"year").toPandas(),
+ )
self.assert_eq(
cdf.rollup("name", cdf.department).sum("salary",
"year").toPandas(),
sdf.rollup("name", sdf.department).sum("salary",
"year").toPandas(),
@@ -1515,6 +1523,10 @@ class SparkConnectTests(SparkConnectSQLTestCase):
cdf.cube("name").avg().toPandas(),
sdf.cube("name").avg().toPandas(),
)
+ self.assert_eq(
+ cdf.cube("name").mean().toPandas(),
+ sdf.cube("name").mean().toPandas(),
+ )
self.assert_eq(
cdf.cube("name").min("salary").toPandas(),
sdf.cube("name").min("salary").toPandas(),
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]