This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new eb4bb44d446 [SPARK-42099][SPARK-41845][CONNECT][PYTHON] Fix `count(*)`
and `count(col(*))`
eb4bb44d446 is described below
commit eb4bb44d446a0416c360da8127659b10f98e5ceb
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Sat Jan 21 09:48:25 2023 +0800
[SPARK-42099][SPARK-41845][CONNECT][PYTHON] Fix `count(*)` and
`count(col(*))`
### What changes were proposed in this pull request?
1, add `UnresolvedStar` to `expressions.py`;
2, Fix `count(*)` and `count(col(*))`, should return
`Column(UnresolvedStar(None))` instead of `Column(UnresolvedAttribute("*"))`,
see:
https://github.com/apache/spark/blob/68531ada34db72d352c39396f85458a8370af812/sql/core/src/main/scala/org/apache/spark/sql/Column.scala#L144-L150
3, remove the `count(*) -> count(1)` transformation in `group.py`, since
it's no longer needed.
### Why are the changes needed?
https://github.com/apache/spark/pull/39636 fixed the `count(*)` issue in
the server side, and then `count(expr(*))` works after that PR.
This PR makes the corresponding changes in the Python Client side, in order
to support `count(*)`, and `count(col(*))`
### Does this PR introduce _any_ user-facing change?
yes
### How was this patch tested?
enabled UT and added UT
Closes #39622 from zhengruifeng/connect_count_star.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../main/protobuf/spark/connect/expressions.proto | 9 ++--
.../sql/connect/planner/SparkConnectPlanner.scala | 16 ++++--
.../connect/planner/SparkConnectPlannerSuite.scala | 2 +-
python/pyspark/sql/connect/dataframe.py | 4 +-
python/pyspark/sql/connect/expressions.py | 46 +++++++++++++++--
python/pyspark/sql/connect/functions.py | 11 ++--
python/pyspark/sql/connect/group.py | 8 +--
.../pyspark/sql/connect/proto/expressions_pb2.py | 30 +++++------
.../pyspark/sql/connect/proto/expressions_pb2.pyi | 31 ++++++++----
.../sql/tests/connect/test_connect_function.py | 58 ++++++++++++++++++++++
10 files changed, 166 insertions(+), 49 deletions(-)
diff --git
a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index 349e2455be3..f7feae0e2f0 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -225,9 +225,12 @@ message Expression {
// UnresolvedStar is used to expand all the fields of a relation or struct.
message UnresolvedStar {
- // (Optional) The target of the expansion, either be a table name or
struct name, this
- // is a list of identifiers that is the path of the expansion.
- repeated string target = 1;
+
+ // (Optional) The target of the expansion.
+ //
+ // If set, it should end with '.*' and will be parsed by
'parseAttributeName'
+ // in the server side.
+ optional string unparsed_target = 1;
}
// Represents all of the input attributes to a given relational operator,
for example in
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 3d63558eb3e..d72aa162132 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1002,11 +1002,19 @@ class SparkConnectPlanner(session: SparkSession) {
session.sessionState.sqlParser.parseExpression(expr.getExpression)
}
- private def transformUnresolvedStar(regex: proto.Expression.UnresolvedStar):
Expression = {
- if (regex.getTargetList.isEmpty) {
- UnresolvedStar(Option.empty)
+ private def transformUnresolvedStar(star: proto.Expression.UnresolvedStar):
UnresolvedStar = {
+ if (star.hasUnparsedTarget) {
+ val target = star.getUnparsedTarget
+ if (!target.endsWith(".*")) {
+ throw InvalidPlanInput(
+ s"UnresolvedStar requires a unparsed target ending with '.*', " +
+ s"but got $target.")
+ }
+
+ UnresolvedStar(
+ Some(UnresolvedAttribute.parseAttributeName(target.substring(0,
target.length - 2))))
} else {
- UnresolvedStar(Some(regex.getTargetList.asScala.toSeq))
+ UnresolvedStar(None)
}
}
diff --git
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
index 63e5415b44f..d8baa182e5a 100644
---
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
+++
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
@@ -552,7 +552,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with
SparkConnectPlanTest {
.addExpressions(
proto.Expression
.newBuilder()
-
.setUnresolvedStar(UnresolvedStar.newBuilder().addTarget("a").addTarget("b").build())
+
.setUnresolvedStar(UnresolvedStar.newBuilder().setUnparsedTarget("a.b.*").build())
.build())
.build()
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index 11c0ef6fc06..d82862a870b 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -249,7 +249,7 @@ class DataFrame:
else:
return Column(SortOrder(col._expr))
else:
- return Column(SortOrder(ColumnReference(name=col)))
+ return Column(SortOrder(ColumnReference(col)))
if isinstance(numPartitions, int):
if not numPartitions > 0:
@@ -1176,7 +1176,7 @@ class DataFrame:
from pyspark.sql.connect.expressions import ColumnReference
if isinstance(col, str):
- col = Column(ColumnReference(name=col))
+ col = Column(ColumnReference(col))
elif not isinstance(col, Column):
raise TypeError("col must be a string or a column, but got %r" %
type(col))
if not isinstance(fractions, dict):
diff --git a/python/pyspark/sql/connect/expressions.py
b/python/pyspark/sql/connect/expressions.py
index 6469c1917ec..c8d361af2a5 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -336,10 +336,10 @@ class ColumnReference(Expression):
treat it as an unresolved attribute. Attributes that have the same fully
qualified name are identical"""
- def __init__(self, name: str) -> None:
+ def __init__(self, unparsed_identifier: str) -> None:
super().__init__()
- assert isinstance(name, str)
- self._unparsed_identifier = name
+ assert isinstance(unparsed_identifier, str)
+ self._unparsed_identifier = unparsed_identifier
def name(self) -> str:
"""Returns the qualified name of the column reference."""
@@ -354,6 +354,43 @@ class ColumnReference(Expression):
def __repr__(self) -> str:
return f"{self._unparsed_identifier}"
+ def __eq__(self, other: Any) -> bool:
+ return (
+ other is not None
+ and isinstance(other, ColumnReference)
+ and other._unparsed_identifier == self._unparsed_identifier
+ )
+
+
+class UnresolvedStar(Expression):
+ def __init__(self, unparsed_target: Optional[str]):
+ super().__init__()
+
+ if unparsed_target is not None:
+ assert isinstance(unparsed_target, str) and
unparsed_target.endswith(".*")
+
+ self._unparsed_target = unparsed_target
+
+ def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
+ expr = proto.Expression()
+ expr.unresolved_star.SetInParent()
+ if self._unparsed_target is not None:
+ expr.unresolved_star.unparsed_target = self._unparsed_target
+ return expr
+
+ def __repr__(self) -> str:
+ if self._unparsed_target is not None:
+ return f"unresolvedstar({self._unparsed_target})"
+ else:
+ return "unresolvedstar()"
+
+ def __eq__(self, other: Any) -> bool:
+ return (
+ other is not None
+ and isinstance(other, UnresolvedStar)
+ and other._unparsed_target == self._unparsed_target
+ )
+
class SQLExpression(Expression):
"""Returns Expression which contains a string which is a SQL expression
@@ -370,6 +407,9 @@ class SQLExpression(Expression):
expr.expression_string.expression = self._expr
return expr
+ def __eq__(self, other: Any) -> bool:
+ return other is not None and isinstance(other, SQLExpression) and
other._expr == self._expr
+
class SortOrder(Expression):
def __init__(self, child: Expression, ascending: bool = True, nullsFirst:
bool = True) -> None:
diff --git a/python/pyspark/sql/connect/functions.py
b/python/pyspark/sql/connect/functions.py
index 5f1eb9c06d7..c73e6ec1ee4 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -40,6 +40,7 @@ from pyspark.sql.connect.expressions import (
LiteralExpression,
ColumnReference,
UnresolvedFunction,
+ UnresolvedStar,
SQLExpression,
LambdaFunction,
UnresolvedNamedLambdaVariable,
@@ -186,7 +187,12 @@ def _options_to_col(options: Dict[str, Any]) -> Column:
def col(col: str) -> Column:
- return Column(ColumnReference(col))
+ if col == "*":
+ return Column(UnresolvedStar(unparsed_target=None))
+ elif col.endswith(".*"):
+ return Column(UnresolvedStar(unparsed_target=col))
+ else:
+ return Column(ColumnReference(unparsed_identifier=col))
col.__doc__ = pysparkfuncs.col.__doc__
@@ -2389,9 +2395,6 @@ def _test() -> None:
# TODO(SPARK-41843): Implement SparkSession.udf
del pyspark.sql.connect.functions.call_udf.__doc__
- # TODO(SPARK-41845): Fix count bug
- del pyspark.sql.connect.functions.count.__doc__
-
globs["spark"] = (
PySparkSession.builder.appName("sql.connect.functions tests")
.remote("local[4]")
diff --git a/python/pyspark/sql/connect/group.py
b/python/pyspark/sql/connect/group.py
index 3aa070ff8b6..cc728808d3a 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -80,14 +80,8 @@ class GroupedData:
assert exprs, "exprs should not be empty"
if len(exprs) == 1 and isinstance(exprs[0], dict):
- # There is a special case for count(*) which is rewritten into
count(1).
# Convert the dict into key value pairs
- aggregate_cols = [
- _invoke_function(
- exprs[0][k], lit(1) if exprs[0][k] == "count" and k == "*"
else col(k)
- )
- for k in exprs[0]
- ]
+ aggregate_cols = [_invoke_function(exprs[0][k], col(k)) for k in
exprs[0]]
else:
# Columns
assert all(isinstance(c, Column) for c in exprs), "all exprs
should be Column"
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py
b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 462384999bb..87c16964102 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -34,7 +34,7 @@ from pyspark.sql.connect.proto import types_pb2 as
spark_dot_connect_dot_types__
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xe8#\n\nExpression\x12=\n\x07literal\x18\x01
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct
[...]
+
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\x92$\n\nExpression\x12=\n\x07literal\x18\x01
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct
[...]
)
@@ -262,7 +262,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options =
b"\n\036org.apache.spark.connect.protoP\001"
_EXPRESSION._serialized_start = 105
- _EXPRESSION._serialized_end = 4689
+ _EXPRESSION._serialized_end = 4731
_EXPRESSION_WINDOW._serialized_start = 1347
_EXPRESSION_WINDOW._serialized_end = 2130
_EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1637
@@ -292,17 +292,17 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_EXPRESSION_EXPRESSIONSTRING._serialized_start = 3866
_EXPRESSION_EXPRESSIONSTRING._serialized_end = 3916
_EXPRESSION_UNRESOLVEDSTAR._serialized_start = 3918
- _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 3958
- _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 3960
- _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 4004
- _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 4007
- _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 4139
- _EXPRESSION_UPDATEFIELDS._serialized_start = 4142
- _EXPRESSION_UPDATEFIELDS._serialized_end = 4329
- _EXPRESSION_ALIAS._serialized_start = 4331
- _EXPRESSION_ALIAS._serialized_end = 4451
- _EXPRESSION_LAMBDAFUNCTION._serialized_start = 4454
- _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4612
- _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 4614
- _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 4676
+ _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 4000
+ _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 4002
+ _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 4046
+ _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 4049
+ _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 4181
+ _EXPRESSION_UPDATEFIELDS._serialized_start = 4184
+ _EXPRESSION_UPDATEFIELDS._serialized_end = 4371
+ _EXPRESSION_ALIAS._serialized_start = 4373
+ _EXPRESSION_ALIAS._serialized_end = 4493
+ _EXPRESSION_LAMBDAFUNCTION._serialized_start = 4496
+ _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4654
+ _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 4656
+ _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 4718
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index 5f64159b854..45889c1518f 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -699,22 +699,33 @@ class Expression(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
- TARGET_FIELD_NUMBER: builtins.int
- @property
- def target(
- self,
- ) ->
google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
- """(Optional) The target of the expansion, either be a table name
or struct name, this
- is a list of identifiers that is the path of the expansion.
- """
+ UNPARSED_TARGET_FIELD_NUMBER: builtins.int
+ unparsed_target: builtins.str
+ """(Optional) The target of the expansion.
+
+ If set, it should end with '.*' and will be parsed by
'parseAttributeName'
+ in the server side.
+ """
def __init__(
self,
*,
- target: collections.abc.Iterable[builtins.str] | None = ...,
+ unparsed_target: builtins.str | None = ...,
) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_unparsed_target", b"_unparsed_target", "unparsed_target",
b"unparsed_target"
+ ],
+ ) -> builtins.bool: ...
def ClearField(
- self, field_name: typing_extensions.Literal["target", b"target"]
+ self,
+ field_name: typing_extensions.Literal[
+ "_unparsed_target", b"_unparsed_target", "unparsed_target",
b"unparsed_target"
+ ],
) -> None: ...
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_unparsed_target",
b"_unparsed_target"]
+ ) -> typing_extensions.Literal["unparsed_target"] | None: ...
class UnresolvedRegex(google.protobuf.message.Message):
"""Represents all of the input attributes to a given relational
operator, for example in
diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py
b/python/pyspark/sql/tests/connect/test_connect_function.py
index 199fd6eb9a9..e1792b03a44 100644
--- a/python/pyspark/sql/tests/connect/test_connect_function.py
+++ b/python/pyspark/sql/tests/connect/test_connect_function.py
@@ -71,6 +71,64 @@ class SparkConnectFunctionTests(SparkConnectFuncTestCase):
self.assertEqual(str1, str2)
+ def test_count_star(self):
+ # SPARK-42099: test count(*), count(col(*)) and count(expr(*))
+
+ from pyspark.sql import functions as SF
+ from pyspark.sql.connect import functions as CF
+
+ data = [(2, "Alice"), (3, "Alice"), (5, "Bob"), (10, "Bob")]
+
+ cdf = self.connect.createDataFrame(data, schema=["age", "name"])
+ sdf = self.spark.createDataFrame(data, schema=["age", "name"])
+
+ self.assertEqual(
+ cdf.select(CF.count(CF.expr("*")), CF.count(cdf.age)).collect(),
+ sdf.select(SF.count(SF.expr("*")), SF.count(sdf.age)).collect(),
+ )
+
+ self.assertEqual(
+ cdf.select(CF.count(CF.col("*")), CF.count(cdf.age)).collect(),
+ sdf.select(SF.count(SF.col("*")), SF.count(sdf.age)).collect(),
+ )
+
+ self.assertEqual(
+ cdf.select(CF.count("*"), CF.count(cdf.age)).collect(),
+ sdf.select(SF.count("*"), SF.count(sdf.age)).collect(),
+ )
+
+ self.assertEqual(
+ cdf.groupby("name").agg({"*": "count"}).sort("name").collect(),
+ sdf.groupby("name").agg({"*": "count"}).sort("name").collect(),
+ )
+
+ self.assertEqual(
+ cdf.groupby("name")
+ .agg(CF.count(CF.expr("*")), CF.count(cdf.age))
+ .sort("name")
+ .collect(),
+ sdf.groupby("name")
+ .agg(SF.count(SF.expr("*")), SF.count(sdf.age))
+ .sort("name")
+ .collect(),
+ )
+
+ self.assertEqual(
+ cdf.groupby("name")
+ .agg(CF.count(CF.col("*")), CF.count(cdf.age))
+ .sort("name")
+ .collect(),
+ sdf.groupby("name")
+ .agg(SF.count(SF.col("*")), SF.count(sdf.age))
+ .sort("name")
+ .collect(),
+ )
+
+ self.assertEqual(
+ cdf.groupby("name").agg(CF.count("*"),
CF.count(cdf.age)).sort("name").collect(),
+ sdf.groupby("name").agg(SF.count("*"),
SF.count(sdf.age)).sort("name").collect(),
+ )
+
def test_broadcast(self):
from pyspark.sql import functions as SF
from pyspark.sql.connect import functions as CF
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]