Github user icexelloss commented on a diff in the pull request:
https://github.com/apache/spark/pull/19872#discussion_r161856927
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala ---
@@ -334,34 +339,51 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
object Aggregation extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case PhysicalAggregation(
- groupingExpressions, aggregateExpressions, resultExpressions,
child) =>
-
- val (functionsWithDistinct, functionsWithoutDistinct) =
- aggregateExpressions.partition(_.isDistinct)
- if
(functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
- // This is a sanity check. We should not reach here when we have
multiple distinct
- // column sets. Our MultipleDistinctRewriter should take care
this case.
- sys.error("You hit a query analyzer bug. Please report your
query to " +
- "Spark user mailing list.")
- }
+ groupingExpressions, aggExpressions, resultExpressions, child) =>
+
+ if (aggExpressions.forall(expr =>
expr.isInstanceOf[AggregateExpression])) {
- val aggregateOperator =
- if (functionsWithDistinct.isEmpty) {
- aggregate.AggUtils.planAggregateWithoutDistinct(
- groupingExpressions,
- aggregateExpressions,
- resultExpressions,
- planLater(child))
- } else {
- aggregate.AggUtils.planAggregateWithOneDistinct(
- groupingExpressions,
- functionsWithDistinct,
- functionsWithoutDistinct,
- resultExpressions,
- planLater(child))
+ val aggregateExpressions = aggExpressions.map(expr =>
+ expr.asInstanceOf[AggregateExpression])
+
+ val (functionsWithDistinct, functionsWithoutDistinct) =
+ aggregateExpressions.partition(_.isDistinct)
+ if
(functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
+ // This is a sanity check. We should not reach here when we
have multiple distinct
+ // column sets. Our MultipleDistinctRewriter should take care
this case.
+ sys.error("You hit a query analyzer bug. Please report your
query to " +
+ "Spark user mailing list.")
}
- aggregateOperator
+ val aggregateOperator =
+ if (functionsWithDistinct.isEmpty) {
+ aggregate.AggUtils.planAggregateWithoutDistinct(
+ groupingExpressions,
+ aggregateExpressions,
+ resultExpressions,
+ planLater(child))
+ } else {
+ aggregate.AggUtils.planAggregateWithOneDistinct(
+ groupingExpressions,
+ functionsWithDistinct,
+ functionsWithoutDistinct,
+ resultExpressions,
+ planLater(child))
+ }
+
+ aggregateOperator
+ } else if (aggExpressions.forall(expr =>
expr.isInstanceOf[PythonUDF])) {
+ val udfExpressions = aggExpressions.map(expr =>
expr.asInstanceOf[PythonUDF])
+
+ Seq(execution.python.AggregateInPandasExec(
+ groupingExpressions,
+ udfExpressions,
+ resultExpressions,
+ planLater(child)))
+ } else {
+ throw new IllegalArgumentException(
--- End diff --
+1. Let's double check in
https://github.com/apache/spark/pull/19872#discussion_r161507315
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]