Github user mallman commented on a diff in the pull request:
https://github.com/apache/spark/pull/16578#discussion_r148716194
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/AggregateFieldExtractionPushdown.scala
---
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet,
NamedExpression}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate,
LogicalPlan, Project}
+
+/**
+ * Pushes down aliases to [[expressions.GetStructField]] expressions in an
aggregate's grouping and
+ * aggregate expressions into a projection over its children. The original
+ * [[expressions.GetStructField]] expressions are replaced with references
to the pushed down
+ * aliases.
+ */
+object AggregateFieldExtractionPushdown extends FieldExtractionPushdown {
+ override def apply(plan: LogicalPlan): LogicalPlan =
+ plan transformDown {
+ case agg @ Aggregate(groupingExpressions, aggregateExpressions,
child) =>
+ val expressions = groupingExpressions ++ aggregateExpressions
+ val attributes = AttributeSet(expressions.collect { case att:
Attribute => att })
+ val childAttributes = AttributeSet(child.expressions)
+ val fieldExtractors0 =
+ expressions
+ .flatMap(getFieldExtractors)
+ .distinct
+ val fieldExtractors1 =
+ fieldExtractors0
+ .filter(_.collectFirst { case att: Attribute => att }
+ .filter(attributes.contains).isEmpty)
--- End diff --
The attribute `a` is not in `expressions`, so it is not in `attributes`.
When we construct `attributes`, we simply collect instances of `Attribute`. We
don't do any recursion.
Your query is tested by the "basic aggregate field extraction pushdown"
test in `AggregateFieldExtractionPushdownSuite`. It's a little difficult to see
because I'm using the Catalyst DataFrame DSL. This seems to be the convention
in these tests, though.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]