This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 79f61985c8de [SPARK-46385][PYTHON][TESTS] Test aggregate functions for
groups (pyspark.sql.group)
79f61985c8de is described below
commit 79f61985c8de6c1be5f82c79c50f2b7aa7b46f67
Author: Xinrong Meng <[email protected]>
AuthorDate: Wed Dec 13 12:23:31 2023 -0800
[SPARK-46385][PYTHON][TESTS] Test aggregate functions for groups
(pyspark.sql.group)
### What changes were proposed in this pull request?
Test aggregate functions for groups (pyspark.sql.group)
### Why are the changes needed?
Subtasks of
[SPARK-46041](https://issues.apache.org/jira/browse/SPARK-46041) to improve
test coverage
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Test change only.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #44322 from xinrong-meng/test_group.
Authored-by: Xinrong Meng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/tests/test_group.py | 23 +++++++++++++++++++++++
1 file changed, 23 insertions(+)
diff --git a/python/pyspark/sql/tests/test_group.py
b/python/pyspark/sql/tests/test_group.py
index 6981601cb129..6c84bd740171 100644
--- a/python/pyspark/sql/tests/test_group.py
+++ b/python/pyspark/sql/tests/test_group.py
@@ -22,6 +22,29 @@ from pyspark.testing import assertDataFrameEqual,
assertSchemaEqual
class GroupTestsMixin:
+ def test_agg_func(self):
+ data = [Row(key=1, value=10), Row(key=1, value=20), Row(key=1,
value=30)]
+ df = self.spark.createDataFrame(data)
+ g = df.groupBy("key")
+ self.assertEqual(g.max("value").collect(), [Row(**{"key": 1,
"max(value)": 30})])
+ self.assertEqual(g.min("value").collect(), [Row(**{"key": 1,
"min(value)": 10})])
+ self.assertEqual(g.sum("value").collect(), [Row(**{"key": 1,
"sum(value)": 60})])
+ self.assertEqual(g.count().collect(), [Row(key=1, count=3)])
+ self.assertEqual(g.mean("value").collect(), [Row(**{"key": 1,
"avg(value)": 20.0})])
+
+ data = [
+ Row(electronic="Smartphone", year=2018, sales=150000),
+ Row(electronic="Tablet", year=2018, sales=120000),
+ Row(electronic="Smartphone", year=2019, sales=180000),
+ Row(electronic="Tablet", year=2019, sales=50000),
+ ]
+
+ df_pivot = self.spark.createDataFrame(data)
+ assertDataFrameEqual(
+ df_pivot.groupBy("year").pivot("electronic", ["Smartphone",
"Tablet"]).sum("sales"),
+ df_pivot.groupBy("year").pivot("electronic").sum("sales"),
+ )
+
def test_aggregator(self):
df = self.df
g = df.groupBy()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]