zhenlineo commented on code in PR #41501:
URL: https://github.com/apache/spark/pull/41501#discussion_r1228433793


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala:
##########
@@ -1986,39 +1990,36 @@ class SparkConnectPlanner(val session: SparkSession) 
extends Logging {
   }
 
   private def transformAggregate(rel: proto.Aggregate): LogicalPlan = {
-    rel.getGroupType match {
-      case proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY
-          // This relies on the assumption that a KVGDS always requires the 
head to be a Typed UDF.
-          // This is the case for datasets created via groupByKey,
-          // and also via RelationalGroupedDS#as, as the first is a dummy UDF 
currently.
-          if rel.getGroupingExpressionsList.size() >= 1 &&
-            isTypedScalaUdfExpr(rel.getGroupingExpressionsList.get(0)) =>
-        transformKeyValueGroupedAggregate(rel)
-      case _ =>
-        transformRelationalGroupedAggregate(rel)
-    }
-  }
-
-  private def transformKeyValueGroupedAggregate(rel: proto.Aggregate): 
LogicalPlan = {
-    val input = transformRelation(rel.getInput)
-    val ds = UntypedKeyValueGroupedDataset(input, 
rel.getGroupingExpressionsList, Seq.empty)
-
-    val keyColumn = TypedAggUtils.aggKeyColumn(ds.kEncoder, 
ds.groupingAttributes)
-    val namedColumns = rel.getAggregateExpressionsList.asScala.toSeq
-      .map(expr => transformExpressionWithTypedReduceExpression(expr, input))
-      .map(toNamedExpression)
-    logical.Aggregate(ds.groupingAttributes, keyColumn +: namedColumns, 
ds.analyzed)
-  }
-
-  private def transformRelationalGroupedAggregate(rel: proto.Aggregate): 
LogicalPlan = {
     if (!rel.hasInput) {
       throw InvalidPlanInput("Aggregate needs a plan input")
     }
-    val input = transformRelation(rel.getInput)
+    var input = transformRelation(rel.getInput)
 
     val groupingExprs = 
rel.getGroupingExpressionsList.asScala.toSeq.map(transformExpression)
     val aggExprs = rel.getAggregateExpressionsList.asScala.toSeq
-      .map(expr => transformExpressionWithTypedReduceExpression(expr, input))
+      .map(expr =>
+        expr.getExprTypeCase match {
+          case proto.Expression.ExprTypeCase.UNRESOLVED_FUNCTION
+              if expr.getUnresolvedFunction.getFunctionName == "reduce" =>
+            // The reduce func needs resolved input data attributes, thus 
handle it specially here
+            input = session.sessionState.executePlan(input).analyzed
+            val udf =
+              
expr.getUnresolvedFunction.getArgumentsList.asScala.map(transformExpression) 
match {
+                case Seq(f: ScalaUDF) =>
+                  TypedScalaUdf(f)
+                case other =>
+                  throw InvalidPlanInput(
+                    s"reduce should carry a scalar scala udf, but got $other")
+              }
+            val tEncoder = udf.outEnc // (T, T) => T
+            val reduce = ReduceAggregator(udf.function)(tEncoder).toColumn.expr
+            TypedAggUtils.withInputType(
+              reduce,
+              tEncoder,
+              input.output.filterNot(p => p.name.startsWith("key_")))

Review Comment:
   The reduce function requires us to filter the keys from the inputs. How do 
we do the same for other agg expressions that used "*" too?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to