This is an automated email from the ASF dual-hosted git repository.

changchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new ba980f615 Fix per comments (#6745)
ba980f615 is described below

commit ba980f61593d6993e24d839119284c78987fc99b
Author: Chang chen <[email protected]>
AuthorDate: Thu Aug 8 16:22:51 2024 +0800

    Fix per comments (#6745)
    
    [GLUTEN-6705] [CORE] [Part 1] Avoid adding c2r for ColumnarWriteFilesExec, 
since it neither output Columnar batch data nor InternalRow
---
 .../gluten/backendsapi/clickhouse/package.scala    | 17 ++++++++++++++
 .../extension/columnar/ColumnarRuleApplier.scala   | 21 +++++++++++------
 .../columnar/transition/ConventionFunc.scala       | 22 +++++++++---------
 .../columnar/transition/Transitions.scala          | 26 ++++++++++------------
 .../sql/execution/ColumnarWriteFilesExec.scala     | 17 +++++++++++++-
 .../apache/gluten/metrics/GlutenTimeMetric.scala   |  6 +++++
 6 files changed, 75 insertions(+), 34 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/package.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/package.scala
index 8704fac7b..8975fb315 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/package.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/package.scala
@@ -21,6 +21,23 @@ import 
org.apache.gluten.extension.columnar.transition.Convention
 import org.apache.spark.sql.execution.{CHColumnarToRowExec, 
RowToCHNativeColumnarExec, SparkPlan}
 
 package object clickhouse {
+
+  /**
+   * ClickHouse batch convention.
+   *
+   * [[fromRow]] and [[toRow]] need a [[TransitionDef]] instance. The scala 
allows an compact way to
+   * implement trait using a lambda function.
+   *
+   * Here the detail definition is given in [[CHBatch.fromRow]].
+   * {{{
+   *       fromRow(new TransitionDef {
+   *       override def create(): Transition = new Transition {
+   *         override protected def apply0(plan: SparkPlan): SparkPlan =
+   *           RowToCHNativeColumnarExec(plan)
+   *       }
+   *     })
+   * }}}
+   */
   case object CHBatch extends Convention.BatchType {
     fromRow(
       () =>
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ColumnarRuleApplier.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ColumnarRuleApplier.scala
index ee5bcd883..27213698b 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ColumnarRuleApplier.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/ColumnarRuleApplier.scala
@@ -22,6 +22,7 @@ import org.apache.gluten.utils.LogLevelUtil
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
+import org.apache.spark.sql.catalyst.util.sideBySide
 import org.apache.spark.sql.execution.SparkPlan
 
 trait ColumnarRuleApplier {
@@ -47,13 +48,19 @@ object ColumnarRuleApplier {
     private val transformPlanLogLevel = 
GlutenConfig.getConf.transformPlanLogLevel
     override val ruleName: String = delegate.ruleName
 
-    override def apply(plan: SparkPlan): SparkPlan = 
GlutenTimeMetric.withMillisTime {
-      logOnLevel(
-        transformPlanLogLevel,
-        s"Preparing to apply rule $ruleName on plan:\n${plan.toString}")
-      val out = delegate.apply(plan)
-      logOnLevel(transformPlanLogLevel, s"Plan after applied rule 
$ruleName:\n${plan.toString}")
+    private def message(oldPlan: SparkPlan, newPlan: SparkPlan, millisTime: 
Long): String =
+      if (!oldPlan.fastEquals(newPlan)) {
+        s"""
+           |=== Applying Rule $ruleName took $millisTime ms ===
+           |${sideBySide(oldPlan.treeString, 
newPlan.treeString).mkString("\n")}
+           """.stripMargin
+      } else { s"Rule $ruleName has no effect, took $millisTime ms." }
+
+    override def apply(plan: SparkPlan): SparkPlan = {
+      val (out, millisTime) = 
GlutenTimeMetric.recordMillisTime(delegate.apply(plan))
+      logOnLevel(transformPlanLogLevel, message(plan, out, millisTime))
       out
-    }(t => logOnLevel(transformPlanLogLevel, s"Applying rule $ruleName took $t 
ms."))
+    }
+
   }
 }
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionFunc.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionFunc.scala
index 453df5d88..beb809474 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionFunc.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionFunc.scala
@@ -101,14 +101,14 @@ object ConventionFunc {
       val out = plan match {
         case k: Convention.KnownRowType =>
           k.rowType()
-        case _ if !SparkShimLoader.getSparkShims.supportsRowBased(plan) =>
-          Convention.RowType.None
-        case _ =>
+        case _ if SparkShimLoader.getSparkShims.supportsRowBased(plan) =>
           Convention.RowType.VanillaRow
+        case _ =>
+          Convention.RowType.None
       }
-      if (out != Convention.RowType.None) {
-        assert(SparkShimLoader.getSparkShims.supportsRowBased(plan))
-      }
+      assert(
+        out == Convention.RowType.None || 
plan.isInstanceOf[Convention.KnownRowType] ||
+          SparkShimLoader.getSparkShims.supportsRowBased(plan))
       out
     }
 
@@ -119,15 +119,13 @@ object ConventionFunc {
           p match {
             case k: Convention.KnownBatchType =>
               k.batchType()
-            case _ if !plan.supportsColumnar =>
-              Convention.BatchType.None
-            case _ =>
+            case _ if plan.supportsColumnar =>
               Convention.BatchType.VanillaBatch
+            case _ =>
+              Convention.BatchType.None
           }
       )
-      if (out != Convention.BatchType.None) {
-        assert(plan.supportsColumnar)
-      }
+      assert(out == Convention.BatchType.None || plan.supportsColumnar)
       out
     }
 
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala
index d02aadd49..602f0303c 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala
@@ -77,22 +77,15 @@ object RemoveTransitions extends Rule[SparkPlan] {
 
 object Transitions {
   def insertTransitions(plan: SparkPlan, outputsColumnar: Boolean): SparkPlan 
= {
-    val out = InsertTransitions(outputsColumnar).apply(plan)
-    out
+    InsertTransitions(outputsColumnar).apply(plan)
   }
 
   def toRowPlan(plan: SparkPlan): SparkPlan = {
-    val convFunc = ConventionFunc.create()
-    val req = ConventionReq.of(
-      ConventionReq.RowType.Is(Convention.RowType.VanillaRow),
-      ConventionReq.BatchType.Any)
-    val removed = RemoveTransitions.removeForNode(plan)
-    val transition = Transition.factory.findTransition(
-      convFunc.conventionOf(removed),
-      req,
-      Transition.notFound(removed, req))
-    val out = transition.apply(removed)
-    out
+    enforceReq(
+      plan,
+      ConventionReq.of(
+        ConventionReq.RowType.Is(Convention.RowType.VanillaRow),
+        ConventionReq.BatchType.Any))
   }
 
   def toBackendBatchPlan(plan: SparkPlan): SparkPlan = {
@@ -107,8 +100,13 @@ object Transitions {
   }
 
   private def toBatchPlan(plan: SparkPlan, toBatchType: Convention.BatchType): 
SparkPlan = {
+    enforceReq(
+      plan,
+      ConventionReq.of(ConventionReq.RowType.Any, 
ConventionReq.BatchType.Is(toBatchType)))
+  }
+
+  private def enforceReq(plan: SparkPlan, req: ConventionReq): SparkPlan = {
     val convFunc = ConventionFunc.create()
-    val req = ConventionReq.of(ConventionReq.RowType.Any, 
ConventionReq.BatchType.Is(toBatchType))
     val removed = RemoveTransitions.removeForNode(plan)
     val transition = Transition.factory.findTransition(
       convFunc.conventionOf(removed),
diff --git 
a/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarWriteFilesExec.scala
 
b/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarWriteFilesExec.scala
index 6f04b8480..b809a4b65 100644
--- 
a/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarWriteFilesExec.scala
+++ 
b/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarWriteFilesExec.scala
@@ -19,6 +19,9 @@ package org.apache.spark.sql.execution
 import org.apache.gluten.backendsapi.BackendsApiManager
 import org.apache.gluten.exception.GlutenException
 import org.apache.gluten.extension.GlutenPlan
+import 
org.apache.gluten.extension.columnar.transition.Convention.{KnownRowType, 
RowType}
+import org.apache.gluten.extension.columnar.transition.ConventionReq
+import 
org.apache.gluten.extension.columnar.transition.ConventionReq.KnownChildrenConventions
 import org.apache.gluten.sql.shims.SparkShimLoader
 
 import org.apache.spark.{Partition, SparkException, TaskContext, 
TaskOutputFileAlreadyExistException}
@@ -269,12 +272,24 @@ object ColumnarWriteFilesExec {
     override def output: Seq[Attribute] = Seq.empty
   }
 
-  sealed trait ExecuteWriteCompatible {
+  /**
+   * ColumnarWriteFilesExec neither output Row nor columnar data. We output 
both row and columnar to
+   * avoid c2r and r2c transitions. Please note, [[GlutenPlan]] already 
implement batchType()
+   */
+  sealed trait ExecuteWriteCompatible extends KnownChildrenConventions with 
KnownRowType {
     // To be compatible with Spark (version < 3.4)
     protected def doExecuteWrite(writeFilesSpec: WriteFilesSpec): 
RDD[WriterCommitMessage] = {
       throw new GlutenException(
         s"Internal Error ${this.getClass} has write support" +
           s" mismatch:\n${this}")
     }
+
+    override def requiredChildrenConventions(): Seq[ConventionReq] = {
+      List(ConventionReq.backendBatch)
+    }
+
+    override def rowType(): RowType = {
+      RowType.VanillaRow
+    }
   }
 }
diff --git 
a/shims/common/src/main/scala/org/apache/gluten/metrics/GlutenTimeMetric.scala 
b/shims/common/src/main/scala/org/apache/gluten/metrics/GlutenTimeMetric.scala
index 37e824ea6..a2a187a4d 100644
--- 
a/shims/common/src/main/scala/org/apache/gluten/metrics/GlutenTimeMetric.scala
+++ 
b/shims/common/src/main/scala/org/apache/gluten/metrics/GlutenTimeMetric.scala
@@ -44,4 +44,10 @@ object GlutenTimeMetric {
   }
   def withMillisTime[U](block: => U)(millisTime: Long => Unit): U =
     withNanoTime(block)(t => millisTime(TimeUnit.NANOSECONDS.toMillis(t)))
+
+  def recordMillisTime[U](block: => U): (U, Long) = {
+    var time = 0L
+    val result = withMillisTime(block)(time = _)
+    (result, time)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to