This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 2efa2e657 [GLUTEN-5649][VL] Fix NullPointerException when collect_list
/ collect_set are partially fallen back (#5655)
2efa2e657 is described below
commit 2efa2e657f2b0cdd43adc9e57963fa17bd7bf5b6
Author: Hongze Zhang <[email protected]>
AuthorDate: Sat May 11 16:16:48 2024 +0800
[GLUTEN-5649][VL] Fix NullPointerException when collect_list / collect_set
are partially fallen back (#5655)
Fixes #5649. Added vanilla implementation of velox_collect_list and
velox_collect_set.
Velox backend's collect_list / collect_set implementations require for
ARRAY intermediate data however Spark uses BINARY. To address this we did some
tricks to forcibly modify the physical plan to change the output schema of
partial aggregate operator to align with Velox, but that way the actual
information for the two functions in Velox backend is still hidden from query
plan so advanced optimizations or compatibility checks are made difficult
during planning phase.
This patch adds new functions velox_collect_list / velox_collect_set to
correctly map to Velox backend's implementation for the two functions and does
essential code cleanup and refactors.
---
.../clickhouse/CHSparkPlanExecApi.scala | 8 +-
.../gluten/backendsapi/velox/VeloxBackend.scala | 8 --
.../backendsapi/velox/VeloxSparkPlanExecApi.scala | 13 +-
.../execution/HashAggregateExecTransformer.scala | 29 ++---
.../expression}/aggregate/HLLAdapter.scala | 5 +-
.../gluten/expression/aggregate/VeloxCollect.scala | 70 +++++++++++
.../BloomFilterMightContainJointRewriteRule.scala | 3 +-
.../gluten/extension/CollectRewriteRule.scala | 106 ++++++++++++++++
.../extension}/FlushableHashAggregateRule.scala | 2 +-
.../extension/HLLRewriteRule.scala} | 48 +++----
.../gluten/utils/VeloxIntermediateData.scala | 3 -
.../apache/gluten/execution/FallbackSuite.scala | 16 ++-
.../execution/VeloxAggregateFunctionsSuite.scala | 79 ++++++++++++
.../execution/VeloxWindowExpressionSuite.scala | 20 ++-
cpp/velox/substrait/VeloxSubstraitSignature.cc | 1 +
.../gluten/backendsapi/BackendSettingsApi.scala | 4 -
.../gluten/expression/ExpressionMappings.scala | 2 -
.../extension/columnar/TransformHintRule.scala | 1 +
.../extension/columnar/enumerated/RasOffload.scala | 1 +
.../columnar/rewrite/RewriteCollect.scala | 140 ---------------------
.../{ => columnar/rewrite}/RewriteIn.scala | 4 +-
.../columnar/rewrite/RewriteSingleNode.scala | 10 +-
.../rewrite/RewriteTypedImperativeAggregate.scala | 72 -----------
.../columnar/validator/FallbackInjects.scala | 38 ++++++
.../extension/columnar/validator/Validators.scala | 14 +++
.../apache/gluten/utils/BackendTestSettings.scala | 82 +++++-------
.../gluten/utils/velox/VeloxTestSettings.scala | 7 +-
.../gluten/utils/velox/VeloxTestSettings.scala | 7 +-
.../gluten/utils/velox/VeloxTestSettings.scala | 7 +-
.../spark/sql/GlutenDataFrameAggregateSuite.scala | 79 ++++++++++--
.../gluten/utils/velox/VeloxTestSettings.scala | 7 +-
31 files changed, 522 insertions(+), 364 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
index 030648b06..465041621 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala
@@ -40,7 +40,7 @@ import
org.apache.spark.sql.catalyst.{CHAggregateFunctionRewriteRule, EqualToRew
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
CollectList, CollectSet}
import org.apache.spark.sql.catalyst.optimizer.BuildSide
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -631,7 +631,11 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
/** Define backend specfic expression mappings. */
override def extraExpressionMappings: Seq[Sig] = {
- SparkShimLoader.getSparkShims.bloomFilterExpressionMappings()
+ List(
+ Sig[CollectList](ExpressionNames.COLLECT_LIST),
+ Sig[CollectSet](ExpressionNames.COLLECT_SET)
+ ) ++
+ SparkShimLoader.getSparkShims.bloomFilterExpressionMappings()
}
override def genStringTranslateTransformer(
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
index aad8ff5d5..5509d37e8 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
@@ -501,14 +501,6 @@ object VeloxBackendSettings extends BackendSettingsApi {
override def supportBroadcastNestedLoopJoinExec(): Boolean = true
- override def shouldRewriteTypedImperativeAggregate(): Boolean = {
- // The intermediate type of collect_list, collect_set in Velox backend is
not consistent with
- // vanilla Spark, we need to rewrite the aggregate to get the correct data
type.
- true
- }
-
- override def shouldRewriteCollect(): Boolean = true
-
override def supportColumnarArrowUdf(): Boolean = true
override def generateHdfsConfForLibhdfs(): Boolean = true
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
index f98055630..772f1cfb2 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
@@ -23,8 +23,8 @@ import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.execution._
import org.apache.gluten.expression._
import org.apache.gluten.expression.ConverterUtils.FunctionConfig
-import org.apache.gluten.expression.aggregate.VeloxBloomFilterAggregate
-import org.apache.gluten.extension.BloomFilterMightContainJointRewriteRule
+import org.apache.gluten.expression.aggregate.{HLLAdapter,
VeloxBloomFilterAggregate, VeloxCollectList, VeloxCollectSet}
+import org.apache.gluten.extension.{BloomFilterMightContainJointRewriteRule,
CollectRewriteRule, FlushableHashAggregateRule, HLLRewriteRule}
import org.apache.gluten.extension.columnar.TransformHints
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.substrait.expression.{ExpressionBuilder,
ExpressionNode, IfThenNode}
@@ -37,12 +37,12 @@ import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{GenShuffleWriterParameters,
GlutenShuffleWriterWrapper}
import org.apache.spark.shuffle.utils.ShuffleUtil
import org.apache.spark.sql.{SparkSession, Strategy}
-import org.apache.spark.sql.catalyst.{AggregateFunctionRewriteRule,
FlushableHashAggregateRule, FunctionIdentifier}
+import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions._
-import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
HLLAdapter}
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.optimizer.BuildSide
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -734,7 +734,8 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
* @return
*/
override def genExtendedOptimizers(): List[SparkSession =>
Rule[LogicalPlan]] = List(
- AggregateFunctionRewriteRule.apply
+ CollectRewriteRule.apply,
+ HLLRewriteRule.apply
)
/**
@@ -788,6 +789,8 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
Sig[UDFExpression](ExpressionNames.UDF_PLACEHOLDER),
Sig[UserDefinedAggregateFunction](ExpressionNames.UDF_PLACEHOLDER),
Sig[NaNvl](ExpressionNames.NANVL),
+ Sig[VeloxCollectList](ExpressionNames.COLLECT_LIST),
+ Sig[VeloxCollectSet](ExpressionNames.COLLECT_SET),
Sig[VeloxBloomFilterMightContain](ExpressionNames.MIGHT_CONTAIN),
Sig[VeloxBloomFilterAggregate](ExpressionNames.BLOOM_FILTER_AGG)
)
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
b/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
index f0a7ea180..26d30606d 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
@@ -20,7 +20,7 @@ import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.expression._
import org.apache.gluten.expression.ConverterUtils.FunctionConfig
-import
org.apache.gluten.extension.columnar.rewrite.RewriteTypedImperativeAggregate
+import org.apache.gluten.expression.aggregate.HLLAdapter
import org.apache.gluten.substrait.`type`.{TypeBuilder, TypeNode}
import org.apache.gluten.substrait.{AggregationParams, SubstraitContext}
import org.apache.gluten.substrait.expression.{AggregateFunctionNode,
ExpressionBuilder, ExpressionNode, ScalarFunctionNode}
@@ -807,25 +807,14 @@ case class HashAggregateExecPullOutHelper(
override protected def getAttrForAggregateExprs: List[Attribute] = {
aggregateExpressions.zipWithIndex.flatMap {
case (expr, index) =>
- handleSpecialAggregateAttr
- .lift(expr)
- .getOrElse(expr.mode match {
- case Partial | PartialMerge =>
- expr.aggregateFunction.aggBufferAttributes
- case Final =>
- Seq(aggregateAttributes(index))
- case other =>
- throw new GlutenNotSupportException(s"Unsupported aggregate
mode: $other.")
- })
+ expr.mode match {
+ case Partial | PartialMerge =>
+ expr.aggregateFunction.aggBufferAttributes
+ case Final =>
+ Seq(aggregateAttributes(index))
+ case other =>
+ throw new GlutenNotSupportException(s"Unsupported aggregate mode:
$other.")
+ }
}.toList
}
-
- private val handleSpecialAggregateAttr: PartialFunction[AggregateExpression,
Seq[Attribute]] = {
- case ae: AggregateExpression if
RewriteTypedImperativeAggregate.shouldRewrite(ae) =>
- val aggBufferAttr = ae.aggregateFunction.inputAggBufferAttributes.head
- Seq(
- aggBufferAttr.copy(dataType = ae.aggregateFunction.dataType)(
- aggBufferAttr.exprId,
- aggBufferAttr.qualifier))
- }
}
diff --git
a/backends-velox/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HLLAdapter.scala
b/backends-velox/src/main/scala/org/apache/gluten/expression/aggregate/HLLAdapter.scala
similarity index 94%
rename from
backends-velox/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HLLAdapter.scala
rename to
backends-velox/src/main/scala/org/apache/gluten/expression/aggregate/HLLAdapter.scala
index 05e0f8441..78b4cb148 100644
---
a/backends-velox/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HLLAdapter.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/expression/aggregate/HLLAdapter.scala
@@ -14,10 +14,11 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.spark.sql.catalyst.expressions.aggregate
+package org.apache.gluten.expression.aggregate
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import
org.apache.spark.sql.catalyst.expressions.aggregate.{HyperLogLogPlusPlus,
ImperativeAggregate, TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.catalyst.util.HyperLogLogPlusPlusHelper
import org.apache.spark.sql.types._
@@ -52,7 +53,7 @@ case class HLLAdapter(
private lazy val row = new UnsafeRow(hllppHelper.numWords)
- override def prettyName: String = "approx_count_distinct_velox"
+ override def prettyName: String = "velox_approx_count_distinct"
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int):
ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/expression/aggregate/VeloxCollect.scala
b/backends-velox/src/main/scala/org/apache/gluten/expression/aggregate/VeloxCollect.scala
new file mode 100644
index 000000000..c12aeab26
--- /dev/null
+++
b/backends-velox/src/main/scala/org/apache/gluten/expression/aggregate/VeloxCollect.scala
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.expression.aggregate
+
+import org.apache.spark.sql.catalyst.expressions.{ArrayDistinct,
AttributeReference, Concat, CreateArray, Expression, If, IsNull, Literal}
+import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
+import org.apache.spark.sql.catalyst.trees.UnaryLike
+import org.apache.spark.sql.types.{ArrayType, DataType}
+
+abstract class VeloxCollect extends DeclarativeAggregate with
UnaryLike[Expression] {
+ protected lazy val buffer: AttributeReference = AttributeReference("buffer",
dataType)()
+
+ override def dataType: DataType = ArrayType(child.dataType, false)
+
+ override def aggBufferAttributes: Seq[AttributeReference] = List(buffer)
+
+ override lazy val initialValues: Seq[Expression] =
List(Literal.create(Seq.empty, dataType))
+
+ override lazy val updateExpressions: Seq[Expression] = List(
+ If(
+ IsNull(child),
+ buffer,
+ Concat(List(buffer, CreateArray(List(child), useStringTypeWhenEmpty =
false))))
+ )
+
+ override lazy val mergeExpressions: Seq[Expression] = List(
+ Concat(List(buffer.left, buffer.right))
+ )
+
+ override def defaultResult: Option[Literal] = Option(Literal.create(Array(),
dataType))
+}
+
+case class VeloxCollectSet(override val child: Expression) extends
VeloxCollect {
+ override def prettyName: String = "velox_collect_set"
+
+ // Velox's collect_set implementation allows null output. Thus we usually
wrap
+ // the function to enforce non-null output. See
CollectRewriteRule#ensureNonNull.
+ override def nullable: Boolean = true
+
+ override protected def withNewChildInternal(newChild: Expression):
Expression =
+ copy(child = newChild)
+
+ override lazy val evaluateExpression: Expression =
+ ArrayDistinct(buffer)
+}
+
+case class VeloxCollectList(override val child: Expression) extends
VeloxCollect {
+ override def prettyName: String = "velox_collect_list"
+
+ override def nullable: Boolean = false
+
+ override protected def withNewChildInternal(newChild: Expression):
Expression =
+ copy(child = newChild)
+
+ override val evaluateExpression: Expression = buffer
+}
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/extension/BloomFilterMightContainJointRewriteRule.scala
b/backends-velox/src/main/scala/org/apache/gluten/extension/BloomFilterMightContainJointRewriteRule.scala
index deba381db..9a0a59e8e 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/extension/BloomFilterMightContainJointRewriteRule.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/extension/BloomFilterMightContainJointRewriteRule.scala
@@ -20,13 +20,14 @@ import org.apache.gluten.GlutenConfig
import org.apache.gluten.expression.VeloxBloomFilterMightContain
import org.apache.gluten.expression.aggregate.VeloxBloomFilterAggregate
import org.apache.gluten.sql.shims.SparkShimLoader
+import org.apache.gluten.utils.PhysicalPlanSelector
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan
case class BloomFilterMightContainJointRewriteRule(spark: SparkSession)
extends Rule[SparkPlan] {
- override def apply(plan: SparkPlan): SparkPlan = {
+ override def apply(plan: SparkPlan): SparkPlan =
PhysicalPlanSelector.maybe(spark, plan) {
if (!(GlutenConfig.getConf.enableNativeBloomFilter)) {
return plan
}
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/extension/CollectRewriteRule.scala
b/backends-velox/src/main/scala/org/apache/gluten/extension/CollectRewriteRule.scala
new file mode 100644
index 000000000..d7299c511
--- /dev/null
+++
b/backends-velox/src/main/scala/org/apache/gluten/extension/CollectRewriteRule.scala
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension
+
+import org.apache.gluten.expression.ExpressionMappings
+import org.apache.gluten.expression.aggregate.{VeloxCollectList,
VeloxCollectSet}
+import org.apache.gluten.utils.LogicalPlanSelector
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.expressions.{And, Coalesce, Expression,
IsNotNull, Literal, WindowExpression}
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan,
Window}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.types.ArrayType
+
+import scala.reflect.{classTag, ClassTag}
+
+/**
+ * Velox's collect_list / collect_set use array as intermediate data type so
aren't compatible with
+ * vanilla Spark. We here replace the two functions with velox_collect_list /
velox_collect_set to
+ * distinguish.
+ */
+case class CollectRewriteRule(spark: SparkSession) extends Rule[LogicalPlan] {
+ import CollectRewriteRule._
+ override def apply(plan: LogicalPlan): LogicalPlan =
LogicalPlanSelector.maybe(spark, plan) {
+ val out = plan.transformUp {
+ case node =>
+ val out = replaceCollectSet(replaceCollectList(node))
+ out
+ }
+ if (out.fastEquals(plan)) {
+ return plan
+ }
+ out
+ }
+
+ private def replaceCollectList(node: LogicalPlan): LogicalPlan = {
+ node.transformExpressions {
+ case func @ AggregateExpression(l: CollectList, _, _, _, _) if
has[VeloxCollectList] =>
+ func.copy(VeloxCollectList(l.child))
+ }
+ }
+
+ private def replaceCollectSet(node: LogicalPlan): LogicalPlan = {
+ // 1. Replace null result from VeloxCollectSet with empty array to align
with
+ // vanilla Spark.
+ // 2. Filter out null inputs from VeloxCollectSet to align with vanilla
Spark.
+ //
+ // Since https://github.com/apache/incubator-gluten/pull/4805
+ node match {
+ case agg: Aggregate =>
+ agg.transformExpressions {
+ case ToVeloxCollectSet(newAggFunc) =>
+ val out = ensureNonNull(newAggFunc)
+ out
+ }
+ case w: Window =>
+ w.transformExpressions {
+ case func @ WindowExpression(ToVeloxCollectSet(newAggFunc), _) =>
+ val out = ensureNonNull(func.copy(newAggFunc))
+ out
+ }
+ case other => other
+ }
+ }
+}
+
+object CollectRewriteRule {
+ private def ensureNonNull(expr: Expression): Expression = {
+ val out =
+ Coalesce(List(expr, Literal.create(Seq.empty, expr.dataType)))
+ assert(!out.nullable)
+ assert(!out.dataType.asInstanceOf[ArrayType].containsNull)
+ out
+ }
+
+ private object ToVeloxCollectSet {
+ def unapply(expr: Expression): Option[Expression] = expr match {
+ case aggFunc @ AggregateExpression(s: CollectSet, _, _, filter, _) if
has[VeloxCollectSet] =>
+ val newFilter = (filter ++ Some(IsNotNull(s.child))).reduceOption(And)
+ val newAggFunc =
+ aggFunc.copy(aggregateFunction = VeloxCollectSet(s.child), filter =
newFilter)
+ Some(newAggFunc)
+ case _ => None
+ }
+ }
+
+ private def has[T <: Expression: ClassTag]: Boolean = {
+ val out =
ExpressionMappings.expressionsMap.contains(classTag[T].runtimeClass)
+ out
+ }
+}
diff --git
a/backends-velox/src/main/scala/org/apache/spark/sql/catalyst/FlushableHashAggregateRule.scala
b/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala
similarity index 99%
rename from
backends-velox/src/main/scala/org/apache/spark/sql/catalyst/FlushableHashAggregateRule.scala
rename to
backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala
index a1858e599..f850b6f45 100644
---
a/backends-velox/src/main/scala/org/apache/spark/sql/catalyst/FlushableHashAggregateRule.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala
@@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.spark.sql.catalyst
+package org.apache.gluten.extension
import org.apache.gluten.execution.{FlushableHashAggregateExecTransformer,
HashAggregateExecTransformer, ProjectExecTransformer,
RegularHashAggregateExecTransformer}
diff --git
a/backends-velox/src/main/scala/org/apache/spark/sql/catalyst/AggregateFunctionRewriteRule.scala
b/backends-velox/src/main/scala/org/apache/gluten/extension/HLLRewriteRule.scala
similarity index 65%
rename from
backends-velox/src/main/scala/org/apache/spark/sql/catalyst/AggregateFunctionRewriteRule.scala
rename to
backends-velox/src/main/scala/org/apache/gluten/extension/HLLRewriteRule.scala
index c5cb1b24d..cb1e626a1 100644
---
a/backends-velox/src/main/scala/org/apache/spark/sql/catalyst/AggregateFunctionRewriteRule.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/extension/HLLRewriteRule.scala
@@ -14,37 +14,41 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.spark.sql.catalyst
+package org.apache.gluten.extension
import org.apache.gluten.GlutenConfig
+import org.apache.gluten.expression.aggregate.HLLAdapter
+import org.apache.gluten.utils.LogicalPlanSelector
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.Literal
-import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
HLLAdapter, HyperLogLogPlusPlus}
+import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
HyperLogLogPlusPlus}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types._
-case class AggregateFunctionRewriteRule(spark: SparkSession) extends
Rule[LogicalPlan] {
- override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp
{
- case a: Aggregate =>
- a.transformExpressions {
- case hllExpr @ AggregateExpression(hll: HyperLogLogPlusPlus, _, _, _,
_)
- if GlutenConfig.getConf.enableNativeHyperLogLogAggregateFunction &&
- GlutenConfig.getConf.enableColumnarHashAgg &&
- !hasDistinctAggregateFunc(a) &&
isDataTypeSupported(hll.child.dataType) =>
- AggregateExpression(
- HLLAdapter(
- hll.child,
- Literal(hll.relativeSD),
- hll.mutableAggBufferOffset,
- hll.inputAggBufferOffset),
- hllExpr.mode,
- hllExpr.isDistinct,
- hllExpr.filter,
- hllExpr.resultId
- )
- }
+case class HLLRewriteRule(spark: SparkSession) extends Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan =
LogicalPlanSelector.maybe(spark, plan) {
+ plan.resolveOperatorsUp {
+ case a: Aggregate =>
+ a.transformExpressions {
+ case hllExpr @ AggregateExpression(hll: HyperLogLogPlusPlus, _, _,
_, _)
+ if GlutenConfig.getConf.enableNativeHyperLogLogAggregateFunction
&&
+ GlutenConfig.getConf.enableColumnarHashAgg &&
+ !hasDistinctAggregateFunc(a) &&
isDataTypeSupported(hll.child.dataType) =>
+ AggregateExpression(
+ HLLAdapter(
+ hll.child,
+ Literal(hll.relativeSD),
+ hll.mutableAggBufferOffset,
+ hll.inputAggBufferOffset),
+ hllExpr.mode,
+ hllExpr.isDistinct,
+ hllExpr.filter,
+ hllExpr.resultId
+ )
+ }
+ }
}
private def hasDistinctAggregateFunc(agg: Aggregate): Boolean = {
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/utils/VeloxIntermediateData.scala
b/backends-velox/src/main/scala/org/apache/gluten/utils/VeloxIntermediateData.scala
index a00bcae1c..a22655152 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/utils/VeloxIntermediateData.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/utils/VeloxIntermediateData.scala
@@ -125,9 +125,6 @@ object VeloxIntermediateData {
aggregateFunc match {
case _ @Type(veloxDataTypes: Seq[DataType]) =>
Seq(StructType(veloxDataTypes.map(StructField("", _)).toArray))
- case _: CollectList | _: CollectSet =>
- // CollectList and CollectSet should use data type of agg function.
- Seq(aggregateFunc.dataType)
case _ =>
// Not use StructType for single column agg intermediate data
aggregateFunc.aggBufferAttributes.map(_.dataType)
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/execution/FallbackSuite.scala
b/backends-velox/src/test/scala/org/apache/gluten/execution/FallbackSuite.scala
index fbad525a2..e8833a43c 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/execution/FallbackSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/execution/FallbackSuite.scala
@@ -51,6 +51,12 @@ class FallbackSuite extends VeloxWholeStageTransformerSuite
with AdaptiveSparkPl
.write
.format("parquet")
.saveAsTable("tmp2")
+ spark
+ .range(100)
+ .selectExpr("cast(id % 3 as int) as c1", "cast(id % 9 as int) as c2")
+ .write
+ .format("parquet")
+ .saveAsTable("tmp3")
}
override protected def afterAll(): Unit = {
@@ -106,15 +112,14 @@ class FallbackSuite extends
VeloxWholeStageTransformerSuite with AdaptiveSparkPl
}
}
- // java.lang.NullPointerException
- ignore("fallback final aggregate of collect_list") {
+ test("fallback final aggregate of collect_list") {
withSQLConf(
GlutenConfig.COLUMNAR_WHOLESTAGE_FALLBACK_THRESHOLD.key -> "1",
GlutenConfig.COLUMNAR_FALLBACK_IGNORE_ROW_TO_COLUMNAR.key -> "false",
GlutenConfig.EXPRESSION_BLACK_LIST.key -> "element_at"
) {
runQueryAndCompare(
- "SELECT sum(ele) FROM (SELECT c1, element_at(collect_list(c2), 1) as
ele FROM tmp1 " +
+ "SELECT sum(ele) FROM (SELECT c1, element_at(collect_list(c2), 1) as
ele FROM tmp3 " +
"GROUP BY c1)") {
df =>
val columnarToRow =
collectColumnarToRow(df.queryExecution.executedPlan)
@@ -123,7 +128,8 @@ class FallbackSuite extends VeloxWholeStageTransformerSuite
with AdaptiveSparkPl
}
}
- // java.lang.NullPointerException
+ // Elements in velox_collect_set's output set may be in different order.
This is a benign bug
+ // until we can exactly align with vanilla Spark.
ignore("fallback final aggregate of collect_set") {
withSQLConf(
GlutenConfig.COLUMNAR_WHOLESTAGE_FALLBACK_THRESHOLD.key -> "1",
@@ -131,7 +137,7 @@ class FallbackSuite extends VeloxWholeStageTransformerSuite
with AdaptiveSparkPl
GlutenConfig.EXPRESSION_BLACK_LIST.key -> "element_at"
) {
runQueryAndCompare(
- "SELECT sum(ele) FROM (SELECT c1, element_at(collect_set(c2), 1) as
ele FROM tmp1 " +
+ "SELECT sum(ele) FROM (SELECT c1, element_at(collect_set(c2), 1) as
ele FROM tmp3 " +
"GROUP BY c1)") {
df =>
val columnarToRow =
collectColumnarToRow(df.queryExecution.executedPlan)
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
index 70fff52b8..398f5e05e 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala
@@ -17,8 +17,11 @@
package org.apache.gluten.execution
import org.apache.gluten.GlutenConfig
+import org.apache.gluten.extension.columnar.validator.FallbackInjects
import org.apache.spark.SparkConf
+import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
+import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.internal.SQLConf
abstract class VeloxAggregateFunctionsSuite extends
VeloxWholeStageTransformerSuite {
@@ -977,6 +980,82 @@ abstract class VeloxAggregateFunctionsSuite extends
VeloxWholeStageTransformerSu
}
}
+ // Used for testing aggregate fallback
+ sealed trait FallbackMode
+ case object Offload extends FallbackMode
+ case object FallbackPartial extends FallbackMode
+ case object FallbackFinal extends FallbackMode
+ case object FallbackAll extends FallbackMode
+
+ List(Offload, FallbackPartial, FallbackFinal, FallbackAll).foreach {
+ mode =>
+ test(s"test fallback collect_set/collect_list with null, $mode") {
+ mode match {
+ case Offload => doTest()
+ case FallbackPartial =>
+ FallbackInjects.fallbackOn {
+ case agg: BaseAggregateExec =>
+ agg.aggregateExpressions.exists(_.mode == Partial)
+ } {
+ doTest()
+ }
+ case FallbackFinal =>
+ FallbackInjects.fallbackOn {
+ case agg: BaseAggregateExec =>
+ agg.aggregateExpressions.exists(_.mode == Final)
+ } {
+ doTest()
+ }
+ case FallbackAll =>
+ FallbackInjects.fallbackOn { case _: BaseAggregateExec => true } {
+ doTest()
+ }
+ }
+
+ def doTest(): Unit = {
+ withTempView("collect_tmp") {
+ Seq((1, null), (1, "a"), (2, null), (3, null), (3, null), (4, "b"))
+ .toDF("c1", "c2")
+ .createOrReplaceTempView("collect_tmp")
+
+ // basic test
+ runQueryAndCompare(
+ "SELECT collect_set(c2), collect_list(c2) FROM collect_tmp GROUP
BY c1") { _ => }
+
+ // test pre project and post project
+ runQueryAndCompare("""
+ |SELECT
+ |size(collect_set(if(c2 = 'a', 'x', 'y'))) as
x,
+ |size(collect_list(if(c2 = 'a', 'x', 'y')))
as y
+ |FROM collect_tmp GROUP BY c1
+ |""".stripMargin) { _ => }
+
+ // test distinct
+ runQueryAndCompare(
+ "SELECT collect_set(c2), collect_list(distinct c2) FROM
collect_tmp GROUP BY c1") {
+ _ =>
+ }
+
+ // test distinct + pre project and post project
+ runQueryAndCompare("""
+ |SELECT
+ |size(collect_set(if(c2 = 'a', 'x', 'y'))),
+ |size(collect_list(distinct if(c2 = 'a', 'x',
'y')))
+ |FROM collect_tmp GROUP BY c1
+ |""".stripMargin) { _ => }
+
+ // test cast array to string
+ runQueryAndCompare("""
+ |SELECT
+ |cast(collect_set(c2) as string),
+ |cast(collect_list(c2) as string)
+ |FROM collect_tmp GROUP BY c1
+ |""".stripMargin) { _ => }
+ }
+ }
+ }
+ }
+
test("count(1)") {
runQueryAndCompare(
"""
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxWindowExpressionSuite.scala
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxWindowExpressionSuite.scala
index 3dfbd6bd2..03b295f49 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxWindowExpressionSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxWindowExpressionSuite.scala
@@ -18,6 +18,7 @@ package org.apache.gluten.execution
import org.apache.spark.SparkConf
import org.apache.spark.sql.Row
+import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.types._
class VeloxWindowExpressionSuite extends WholeStageTransformerSuite {
@@ -72,7 +73,7 @@ class VeloxWindowExpressionSuite extends
WholeStageTransformerSuite {
}
}
- test("collect_list") {
+ test("collect_list / collect_set") {
withTable("t") {
val data = Seq(
Row(0, 1),
@@ -108,6 +109,23 @@ class VeloxWindowExpressionSuite extends
WholeStageTransformerSuite {
|""".stripMargin) {
checkGlutenOperatorMatch[WindowExecTransformer]
}
+
+ runQueryAndCompare(
+ """
+ |SELECT
+ | c1,
+ | collect_set(c2) OVER (
+ | PARTITION BY c1
+ | )
+ |FROM
+ | t
+ |ORDER BY 1, 2;
+ |""".stripMargin,
+ noFallBack = false
+ ) {
+ // Velox window doesn't support collect_set
+ checkSparkOperatorMatch[WindowExec]
+ }
}
}
}
diff --git a/cpp/velox/substrait/VeloxSubstraitSignature.cc
b/cpp/velox/substrait/VeloxSubstraitSignature.cc
index ee7c5f513..fa415cfef 100644
--- a/cpp/velox/substrait/VeloxSubstraitSignature.cc
+++ b/cpp/velox/substrait/VeloxSubstraitSignature.cc
@@ -188,6 +188,7 @@ TypePtr
VeloxSubstraitSignature::fromSubstraitSignature(const std::string& signa
types.emplace_back(fromSubstraitSignature(typeStr));
break;
}
+ VELOX_CHECK(childrenTypes.at(typeEnd) == delimiter)
std::string typeStr = childrenTypes.substr(typeStart, typeEnd -
typeStart);
types.emplace_back(fromSubstraitSignature(typeStr));
typeStart = typeEnd + 1;
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala
index ac8c2a436..c8729561d 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala
@@ -149,10 +149,6 @@ trait BackendSettingsApi {
/** Merge two phases hash based aggregate if need */
def mergeTwoPhasesHashBaseAggregateIfNeed(): Boolean = false
- def shouldRewriteTypedImperativeAggregate(): Boolean = false
-
- def shouldRewriteCollect(): Boolean = false
-
def supportColumnarArrowUdf(): Boolean = false
def generateHdfsConfForLibhdfs(): Boolean = false
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
index 33e4f0a7b..459a3d8e2 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
@@ -288,8 +288,6 @@ object ExpressionMappings {
Sig[MinBy](MIN_BY),
Sig[StddevSamp](STDDEV_SAMP),
Sig[StddevPop](STDDEV_POP),
- Sig[CollectList](COLLECT_LIST),
- Sig[CollectSet](COLLECT_SET),
Sig[VarianceSamp](VAR_SAMP),
Sig[VariancePop](VAR_POP),
Sig[BitAndAgg](BIT_AND_AGG),
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformHintRule.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformHintRule.scala
index 3c3d23ccc..c9fcc52aa 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformHintRule.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformHintRule.scala
@@ -304,6 +304,7 @@ case class AddTransformHintRule() extends Rule[SparkPlan] {
.fallbackComplexExpressions()
.fallbackByBackendSettings()
.fallbackByUserOptions()
+ .fallbackByTestInjects()
.build()
def apply(plan: SparkPlan): SparkPlan = {
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala
index 57e093bde..5cabfa88e 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala
@@ -78,6 +78,7 @@ object RasOffload {
.fallbackComplexExpressions()
.fallbackByBackendSettings()
.fallbackByUserOptions()
+ .fallbackByTestInjects()
.build()
private val rewrites = RewriteSingleNode.allRules()
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteCollect.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteCollect.scala
deleted file mode 100644
index 74d493de5..000000000
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteCollect.scala
+++ /dev/null
@@ -1,140 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.gluten.extension.columnar.rewrite
-
-import org.apache.gluten.backendsapi.BackendsApiManager
-import org.apache.gluten.utils.PullOutProjectHelper
-
-import org.apache.spark.sql.catalyst.expressions.{And, Attribute,
AttributeSet, If, IsNotNull, IsNull, Literal, NamedExpression}
-import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
-import org.apache.spark.sql.types.ArrayType
-
-import scala.collection.mutable.ArrayBuffer
-
-/**
- * This rule rewrite collect_set and collect_list to be compatible with
vanilla Spark.
- *
- * - Add `IsNotNull(partial_in)` to skip null value before going to native
collect_set
- * - Add `If(IsNull(result), CreateArray(Seq.empty), result)` to replace
null to empty array
- *
- * TODO: remove this rule once Velox compatible with vanilla Spark.
- */
-object RewriteCollect extends RewriteSingleNode with PullOutProjectHelper {
- private lazy val shouldRewriteCollect =
- BackendsApiManager.getSettings.shouldRewriteCollect()
-
- private def shouldAddIsNotNull(ae: AggregateExpression): Boolean = {
- ae.aggregateFunction match {
- case c: CollectSet if c.child.nullable =>
- ae.mode match {
- case Partial | Complete => true
- case _ => false
- }
- case _ => false
- }
- }
-
- private def shouldReplaceNullToEmptyArray(ae: AggregateExpression): Boolean
= {
- ae.aggregateFunction match {
- case _: CollectSet =>
- ae.mode match {
- case Final | Complete => true
- case _ => false
- }
- case _ => false
- }
- }
-
- private def shouldRewrite(agg: BaseAggregateExec): Boolean = {
- agg.aggregateExpressions.exists {
- ae => shouldAddIsNotNull(ae) || shouldReplaceNullToEmptyArray(ae)
- }
- }
-
- private def rewriteCollectFilter(aggExprs: Seq[AggregateExpression]):
Seq[AggregateExpression] = {
- aggExprs
- .map {
- aggExpr =>
- if (shouldAddIsNotNull(aggExpr)) {
- val newFilter =
- (aggExpr.filter ++
Seq(IsNotNull(aggExpr.aggregateFunction.children.head)))
- .reduce(And)
- aggExpr.copy(filter = Option(newFilter))
- } else {
- aggExpr
- }
- }
- }
-
- private def rewriteAttributesAndResultExpressions(
- agg: BaseAggregateExec): (Seq[Attribute], Seq[NamedExpression]) = {
- val rewriteAggExprIndices = agg.aggregateExpressions.zipWithIndex
- .filter(exprAndIndex => shouldReplaceNullToEmptyArray(exprAndIndex._1))
- .map(_._2)
- .toSet
- if (rewriteAggExprIndices.isEmpty) {
- return (agg.aggregateAttributes, agg.resultExpressions)
- }
-
- assert(agg.aggregateExpressions.size == agg.aggregateAttributes.size)
- val rewriteAggAttributes = new ArrayBuffer[Attribute]()
- val newAggregateAttributes = agg.aggregateAttributes.zipWithIndex.map {
- case (attr, index) =>
- if (rewriteAggExprIndices.contains(index)) {
- rewriteAggAttributes.append(attr)
- // We should mark attribute as withNullability since the collect_set
and collect_set
- // are not nullable but velox may return null. This is to avoid
potential issue when
- // the post project fallback to vanilla Spark.
- attr.withNullability(true)
- } else {
- attr
- }
- }
- val rewriteAggAttributeSet = AttributeSet(rewriteAggAttributes)
- val newResultExpressions = agg.resultExpressions.map {
- ne =>
- val rewritten = ne.transformUp {
- case attr: Attribute if rewriteAggAttributeSet.contains(attr) =>
- assert(attr.dataType.isInstanceOf[ArrayType])
- If(IsNull(attr), Literal.create(Seq.empty, attr.dataType), attr)
- }
- assert(rewritten.isInstanceOf[NamedExpression])
- rewritten.asInstanceOf[NamedExpression]
- }
- (newAggregateAttributes, newResultExpressions)
- }
-
- override def rewrite(plan: SparkPlan): SparkPlan = {
- if (!shouldRewriteCollect) {
- return plan
- }
-
- plan match {
- case agg: BaseAggregateExec if shouldRewrite(agg) =>
- val newAggExprs = rewriteCollectFilter(agg.aggregateExpressions)
- val (newAttributes, newResultExprs) =
rewriteAttributesAndResultExpressions(agg)
- copyBaseAggregateExec(agg)(
- newAggregateExpressions = newAggExprs,
- newAggregateAttributes = newAttributes,
- newResultExpressions = newResultExprs)
-
- case _ => plan
- }
- }
-}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/RewriteIn.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteIn.scala
similarity index 96%
rename from
gluten-core/src/main/scala/org/apache/gluten/extension/RewriteIn.scala
rename to
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteIn.scala
index 565b9bb19..da120c39a 100644
--- a/gluten-core/src/main/scala/org/apache/gluten/extension/RewriteIn.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteIn.scala
@@ -14,9 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.gluten.extension
-
-import org.apache.gluten.extension.columnar.rewrite.RewriteSingleNode
+package org.apache.gluten.extension.columnar.rewrite
import org.apache.spark.sql.catalyst.expressions.{EqualTo, Expression, In, Or}
import org.apache.spark.sql.execution.{FileSourceScanExec, FilterExec,
SparkPlan}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSingleNode.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSingleNode.scala
index 73bc8b967..01f2e29fe 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSingleNode.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteSingleNode.scala
@@ -16,8 +16,6 @@
*/
package org.apache.gluten.extension.columnar.rewrite
-import org.apache.gluten.extension.RewriteIn
-
import org.apache.spark.sql.execution.SparkPlan
/**
@@ -37,12 +35,6 @@ trait RewriteSingleNode {
object RewriteSingleNode {
def allRules(): Seq[RewriteSingleNode] = {
- Seq(
- RewriteIn,
- RewriteMultiChildrenCount,
- RewriteCollect,
- RewriteTypedImperativeAggregate,
- PullOutPreProject,
- PullOutPostProject)
+ Seq(RewriteIn, RewriteMultiChildrenCount, PullOutPreProject,
PullOutPostProject)
}
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteTypedImperativeAggregate.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteTypedImperativeAggregate.scala
deleted file mode 100644
index 971a87923..000000000
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/rewrite/RewriteTypedImperativeAggregate.scala
+++ /dev/null
@@ -1,72 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.gluten.extension.columnar.rewrite
-
-import org.apache.gluten.backendsapi.BackendsApiManager
-import org.apache.gluten.utils.PullOutProjectHelper
-
-import org.apache.spark.sql.catalyst.expressions.AttributeReference
-import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
-
-object RewriteTypedImperativeAggregate extends RewriteSingleNode with
PullOutProjectHelper {
- private lazy val shouldRewriteTypedImperativeAggregate =
- BackendsApiManager.getSettings.shouldRewriteTypedImperativeAggregate()
-
- def shouldRewrite(ae: AggregateExpression): Boolean = {
- ae.aggregateFunction match {
- case _: CollectList | _: CollectSet =>
- ae.mode match {
- case Partial | PartialMerge => true
- case _ => false
- }
- case _ => false
- }
- }
-
- override def rewrite(plan: SparkPlan): SparkPlan = {
- if (!shouldRewriteTypedImperativeAggregate) {
- return plan
- }
-
- plan match {
- case agg: BaseAggregateExec if
agg.aggregateExpressions.exists(shouldRewrite) =>
- val exprMap = agg.aggregateExpressions
- .filter(shouldRewrite)
- .map(ae => ae.aggregateFunction.inputAggBufferAttributes.head -> ae)
- .toMap
- val newResultExpressions = agg.resultExpressions.map {
- case attr: AttributeReference =>
- exprMap
- .get(attr)
- .map {
- ae =>
- attr.copy(dataType = ae.aggregateFunction.dataType)(
- exprId = attr.exprId,
- qualifier = attr.qualifier
- )
- }
- .getOrElse(attr)
- case other => other
- }
- copyBaseAggregateExec(agg)(newResultExpressions = newResultExpressions)
-
- case _ => plan
- }
- }
-}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/validator/FallbackInjects.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/validator/FallbackInjects.scala
new file mode 100644
index 000000000..54139ec9e
--- /dev/null
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/validator/FallbackInjects.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension.columnar.validator
+
+import org.apache.spark.sql.execution.SparkPlan
+
+object FallbackInjects {
+ private var fallbackCondition: Option[PartialFunction[SparkPlan, Boolean]] =
None
+
+ def fallbackOn[T](condition: PartialFunction[SparkPlan, Boolean])(func: =>
T): T =
+ synchronized {
+ assert(this.fallbackCondition.isEmpty)
+ this.fallbackCondition = Some(condition)
+ try {
+ func
+ } finally {
+ this.fallbackCondition = None
+ }
+ }
+
+ private[validator] def shouldFallback(node: SparkPlan): Boolean = {
+ fallbackCondition.exists(_.applyOrElse(node, { _: SparkPlan => false }))
+ }
+}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/validator/Validators.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/validator/Validators.scala
index 57bcc7e09..d4bd9926a 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/validator/Validators.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/validator/Validators.scala
@@ -82,6 +82,11 @@ object Validators {
this
}
+ def fallbackByTestInjects(): Builder = {
+ buffer += new FallbackByTestInjects()
+ this
+ }
+
/** Add a custom validator to pipeline. */
def add(validator: Validator): Builder = {
buffer += validator
@@ -191,6 +196,15 @@ object Validators {
}
}
+ private class FallbackByTestInjects() extends Validator {
+ override def validate(plan: SparkPlan): Validator.OutCome = {
+ if (FallbackInjects.shouldFallback(plan)) {
+ return fail(plan)
+ }
+ pass()
+ }
+ }
+
private class ValidatorPipeline(validators: Seq[Validator]) extends
Validator {
override def validate(plan: SparkPlan): Validator.OutCome = {
val init: Validator.OutCome = pass()
diff --git
a/gluten-ut/common/src/test/scala/org/apache/gluten/utils/BackendTestSettings.scala
b/gluten-ut/common/src/test/scala/org/apache/gluten/utils/BackendTestSettings.scala
index fe8d4678d..987635d06 100644
---
a/gluten-ut/common/src/test/scala/org/apache/gluten/utils/BackendTestSettings.scala
+++
b/gluten-ut/common/src/test/scala/org/apache/gluten/utils/BackendTestSettings.scala
@@ -94,14 +94,6 @@ abstract class BackendTestSettings {
exclusion.add(Exclude(testNames: _*))
this
}
- def includeGlutenTest(testName: String*): SuiteSettings = {
- inclusion.add(IncludeGlutenTest(testName: _*))
- this
- }
- def excludeGlutenTest(testName: String*): SuiteSettings = {
- exclusion.add(ExcludeGlutenTest(testName: _*))
- this
- }
def includeByPrefix(prefixes: String*): SuiteSettings = {
inclusion.add(IncludeByPrefix(prefixes: _*))
this
@@ -110,22 +102,6 @@ abstract class BackendTestSettings {
exclusion.add(ExcludeByPrefix(prefixes: _*))
this
}
- def includeGlutenTestsByPrefix(prefixes: String*): SuiteSettings = {
- inclusion.add(IncludeGlutenTestByPrefix(prefixes: _*))
- this
- }
- def excludeGlutenTestsByPrefix(prefixes: String*): SuiteSettings = {
- exclusion.add(ExcludeGlutenTestByPrefix(prefixes: _*))
- this
- }
- def includeAllGlutenTests(): SuiteSettings = {
- inclusion.add(IncludeByPrefix(GLUTEN_TEST))
- this
- }
- def excludeAllGlutenTests(): SuiteSettings = {
- exclusion.add(ExcludeByPrefix(GLUTEN_TEST))
- this
- }
def disable(reason: String): SuiteSettings = {
disableReason = disableReason match {
@@ -136,6 +112,40 @@ abstract class BackendTestSettings {
}
}
+ object SuiteSettings {
+ implicit class SuiteSettingsImplicits(settings: SuiteSettings) {
+ def includeGlutenTest(testName: String*): SuiteSettings = {
+ settings.include(testName.map(GLUTEN_TEST + _): _*)
+ settings
+ }
+
+ def excludeGlutenTest(testName: String*): SuiteSettings = {
+ settings.exclude(testName.map(GLUTEN_TEST + _): _*)
+ settings
+ }
+
+ def includeGlutenTestsByPrefix(prefixes: String*): SuiteSettings = {
+ settings.includeByPrefix(prefixes.map(GLUTEN_TEST + _): _*)
+ settings
+ }
+
+ def excludeGlutenTestsByPrefix(prefixes: String*): SuiteSettings = {
+ settings.excludeByPrefix(prefixes.map(GLUTEN_TEST + _): _*)
+ settings
+ }
+
+ def includeAllGlutenTests(): SuiteSettings = {
+ settings.include(GLUTEN_TEST)
+ settings
+ }
+
+ def excludeAllGlutenTests(): SuiteSettings = {
+ settings.exclude(GLUTEN_TEST)
+ settings
+ }
+ }
+ }
+
protected trait IncludeBase {
def isIncluded(testName: String): Boolean
}
@@ -150,14 +160,6 @@ abstract class BackendTestSettings {
val nameSet: Set[String] = Set(testNames: _*)
override def isExcluded(testName: String): Boolean =
nameSet.contains(testName)
}
- private case class IncludeGlutenTest(testNames: String*) extends IncludeBase
{
- val nameSet: Set[String] = testNames.map(name => GLUTEN_TEST + name).toSet
- override def isIncluded(testName: String): Boolean =
nameSet.contains(testName)
- }
- private case class ExcludeGlutenTest(testNames: String*) extends ExcludeBase
{
- val nameSet: Set[String] = testNames.map(name => GLUTEN_TEST + name).toSet
- override def isExcluded(testName: String): Boolean =
nameSet.contains(testName)
- }
private case class IncludeByPrefix(prefixes: String*) extends IncludeBase {
override def isIncluded(testName: String): Boolean = {
if (prefixes.exists(prefix => testName.startsWith(prefix))) {
@@ -174,22 +176,6 @@ abstract class BackendTestSettings {
false
}
}
- private case class IncludeGlutenTestByPrefix(prefixes: String*) extends
IncludeBase {
- override def isIncluded(testName: String): Boolean = {
- if (prefixes.exists(prefix => testName.startsWith(GLUTEN_TEST +
prefix))) {
- return true
- }
- false
- }
- }
- private case class ExcludeGlutenTestByPrefix(prefixes: String*) extends
ExcludeBase {
- override def isExcluded(testName: String): Boolean = {
- if (prefixes.exists(prefix => testName.startsWith(GLUTEN_TEST +
prefix))) {
- return true
- }
- false
- }
- }
def getSQLQueryTestSettings: SQLQueryTestSettings
}
diff --git
a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
index a319c5ca9..1207514c2 100644
---
a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
+++
b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
@@ -54,7 +54,12 @@ class VeloxTestSettings extends BackendTestSettings {
"SPARK-32038: NormalizeFloatingNumbers should work on distinct
aggregate",
// Replaced with another test.
"SPARK-19471: AggregationIterator does not initialize the generated
result projection" +
- " before using it"
+ " before using it",
+ // Velox's collect_list / collect_set are by design declarative
aggregate so plan check
+ // for ObjectHashAggregateExec will fail.
+ "SPARK-22223: ObjectHashAggregate should not introduce unnecessary
shuffle",
+ "SPARK-31620: agg with subquery (whole-stage-codegen = true)",
+ "SPARK-31620: agg with subquery (whole-stage-codegen = false)"
)
enableSuite[GlutenCastSuite]
diff --git
a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
index d750456b6..40185aa63 100644
---
a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
+++
b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
@@ -953,7 +953,12 @@ class VeloxTestSettings extends BackendTestSettings {
"SPARK-32038: NormalizeFloatingNumbers should work on distinct
aggregate",
// Replaced with another test.
"SPARK-19471: AggregationIterator does not initialize the generated
result projection" +
- " before using it"
+ " before using it",
+ // Velox's collect_list / collect_set are by design declarative
aggregate so plan check
+ // for ObjectHashAggregateExec will fail.
+ "SPARK-22223: ObjectHashAggregate should not introduce unnecessary
shuffle",
+ "SPARK-31620: agg with subquery (whole-stage-codegen = true)",
+ "SPARK-31620: agg with subquery (whole-stage-codegen = false)"
)
enableSuite[GlutenDataFrameAsOfJoinSuite]
enableSuite[GlutenDataFrameComplexTypeSuite]
diff --git
a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
index 689eaf39e..47ad21958 100644
---
a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
+++
b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
@@ -958,7 +958,12 @@ class VeloxTestSettings extends BackendTestSettings {
"SPARK-32038: NormalizeFloatingNumbers should work on distinct
aggregate",
// Replaced with another test.
"SPARK-19471: AggregationIterator does not initialize the generated
result projection" +
- " before using it"
+ " before using it",
+ // Velox's collect_list / collect_set are by design declarative
aggregate so plan check
+ // for ObjectHashAggregateExec will fail.
+ "SPARK-22223: ObjectHashAggregate should not introduce unnecessary
shuffle",
+ "SPARK-31620: agg with subquery (whole-stage-codegen = true)",
+ "SPARK-31620: agg with subquery (whole-stage-codegen = false)"
)
enableSuite[GlutenDataFrameAsOfJoinSuite]
enableSuite[GlutenDataFrameComplexTypeSuite]
diff --git
a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenDataFrameAggregateSuite.scala
b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenDataFrameAggregateSuite.scala
index cba70c21f..de56b8834 100644
---
a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenDataFrameAggregateSuite.scala
+++
b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/GlutenDataFrameAggregateSuite.scala
@@ -188,21 +188,25 @@ class GlutenDataFrameAggregateSuite extends
DataFrameAggregateSuite with GlutenS
testGluten("use gluten hash agg to replace vanilla spark sort agg") {
withSQLConf(("spark.gluten.sql.columnar.force.hashagg", "false")) {
- Seq("A", "B", "C", "D").toDF("col1").createOrReplaceTempView("t1")
- // SortAggregateExec is expected to be used for string type input.
- val df = spark.sql("select max(col1) from t1")
- checkAnswer(df, Row("D") :: Nil)
-
assert(find(df.queryExecution.executedPlan)(_.isInstanceOf[SortAggregateExec]).isDefined)
+ withTempView("t1") {
+ Seq("A", "B", "C", "D").toDF("col1").createOrReplaceTempView("t1")
+ // SortAggregateExec is expected to be used for string type input.
+ val df = spark.sql("select max(col1) from t1")
+ checkAnswer(df, Row("D") :: Nil)
+
assert(find(df.queryExecution.executedPlan)(_.isInstanceOf[SortAggregateExec]).isDefined)
+ }
}
withSQLConf(("spark.gluten.sql.columnar.force.hashagg", "true")) {
- Seq("A", "B", "C", "D").toDF("col1").createOrReplaceTempView("t1")
- val df = spark.sql("select max(col1) from t1")
- checkAnswer(df, Row("D") :: Nil)
- // Sort agg is expected to be replaced by gluten's hash agg.
- assert(
- find(df.queryExecution.executedPlan)(
- _.isInstanceOf[HashAggregateExecBaseTransformer]).isDefined)
+ withTempView("t1") {
+ Seq("A", "B", "C", "D").toDF("col1").createOrReplaceTempView("t1")
+ val df = spark.sql("select max(col1) from t1")
+ checkAnswer(df, Row("D") :: Nil)
+ // Sort agg is expected to be replaced by gluten's hash agg.
+ assert(
+ find(df.queryExecution.executedPlan)(
+ _.isInstanceOf[HashAggregateExecBaseTransformer]).isDefined)
+ }
}
}
@@ -279,4 +283,55 @@ class GlutenDataFrameAggregateSuite extends
DataFrameAggregateSuite with GlutenS
randn(Random.nextLong())
).foreach(assertNoExceptions)
}
+
+ Seq(true, false).foreach {
+ value =>
+ testGluten(s"SPARK-31620: agg with subquery (whole-stage-codegen =
$value)") {
+ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> value.toString) {
+ withTempView("t1", "t2") {
+ sql("create temporary view t1 as select * from values (1, 2) as
t1(a, b)")
+ sql("create temporary view t2 as select * from values (3, 4) as
t2(c, d)")
+
+ // test without grouping keys
+ checkAnswer(
+ sql("select sum(if(c > (select a from t1), d, 0)) as csum from
t2"),
+ Row(4) :: Nil)
+
+ // test with grouping keys
+ checkAnswer(
+ sql(
+ "select c, sum(if(c > (select a from t1), d, 0)) as csum from
" +
+ "t2 group by c"),
+ Row(3, 4) :: Nil)
+
+ // test with distinct
+ checkAnswer(
+ sql(
+ "select avg(distinct(d)), sum(distinct(if(c > (select a from
t1)," +
+ " d, 0))) as csum from t2 group by c"),
+ Row(4, 4) :: Nil)
+
+ // test subquery with agg
+ checkAnswer(
+ sql(
+ "select sum(distinct(if(c > (select sum(distinct(a)) from
t1)," +
+ " d, 0))) as csum from t2 group by c"),
+ Row(4) :: Nil)
+
+ // test SortAggregateExec
+ var df = sql("select max(if(c > (select a from t1), 'str1',
'str2')) as csum from t2")
+ assert(
+
find(df.queryExecution.executedPlan)(_.isInstanceOf[SortAggregateExec]).isDefined)
+ checkAnswer(df, Row("str1") :: Nil)
+
+ // test SortAggregateExec (collect_list)
+ df =
+ sql("select collect_list(d), sum(if(c > (select a from t1), d,
0)) as csum from t2")
+ assert(
+
find(df.queryExecution.executedPlan)(_.isInstanceOf[SortAggregateExec]).isDefined)
+ checkAnswer(df, Row(Array(4), 4) :: Nil)
+ }
+ }
+ }
+ }
}
diff --git
a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
index c2385bd56..2aed0ff78 100644
---
a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
+++
b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
@@ -973,7 +973,12 @@ class VeloxTestSettings extends BackendTestSettings {
"SPARK-32038: NormalizeFloatingNumbers should work on distinct
aggregate",
// Replaced with another test.
"SPARK-19471: AggregationIterator does not initialize the generated
result projection" +
- " before using it"
+ " before using it",
+ // Velox's collect_list / collect_set are by design declarative
aggregate so plan check
+ // for ObjectHashAggregateExec will fail.
+ "SPARK-22223: ObjectHashAggregate should not introduce unnecessary
shuffle",
+ "SPARK-31620: agg with subquery (whole-stage-codegen = true)",
+ "SPARK-31620: agg with subquery (whole-stage-codegen = false)"
)
enableSuite[GlutenDataFrameAsOfJoinSuite]
enableSuite[GlutenDataFrameComplexTypeSuite]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]