This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 2d028a2ec19 [SPARK-41742] Support df.groupBy().agg({"*":"count"})
2d028a2ec19 is described below
commit 2d028a2ec19f1a9e41e3b2e893c412bd28ab53a4
Author: Martin Grund <[email protected]>
AuthorDate: Fri Dec 30 10:22:00 2022 +0800
[SPARK-41742] Support df.groupBy().agg({"*":"count"})
### What changes were proposed in this pull request?
Compatibility changes to support `count(*)` for DF operations that are
rewritten into `count(1)`.
### Why are the changes needed?
Compatibility.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
UT
Closes #39298 from grundprinzip/SPARK-41742.
Authored-by: Martin Grund <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/connect/group.py | 8 +++++++-
python/pyspark/sql/group.py | 4 +---
2 files changed, 8 insertions(+), 4 deletions(-)
diff --git a/python/pyspark/sql/connect/group.py
b/python/pyspark/sql/connect/group.py
index 4c074d6da1b..fd6f9816e2d 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -80,8 +80,14 @@ class GroupedData:
assert exprs, "exprs should not be empty"
if len(exprs) == 1 and isinstance(exprs[0], dict):
+ # There is a special case for count(*) which is rewritten into
count(1).
# Convert the dict into key value pairs
- aggregate_cols = [scalar_function(exprs[0][k], col(k)) for k in
exprs[0]]
+ aggregate_cols = [
+ scalar_function(
+ exprs[0][k], lit(1) if exprs[0][k] == "count" and k == "*"
else col(k)
+ )
+ for k in exprs[0]
+ ]
else:
# Columns
assert all(isinstance(c, Column) for c in exprs), "all exprs
should be Column"
diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py
index ac661e39741..10468988186 100644
--- a/python/pyspark/sql/group.py
+++ b/python/pyspark/sql/group.py
@@ -78,8 +78,6 @@ class GroupedData(PandasGroupedOpsMixin):
def agg(self, __exprs: Dict[str, str]) -> DataFrame:
...
- # TODO(SPARK-41279): Enable the doctest with supporting the star in Spark
Connect.
- # TODO(SPARK-41743): groupBy(...).agg(...).sort does not actually sort the
output
def agg(self, *exprs: Union[Column, Dict[str, str]]) -> DataFrame:
"""Compute aggregates and returns the result as a :class:`DataFrame`.
@@ -135,7 +133,7 @@ class GroupedData(PandasGroupedOpsMixin):
Group-by name, and count each group.
- >>> df.groupBy(df.name).agg({"*": "count"}).sort("name").show() #
doctest: +SKIP
+ >>> df.groupBy(df.name).agg({"*": "count"}).sort("name").show()
+-----+--------+
| name|count(1)|
+-----+--------+
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]