This is an automated email from the ASF dual-hosted git repository.

hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new c0bd7f25e7 [CORE] Query planner: A more explicit practice to register 
columnar batch types (#8002)
c0bd7f25e7 is described below

commit c0bd7f25e7dacef2b4d76db1dd52078e58db2dc3
Author: Hongze Zhang <[email protected]>
AuthorDate: Wed Nov 20 13:31:00 2024 +0800

    [CORE] Query planner: A more explicit practice to register columnar batch 
types (#8002)
---
 .../backendsapi/clickhouse/CHListenerApi.scala     |  2 +-
 .../org/apache/gluten/columnarbatch/CHBatch.scala  |  6 ++-
 .../backendsapi/velox/VeloxListenerApi.scala       |  6 +--
 .../apache/gluten/columnarbatch/VeloxBatch.scala   | 10 +++--
 .../apache/gluten/columnarbatch/ArrowBatches.scala | 10 +++--
 .../enumerated/planner/property/Conv.scala         |  4 +-
 .../extension/columnar/transition/Convention.scala | 43 ++++++++++++++++++----
 .../columnar/transition/Transitions.scala          |  9 ++---
 .../columnar/EnsureLocalSortRequirements.scala     |  2 +-
 .../columnar/transition/TransitionSuite.scala      | 26 ++++++++-----
 10 files changed, 80 insertions(+), 38 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
index 6ae957912a..b93c002561 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
@@ -71,7 +71,7 @@ class CHListenerApi extends ListenerApi with Logging {
 
   private def initialize(conf: SparkConf, isDriver: Boolean): Unit = {
     // Force batch type initializations.
-    CHBatch.getClass
+    CHBatch.ensureRegistered()
     SparkDirectoryUtil.init(conf)
     val libPath = conf.get(GlutenConfig.GLUTEN_LIB_PATH, StringUtils.EMPTY)
     if (StringUtils.isBlank(libPath)) {
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/columnarbatch/CHBatch.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/columnarbatch/CHBatch.scala
index 870a731b11..ac0ca5f8b4 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/columnarbatch/CHBatch.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/columnarbatch/CHBatch.scala
@@ -38,6 +38,8 @@ import org.apache.spark.sql.execution.{CHColumnarToRowExec, 
RowToCHNativeColumna
  * }}}
  */
 object CHBatch extends Convention.BatchType {
-  fromRow(RowToCHNativeColumnarExec.apply)
-  toRow(CHColumnarToRowExec.apply)
+  override protected def registerTransitions(): Unit = {
+    fromRow(RowToCHNativeColumnarExec.apply)
+    toRow(CHColumnarToRowExec.apply)
+  }
 }
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
index a33fd22812..85de0a8889 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
@@ -123,9 +123,9 @@ class VeloxListenerApi extends ListenerApi with Logging {
 
   private def initialize(conf: SparkConf, isDriver: Boolean): Unit = {
     // Force batch type initializations.
-    VeloxBatch.getClass
-    ArrowJavaBatch.getClass
-    ArrowNativeBatch.getClass
+    VeloxBatch.ensureRegistered()
+    ArrowJavaBatch.ensureRegistered()
+    ArrowNativeBatch.ensureRegistered()
 
     // Sets this configuration only once, since not undoable.
     if (conf.getBoolean(GlutenConfig.GLUTEN_DEBUG_KEEP_JNI_WORKSPACE, 
defaultValue = false)) {
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/columnarbatch/VeloxBatch.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/columnarbatch/VeloxBatch.scala
index 0c7600c856..5d9a78d318 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/columnarbatch/VeloxBatch.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/columnarbatch/VeloxBatch.scala
@@ -20,8 +20,10 @@ import 
org.apache.gluten.execution.{ArrowColumnarToVeloxColumnarExec, RowToVelox
 import org.apache.gluten.extension.columnar.transition.{Convention, Transition}
 
 object VeloxBatch extends Convention.BatchType {
-  fromRow(RowToVeloxColumnarExec.apply)
-  toRow(VeloxColumnarToRowExec.apply)
-  fromBatch(ArrowBatches.ArrowNativeBatch, 
ArrowColumnarToVeloxColumnarExec.apply)
-  toBatch(ArrowBatches.ArrowNativeBatch, Transition.empty)
+  override protected def registerTransitions(): Unit = {
+    fromRow(RowToVeloxColumnarExec.apply)
+    toRow(VeloxColumnarToRowExec.apply)
+    fromBatch(ArrowBatches.ArrowNativeBatch, 
ArrowColumnarToVeloxColumnarExec.apply)
+    toBatch(ArrowBatches.ArrowNativeBatch, Transition.empty)
+  }
 }
diff --git 
a/gluten-arrow/src/main/scala/org/apache/gluten/columnarbatch/ArrowBatches.scala
 
b/gluten-arrow/src/main/scala/org/apache/gluten/columnarbatch/ArrowBatches.scala
index 5ae3863c57..c23d6eea79 100644
--- 
a/gluten-arrow/src/main/scala/org/apache/gluten/columnarbatch/ArrowBatches.scala
+++ 
b/gluten-arrow/src/main/scala/org/apache/gluten/columnarbatch/ArrowBatches.scala
@@ -33,7 +33,9 @@ object ArrowBatches {
    * implementations.
    */
   object ArrowJavaBatch extends Convention.BatchType {
-    toBatch(VanillaBatch, Transition.empty)
+    override protected def registerTransitions(): Unit = {
+      toBatch(VanillaBatch, Transition.empty)
+    }
   }
 
   /**
@@ -44,7 +46,9 @@ object ArrowBatches {
    * [[ColumnarBatches]].
    */
   object ArrowNativeBatch extends Convention.BatchType {
-    fromBatch(ArrowJavaBatch, OffloadArrowDataExec.apply)
-    toBatch(ArrowJavaBatch, LoadArrowDataExec.apply)
+    override protected def registerTransitions(): Unit = {
+      fromBatch(ArrowJavaBatch, OffloadArrowDataExec.apply)
+      toBatch(ArrowJavaBatch, LoadArrowDataExec.apply)
+    }
   }
 }
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/property/Conv.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/property/Conv.scala
index e9ca836eee..831b212e1f 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/property/Conv.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/property/Conv.scala
@@ -34,7 +34,7 @@ sealed trait Conv extends Property[SparkPlan] {
       return true
     }
     val prop = this.asInstanceOf[Prop]
-    val out = Transition.factory.satisfies(prop.prop, req.req)
+    val out = Transition.factory().satisfies(prop.prop, req.req)
     out
   }
 }
@@ -52,7 +52,7 @@ object Conv {
   def findTransition(from: Conv, to: Conv): Transition = {
     val prop = from.asInstanceOf[Prop]
     val req = to.asInstanceOf[Req]
-    val out = Transition.factory.findTransition(prop.prop, req.req, new 
IllegalStateException())
+    val out = Transition.factory().findTransition(prop.prop, req.req, new 
IllegalStateException())
     out
   }
 
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala
index 55bcb84d2b..840b62fb67 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala
@@ -18,6 +18,8 @@ package org.apache.gluten.extension.columnar.transition
 
 import org.apache.spark.sql.execution.{ColumnarToRowExec, RowToColumnarExec, 
SparkPlan}
 
+import java.util.concurrent.atomic.AtomicBoolean
+
 /**
  * Convention of a query plan consists of the row data type and columnar data 
type it supports to
  * output.
@@ -72,22 +74,43 @@ object Convention {
   }
 
   trait BatchType extends TransitionGraph.Vertex with Serializable {
-    Transition.graph.addVertex(this)
+    private val initialized: AtomicBoolean = new AtomicBoolean(false)
+
+    final def ensureRegistered(): Unit = {
+      if (!initialized.compareAndSet(false, true)) {
+        // Already registered.
+        return
+      }
+      register()
+    }
 
-    final protected def fromRow(transition: Transition): Unit = {
+    final private def register(): Unit = {
+      Transition.graph.addVertex(this)
+      registerTransitions()
+    }
+
+    ensureRegistered()
+
+    /**
+     * User batch type could override this method to define transitions 
from/to this batch type by
+     * calling the subsequent protected APIs.
+     */
+    protected[this] def registerTransitions(): Unit
+
+    final protected[this] def fromRow(transition: Transition): Unit = {
       Transition.graph.addEdge(RowType.VanillaRow, this, transition)
     }
 
-    final protected def toRow(transition: Transition): Unit = {
+    final protected[this] def toRow(transition: Transition): Unit = {
       Transition.graph.addEdge(this, RowType.VanillaRow, transition)
     }
 
-    final protected def fromBatch(from: BatchType, transition: Transition): 
Unit = {
+    final protected[this] def fromBatch(from: BatchType, transition: 
Transition): Unit = {
       assert(from != this)
       Transition.graph.addEdge(from, this, transition)
     }
 
-    final protected def toBatch(to: BatchType, transition: Transition): Unit = 
{
+    final protected[this] def toBatch(to: BatchType, transition: Transition): 
Unit = {
       assert(to != this)
       Transition.graph.addEdge(this, to, transition)
     }
@@ -95,10 +118,14 @@ object Convention {
 
   object BatchType {
     // None indicates that the plan doesn't support batch-based processing.
-    final case object None extends BatchType
+    final case object None extends BatchType {
+      override protected[this] def registerTransitions(): Unit = {}
+    }
     final case object VanillaBatch extends BatchType {
-      fromRow(RowToColumnarExec.apply)
-      toRow(ColumnarToRowExec.apply)
+      override protected[this] def registerTransitions(): Unit = {
+        fromRow(RowToColumnarExec.apply)
+        toRow(ColumnarToRowExec.apply)
+      }
     }
   }
 
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala
index 9987a65b0c..2f2840b52b 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala
@@ -54,7 +54,7 @@ case class InsertTransitions(outputsColumnar: Boolean) 
extends Rule[SparkPlan] {
           child
         } else {
           val transition =
-            Transition.factory.findTransition(from, convReq, 
Transition.notFound(node))
+            Transition.factory().findTransition(from, convReq, 
Transition.notFound(node))
           val newChild = transition.apply(child)
           newChild
         }
@@ -108,10 +108,9 @@ object Transitions {
   private def enforceReq(plan: SparkPlan, req: ConventionReq): SparkPlan = {
     val convFunc = ConventionFunc.create()
     val removed = RemoveTransitions.removeForNode(plan)
-    val transition = Transition.factory.findTransition(
-      convFunc.conventionOf(removed),
-      req,
-      Transition.notFound(removed, req))
+    val transition = Transition
+      .factory()
+      .findTransition(convFunc.conventionOf(removed), req, 
Transition.notFound(removed, req))
     val out = transition.apply(removed)
     out
   }
diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/EnsureLocalSortRequirements.scala
 
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/EnsureLocalSortRequirements.scala
index 29a7652885..056315186d 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/EnsureLocalSortRequirements.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/EnsureLocalSortRequirements.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.{SortExec, SparkPlan}
  * SortAggregate with the same key. So, this rule adds local sort back if 
necessary.
  */
 object EnsureLocalSortRequirements extends Rule[SparkPlan] {
-  private lazy val transform = HeuristicTransform.static()
+  private lazy val transform: HeuristicTransform = HeuristicTransform.static()
 
   private def addLocalSort(
       originalChild: SparkPlan,
diff --git 
a/gluten-substrait/src/test/scala/org/apache/gluten/extension/columnar/transition/TransitionSuite.scala
 
b/gluten-substrait/src/test/scala/org/apache/gluten/extension/columnar/transition/TransitionSuite.scala
index 5c35cb5020..9712bd2c21 100644
--- 
a/gluten-substrait/src/test/scala/org/apache/gluten/extension/columnar/transition/TransitionSuite.scala
+++ 
b/gluten-substrait/src/test/scala/org/apache/gluten/extension/columnar/transition/TransitionSuite.scala
@@ -87,23 +87,31 @@ class TransitionSuite extends SharedSparkSession {
 
 object TransitionSuite extends TransitionSuiteBase {
   object TypeA extends Convention.BatchType {
-    fromRow(RowToBatch(this, _))
-    toRow(BatchToRow(this, _))
+    override protected[this] def registerTransitions(): Unit = {
+      fromRow(RowToBatch(this, _))
+      toRow(BatchToRow(this, _))
+    }
   }
 
   object TypeB extends Convention.BatchType {
-    fromRow(RowToBatch(this, _))
-    toRow(BatchToRow(this, _))
+    override protected[this] def registerTransitions(): Unit = {
+      fromRow(RowToBatch(this, _))
+      toRow(BatchToRow(this, _))
+    }
   }
 
   object TypeC extends Convention.BatchType {
-    fromRow(RowToBatch(this, _))
-    toRow(BatchToRow(this, _))
-    fromBatch(TypeA, BatchToBatch(TypeA, this, _))
-    toBatch(TypeA, BatchToBatch(this, TypeA, _))
+    override protected[this] def registerTransitions(): Unit = {
+      fromRow(RowToBatch(this, _))
+      toRow(BatchToRow(this, _))
+      fromBatch(TypeA, BatchToBatch(TypeA, this, _))
+      toBatch(TypeA, BatchToBatch(this, TypeA, _))
+    }
   }
 
-  object TypeD extends Convention.BatchType {}
+  object TypeD extends Convention.BatchType {
+    override protected[this] def registerTransitions(): Unit = {}
+  }
 
   case class RowToBatch(toBatchType: Convention.BatchType, override val child: 
SparkPlan)
     extends RowToColumnarTransition


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to