This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new bfb0f016817d [SPARK-46677][CONNECT][FOLLOWUP] Convert `count(df["*"])`
to `count(1)` on client side
bfb0f016817d is described below
commit bfb0f016817d9abfb648bd47f7c5164e6e1004a7
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Tue Jan 16 18:02:36 2024 +0800
[SPARK-46677][CONNECT][FOLLOWUP] Convert `count(df["*"])` to `count(1)` on
client side
### What changes were proposed in this pull request?
before https://github.com/apache/spark/pull/44689, `df["*"]` and
`sf.col("*")` are both convert to `UnresolvedStar`, and then
`Count(UnresolvedStar)` is converted to `Count(1)` in Analyzer:
https://github.com/apache/spark/blob/381f3691bd481abc8f621ca3f282e06db32bea31/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala#L1893-L1897
in that fix, we introduced a new node `UnresolvedDataFrameStar` for
`df["*"]` which will be replaced to `ResolvedStar` later. Unfortunately, it
doesn't match `Count(UnresolvedStar)` any more.
So it causes:
```
In [1]: from pyspark.sql import functions as sf
In [2]: df1 = spark.createDataFrame([{"id": 1, "val": "v"}])
In [3]: df1.select(sf.count(df1["*"]))
Out[3]: DataFrame[count(id, val): bigint]
```
which should be
```
In [3]: df1.select(sf.count(df1["*"]))
Out[3]: DataFrame[count(1): bigint]
```
In vanilla Spark, it is up to the `count` function to make such conversion
`sf.count(df1["*"])` -> `sf.count(sf.lit(1))`, see
https://github.com/apache/spark/blob/e8dfcd3081abe16b2115bb2944a2b1cb547eca8e/sql/core/src/main/scala/org/apache/spark/sql/functions.scala#L422-L436
So it is a natural way to fix this behavior on the client side.
### Why are the changes needed?
to keep the behavior
### Does this PR introduce _any_ user-facing change?
it fix a behavior change introduced in
https://github.com/apache/spark/pull/44689
### How was this patch tested?
added ut
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #44752 from zhengruifeng/connect_fix_count_df_star.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../src/main/scala/org/apache/spark/sql/functions.scala | 9 ++++++++-
.../test/resources/query-tests/queries/groupby_agg.json | 3 ++-
.../resources/query-tests/queries/groupby_agg.proto.bin | Bin 208 -> 210 bytes
python/pyspark/sql/connect/functions/builtin.py | 2 ++
python/pyspark/sql/tests/test_dataframe.py | 15 +++++++++++++++
5 files changed, 27 insertions(+), 2 deletions(-)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
index 9191633171f7..2a48958d4222 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
@@ -402,7 +402,14 @@ object functions {
* @group agg_funcs
* @since 3.4.0
*/
- def count(e: Column): Column = Column.fn("count", e)
+ def count(e: Column): Column = {
+ val withoutStar = e.expr.getExprTypeCase match {
+ // Turn count(*) into count(1)
+ case proto.Expression.ExprTypeCase.UNRESOLVED_STAR => lit(1)
+ case _ => e
+ }
+ Column.fn("count", withoutStar)
+ }
/**
* Aggregate function: returns the number of items in a group.
diff --git
a/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.json
b/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.json
index 4a1cfddb0288..65f266794828 100644
---
a/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.json
+++
b/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.json
@@ -81,7 +81,8 @@
"unresolvedFunction": {
"functionName": "count",
"arguments": [{
- "unresolvedStar": {
+ "literal": {
+ "integer": 1
}
}]
}
diff --git
a/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.proto.bin
b/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.proto.bin
index cfd6c2daa84b..18d8c6ce4115 100644
Binary files
a/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.proto.bin
and
b/connector/connect/common/src/test/resources/query-tests/queries/groupby_agg.proto.bin
differ
diff --git a/python/pyspark/sql/connect/functions/builtin.py
b/python/pyspark/sql/connect/functions/builtin.py
index 2eeefc9fae23..1e22a42c6241 100644
--- a/python/pyspark/sql/connect/functions/builtin.py
+++ b/python/pyspark/sql/connect/functions/builtin.py
@@ -1010,6 +1010,8 @@ corr.__doc__ = pysparkfuncs.corr.__doc__
def count(col: "ColumnOrName") -> Column:
+ if isinstance(col, Column) and isinstance(col._expr, UnresolvedStar):
+ col = lit(1)
return _invoke_function_over_columns("count", col)
diff --git a/python/pyspark/sql/tests/test_dataframe.py
b/python/pyspark/sql/tests/test_dataframe.py
index 407ab22a088c..1788f1d9fb1a 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -104,6 +104,21 @@ class DataFrameTestsMixin:
self.assertEqual(df.select(df2["*"]).columns, ["a", "b"])
self.assertEqual(df.select(df3["*"]).columns, ["x", "y"])
+ def test_count_star(self):
+ df1 = self.spark.createDataFrame([{"a": 1}])
+ df2 = self.spark.createDataFrame([{"a": 1, "b": "v"}])
+ df3 = df2.select(struct("a", "b").alias("s"))
+
+ self.assertEqual(df1.select(count(df1["*"])).columns, ["count(1)"])
+ self.assertEqual(df1.select(count(col("*"))).columns, ["count(1)"])
+
+ self.assertEqual(df2.select(count(df2["*"])).columns, ["count(1)"])
+ self.assertEqual(df2.select(count(col("*"))).columns, ["count(1)"])
+
+ self.assertEqual(df3.select(count(df3["*"])).columns, ["count(1)"])
+ self.assertEqual(df3.select(count(col("*"))).columns, ["count(1)"])
+ self.assertEqual(df3.select(count(col("s.*"))).columns, ["count(1)"])
+
def test_self_join(self):
df1 = self.spark.range(10).withColumn("a", lit(0))
df2 = df1.withColumnRenamed("a", "b")
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]