This is an automated email from the ASF dual-hosted git repository.
changchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 0702eb8f7 [GLUTEN-5016][CH] fix simple aggregation sql exchange
fallback (#5042)
0702eb8f7 is described below
commit 0702eb8f73051407eb0dd9da495000360f459edd
Author: Wenzheng Liu <[email protected]>
AuthorDate: Fri Mar 22 13:03:43 2024 +0800
[GLUTEN-5016][CH] fix simple aggregation sql exchange fallback (#5042)
---
.../GlutenClickHouseTPCHNullableSuite.scala | 15 ++
.../extension/columnar/MiscColumnarRules.scala | 263 ++++++++++++---------
2 files changed, 163 insertions(+), 115 deletions(-)
diff --git
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHNullableSuite.scala
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHNullableSuite.scala
index 72a04f02c..42783babd 100644
---
a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHNullableSuite.scala
+++
b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHNullableSuite.scala
@@ -195,4 +195,19 @@ class GlutenClickHouseTPCHNullableSuite extends
GlutenClickHouseTPCHAbstractSuit
|""".stripMargin) { _ => }
assert(result(0).getLong(0) == 227302L)
}
+
+ test("test 'GLUTEN-5016'") {
+ withSQLConf(("spark.gluten.sql.columnar.preferColumnar", "false")) {
+ val sql =
+ """
+ |SELECT
+ | sum(l_quantity) AS sum_qty
+ |FROM
+ | lineitem
+ |WHERE
+ | l_shipdate <= date'1998-09-02'
+ |""".stripMargin
+ runSql(sql, noFallBack = true) { _ => }
+ }
+ }
}
diff --git
a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/MiscColumnarRules.scala
b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/MiscColumnarRules.scala
index 4fbadb0b5..d21956bfc 100644
---
a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/MiscColumnarRules.scala
+++
b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/MiscColumnarRules.scala
@@ -49,9 +49,12 @@ object MiscColumnarRules {
// Aggregation transformation.
private case class AggregationTransformRule() extends Rule[SparkPlan] with
LogLevelUtil {
- override def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
+ override def apply(plan: SparkPlan): SparkPlan = plan match {
+ case plan if TransformHints.isNotTransformable(plan) =>
+ plan
case agg: HashAggregateExec =>
genHashAggregateExec(agg)
+ case other => other
}
/**
@@ -105,13 +108,144 @@ object MiscColumnarRules {
}
}
+ // Exchange transformation.
+ private case class ExchangeTransformRule() extends Rule[SparkPlan] with
LogLevelUtil {
+ override def apply(plan: SparkPlan): SparkPlan = plan match {
+ case plan if TransformHints.isNotTransformable(plan) =>
+ plan
+ case plan: ShuffleExchangeExec =>
+ logDebug(s"Columnar Processing for ${plan.getClass} is currently
supported.")
+ val child = plan.child
+ if (
+ (child.supportsColumnar ||
GlutenConfig.getConf.enablePreferColumnar) &&
+ BackendsApiManager.getSettings.supportColumnarShuffleExec()
+ ) {
+
BackendsApiManager.getSparkPlanExecApiInstance.genColumnarShuffleExchange(plan,
child)
+ } else {
+ plan.withNewChildren(Seq(child))
+ }
+ case plan: BroadcastExchangeExec =>
+ val child = plan.child
+ logDebug(s"Columnar Processing for ${plan.getClass} is currently
supported.")
+ ColumnarBroadcastExchangeExec(plan.mode, child)
+ case other => other
+ }
+ }
+
+ // Join transformation.
+ private case class JoinTransformRule() extends Rule[SparkPlan] with
LogLevelUtil {
+
+ /**
+ * Get the build side supported by the execution of vanilla Spark.
+ *
+ * @param plan
+ * : shuffled hash join plan
+ * @return
+ * the supported build side
+ */
+ private def getSparkSupportedBuildSide(plan: ShuffledHashJoinExec):
BuildSide = {
+ plan.joinType match {
+ case LeftOuter | LeftSemi => BuildRight
+ case RightOuter => BuildLeft
+ case _ => plan.buildSide
+ }
+ }
+
+ override def apply(plan: SparkPlan): SparkPlan = {
+ if (TransformHints.isNotTransformable(plan)) {
+ logDebug(s"Columnar Processing for ${plan.getClass} is under row
guard.")
+ plan match {
+ case shj: ShuffledHashJoinExec =>
+ if (BackendsApiManager.getSettings.recreateJoinExecOnFallback())
{
+ // Because we manually removed the build side limitation for
LeftOuter, LeftSemi and
+ // RightOuter, need to change the build side back if this join
fallback into vanilla
+ // Spark for execution.
+ return ShuffledHashJoinExec(
+ shj.leftKeys,
+ shj.rightKeys,
+ shj.joinType,
+ getSparkSupportedBuildSide(shj),
+ shj.condition,
+ shj.left,
+ shj.right,
+ shj.isSkewJoin
+ )
+ } else {
+ return shj
+ }
+ case p =>
+ return p
+ }
+ }
+ plan match {
+ case plan: ShuffledHashJoinExec =>
+ val left = plan.left
+ val right = plan.right
+ logDebug(s"Columnar Processing for ${plan.getClass} is currently
supported.")
+ BackendsApiManager.getSparkPlanExecApiInstance
+ .genShuffledHashJoinExecTransformer(
+ plan.leftKeys,
+ plan.rightKeys,
+ plan.joinType,
+ plan.buildSide,
+ plan.condition,
+ left,
+ right,
+ plan.isSkewJoin)
+ case plan: SortMergeJoinExec =>
+ val left = plan.left
+ val right = plan.right
+ logDebug(s"Columnar Processing for ${plan.getClass} is currently
supported.")
+ SortMergeJoinExecTransformer(
+ plan.leftKeys,
+ plan.rightKeys,
+ plan.joinType,
+ plan.condition,
+ left,
+ right,
+ plan.isSkewJoin)
+ case plan: BroadcastHashJoinExec =>
+ val left = plan.left
+ val right = plan.right
+ BackendsApiManager.getSparkPlanExecApiInstance
+ .genBroadcastHashJoinExecTransformer(
+ plan.leftKeys,
+ plan.rightKeys,
+ plan.joinType,
+ plan.buildSide,
+ plan.condition,
+ left,
+ right,
+ isNullAwareAntiJoin = plan.isNullAwareAntiJoin)
+ case plan: CartesianProductExec =>
+ val left = plan.left
+ val right = plan.right
+ BackendsApiManager.getSparkPlanExecApiInstance
+ .genCartesianProductExecTransformer(left, right, plan.condition)
+ case plan: BroadcastNestedLoopJoinExec =>
+ val left = plan.left
+ val right = plan.right
+ BackendsApiManager.getSparkPlanExecApiInstance
+ .genBroadcastNestedLoopJoinExecTransformer(
+ left,
+ right,
+ plan.buildSide,
+ plan.joinType,
+ plan.condition)
+ case other => other
+ }
+ }
+
+ }
+
// Filter transformation.
private case class FilterTransformRule() extends Rule[SparkPlan] with
LogLevelUtil {
private val replace = new ReplaceSingleNode()
- override def apply(plan: SparkPlan): SparkPlan = plan.transformDown {
+ override def apply(plan: SparkPlan): SparkPlan = plan match {
case filter: FilterExec =>
genFilterExec(filter)
+ case other => other
}
/**
@@ -155,39 +289,18 @@ object MiscColumnarRules {
private case class RegularTransformRule() extends Rule[SparkPlan] with
LogLevelUtil {
private val replace = new ReplaceSingleNode()
- override def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
- case plan => replace.replaceWithTransformerPlan(plan)
- }
+ override def apply(plan: SparkPlan): SparkPlan =
replace.replaceWithTransformerPlan(plan)
}
// Utility to replace single node within transformed Gluten node.
// Children will be preserved as they are as children of the output node.
class ReplaceSingleNode() extends LogLevelUtil with Logging {
- private val columnarConf: GlutenConfig = GlutenConfig.getConf
def replaceWithTransformerPlan(p: SparkPlan): SparkPlan = {
val plan = p
if (TransformHints.isNotTransformable(plan)) {
logDebug(s"Columnar Processing for ${plan.getClass} is under row
guard.")
plan match {
- case shj: ShuffledHashJoinExec =>
- if (BackendsApiManager.getSettings.recreateJoinExecOnFallback())
{
- // Because we manually removed the build side limitation for
LeftOuter, LeftSemi and
- // RightOuter, need to change the build side back if this join
fallback into vanilla
- // Spark for execution.
- return ShuffledHashJoinExec(
- shj.leftKeys,
- shj.rightKeys,
- shj.joinType,
- getSparkSupportedBuildSide(shj),
- shj.condition,
- shj.left,
- shj.right,
- shj.isSkewJoin
- )
- } else {
- return shj
- }
case plan: BatchScanExec =>
return applyScanNotTransformable(plan)
case plan: FileSourceScanExec =>
@@ -283,75 +396,6 @@ object MiscColumnarRules {
plan.projectList,
child,
offset)
- case plan: ShuffleExchangeExec =>
- logDebug(s"Columnar Processing for ${plan.getClass} is currently
supported.")
- val child = plan.child
- if (
- (child.supportsColumnar || columnarConf.enablePreferColumnar) &&
- BackendsApiManager.getSettings.supportColumnarShuffleExec()
- ) {
-
BackendsApiManager.getSparkPlanExecApiInstance.genColumnarShuffleExchange(plan,
child)
- } else {
- plan.withNewChildren(Seq(child))
- }
- case plan: ShuffledHashJoinExec =>
- val left = plan.left
- val right = plan.right
- logDebug(s"Columnar Processing for ${plan.getClass} is currently
supported.")
- BackendsApiManager.getSparkPlanExecApiInstance
- .genShuffledHashJoinExecTransformer(
- plan.leftKeys,
- plan.rightKeys,
- plan.joinType,
- plan.buildSide,
- plan.condition,
- left,
- right,
- plan.isSkewJoin)
- case plan: SortMergeJoinExec =>
- val left = plan.left
- val right = plan.right
- logDebug(s"Columnar Processing for ${plan.getClass} is currently
supported.")
- SortMergeJoinExecTransformer(
- plan.leftKeys,
- plan.rightKeys,
- plan.joinType,
- plan.condition,
- left,
- right,
- plan.isSkewJoin)
- case plan: BroadcastExchangeExec =>
- val child = plan.child
- logDebug(s"Columnar Processing for ${plan.getClass} is currently
supported.")
- ColumnarBroadcastExchangeExec(plan.mode, child)
- case plan: BroadcastHashJoinExec =>
- val left = plan.left
- val right = plan.right
- BackendsApiManager.getSparkPlanExecApiInstance
- .genBroadcastHashJoinExecTransformer(
- plan.leftKeys,
- plan.rightKeys,
- plan.joinType,
- plan.buildSide,
- plan.condition,
- left,
- right,
- isNullAwareAntiJoin = plan.isNullAwareAntiJoin)
- case plan: CartesianProductExec =>
- val left = plan.left
- val right = plan.right
- BackendsApiManager.getSparkPlanExecApiInstance
- .genCartesianProductExecTransformer(left, right, plan.condition)
- case plan: BroadcastNestedLoopJoinExec =>
- val left = plan.left
- val right = plan.right
- BackendsApiManager.getSparkPlanExecApiInstance
- .genBroadcastNestedLoopJoinExecTransformer(
- left,
- right,
- plan.buildSide,
- plan.joinType,
- plan.condition)
case plan: WindowExec =>
WindowExecTransformer(
plan.windowExpression,
@@ -389,22 +433,6 @@ object MiscColumnarRules {
}
}
- /**
- * Get the build side supported by the execution of vanilla Spark.
- *
- * @param plan
- * : shuffled hash join plan
- * @return
- * the supported build side
- */
- private def getSparkSupportedBuildSide(plan: ShuffledHashJoinExec):
BuildSide = {
- plan.joinType match {
- case LeftOuter | LeftSemi => BuildRight
- case RightOuter => BuildLeft
- case _ => plan.buildSide
- }
- }
-
private def applyScanNotTransformable(plan: SparkPlan): SparkPlan = plan
match {
case plan: FileSourceScanExec =>
val newPartitionFilters =
@@ -489,18 +517,23 @@ object MiscColumnarRules {
case class TransformPreOverrides() extends Rule[SparkPlan] with LogLevelUtil
{
import TransformPreOverrides._
- private val subRules = List(
- FilterTransformRule(),
+ private val topdownRules = List(
+ FilterTransformRule()
+ )
+ private val bottomupRules = List(
RegularTransformRule(),
- AggregationTransformRule()
+ AggregationTransformRule(),
+ ExchangeTransformRule(),
+ JoinTransformRule()
)
@transient private val planChangeLogger = new PlanChangeLogger[SparkPlan]()
def apply(plan: SparkPlan): SparkPlan = {
- val newPlan = subRules.foldLeft(plan)((p, rule) => rule.apply(p))
- planChangeLogger.logRule(ruleName, plan, newPlan)
- newPlan
+ val plan0 = topdownRules.foldLeft(plan)((p, rule) => p.transformDown {
case p => rule(p) })
+ val plan1 = bottomupRules.foldLeft(plan0)((p, rule) => p.transformUp {
case p => rule(p) })
+ planChangeLogger.logRule(ruleName, plan, plan1)
+ plan1
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]