Github user ueshin commented on a diff in the pull request:
https://github.com/apache/spark/pull/19872#discussion_r156037616
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
---
@@ -48,9 +48,26 @@ object ExtractPythonUDFFromAggregate extends
Rule[LogicalPlan] {
}.isDefined
}
+ private def isPandasGroupAggUdf(expr: Expression): Boolean = expr match {
+ case _ @ PythonUDF(_, _, _, _,
PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF ) => true
+ case Alias(expr, _) => isPandasGroupAggUdf(expr)
+ case _ => false
+ }
+
+ private def hasPandasGroupAggUdf(agg: Aggregate): Boolean = {
+ val actualAggExpr =
agg.aggregateExpressions.drop(agg.groupingExpressions.length)
+ actualAggExpr.exists(isPandasGroupAggUdf)
+ }
+
+
private def extract(agg: Aggregate): LogicalPlan = {
val projList = new ArrayBuffer[NamedExpression]()
val aggExpr = new ArrayBuffer[NamedExpression]()
+
+ if (hasPandasGroupAggUdf(agg)) {
+ Aggregate(agg.groupingExpressions, agg.aggregateExpressions,
agg.child)
+ } else {
+
--- End diff --
nit: style, we need indent for this block.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]