This is an automated email from the ASF dual-hosted git repository.

hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 255b0cc8eb [GLUTEN-7313][VL] Explicit Arrow transitions, part 2: new 
algorithm to find optimal transition (#7372)
255b0cc8eb is described below

commit 255b0cc8ebae7b23e2966dfd922b591958b8b6fa
Author: Hongze Zhang <[email protected]>
AuthorDate: Sun Sep 29 09:10:55 2024 +0800

    [GLUTEN-7313][VL] Explicit Arrow transitions, part 2: new algorithm to find 
optimal transition (#7372)
---
 .../gluten/backendsapi/clickhouse/CHBackend.scala  |   2 +-
 .../backendsapi/clickhouse/CHListenerApi.scala     |   3 +
 .../org/apache/gluten/columnarbatch/CHBatch.scala  |   5 +-
 .../gluten/backendsapi/velox/VeloxBackend.scala    |  24 ++--
 .../backendsapi/velox/VeloxListenerApi.scala       |   8 ++
 .../api/python/ColumnarArrowEvalPythonExec.scala   |  21 ++-
 .../columnar/transition/VeloxTransitionSuite.scala |  44 +++---
 .../scala/org/apache/gluten/backend/Backend.scala  |   9 +-
 .../extension/columnar/transition/Convention.scala |  24 ++--
 .../columnar/transition/ConventionFunc.scala       |  56 +++++---
 .../columnar/transition/ConventionReq.scala        |   2 +-
 .../columnar/transition/FloydWarshallGraph.scala   | 158 +++++++++++++++++++++
 .../extension/columnar/transition/Transition.scala | 119 ++++------------
 .../columnar/transition/TransitionGraph.scala      |  90 ++++++++++++
 .../columnar/transition/Transitions.scala          |   2 +-
 .../extension/columnar/transition/package.scala    |   1 +
 .../apache/spark/util/SparkReflectionUtil.scala    |  23 +++
 .../transition/FloydWarshallGraphSuite.scala       | 103 ++++++++++++++
 .../org/apache/gluten/extension/GlutenPlan.scala   |   2 +-
 .../scala/org/apache/gluten/utils/PlanUtil.scala   |   4 +-
 .../ColumnarCollapseTransformStages.scala          |   2 +-
 21 files changed, 521 insertions(+), 181 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
index d8966519da..900c9bafb4 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
@@ -49,7 +49,7 @@ import scala.util.control.Breaks.{break, breakable}
 
 class CHBackend extends SubstraitBackend {
   override def name(): String = CHConf.BACKEND_NAME
-  override def batchType: Convention.BatchType = CHBatch
+  override def defaultBatchType: Convention.BatchType = CHBatch
   override def buildInfo(): Backend.BuildInfo =
     Backend.BuildInfo("ClickHouse", CH_BRANCH, CH_COMMIT, "UNKNOWN")
   override def iteratorApi(): IteratorApi = new CHIteratorApi
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
index f326f2edcf..16f5fa064c 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
@@ -18,6 +18,7 @@ package org.apache.gluten.backendsapi.clickhouse
 
 import org.apache.gluten.GlutenConfig
 import org.apache.gluten.backendsapi.ListenerApi
+import org.apache.gluten.columnarbatch.CHBatch
 import org.apache.gluten.execution.CHBroadcastBuildSideCache
 import org.apache.gluten.execution.datasource.{GlutenOrcWriterInjects, 
GlutenParquetWriterInjects, GlutenRowSplitter}
 import org.apache.gluten.expression.UDFMappings
@@ -68,6 +69,8 @@ class CHListenerApi extends ListenerApi with Logging {
   override def onExecutorShutdown(): Unit = shutdown()
 
   private def initialize(conf: SparkConf, isDriver: Boolean): Unit = {
+    // Force batch type initializations.
+    CHBatch.getClass
     SparkDirectoryUtil.init(conf)
     val libPath = conf.get(GlutenConfig.GLUTEN_LIB_PATH, StringUtils.EMPTY)
     if (StringUtils.isBlank(libPath)) {
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/columnarbatch/CHBatch.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/columnarbatch/CHBatch.scala
index 0121d01578..079fa6bfd7 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/columnarbatch/CHBatch.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/columnarbatch/CHBatch.scala
@@ -23,8 +23,9 @@ import org.apache.spark.sql.execution.{CHColumnarToRowExec, 
RowToCHNativeColumna
 /**
  * ClickHouse batch convention.
  *
- * [[fromRow]] and [[toRow]] need a [[TransitionDef]] instance. The scala 
allows an compact way to
- * implement trait using a lambda function.
+ * [[fromRow]] and [[toRow]] need a
+ * [[org.apache.gluten.extension.columnar.transition.TransitionDef]] instance. 
The scala allows an
+ * compact way to implement trait using a lambda function.
  *
  * Here the detail definition is given in [[CHBatch.fromRow]].
  * {{{
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
index 027876a2b0..56dc92f420 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
@@ -25,8 +25,7 @@ import org.apache.gluten.exception.GlutenNotSupportException
 import org.apache.gluten.execution.WriteFilesExecTransformer
 import org.apache.gluten.expression.WindowFunctionsBuilder
 import org.apache.gluten.extension.ValidationResult
-import org.apache.gluten.extension.columnar.transition.Convention
-import 
org.apache.gluten.extension.columnar.transition.ConventionFunc.BatchOverride
+import org.apache.gluten.extension.columnar.transition.{Convention, 
ConventionFunc}
 import org.apache.gluten.sql.shims.SparkShimLoader
 import org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat
 import 
org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat.{DwrfReadFormat, 
OrcReadFormat, ParquetReadFormat}
@@ -37,7 +36,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, 
CumeDist, DenseRank, De
 import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
ApproximatePercentile}
 import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter}
 import org.apache.spark.sql.catalyst.util.CharVarcharUtils
-import org.apache.spark.sql.execution.ColumnarCachedBatchSerializer
+import org.apache.spark.sql.execution.{ColumnarCachedBatchSerializer, 
SparkPlan}
 import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
 import 
org.apache.spark.sql.execution.command.CreateDataSourceTableAsSelectCommand
 import org.apache.spark.sql.execution.datasources.{FileFormat, 
InsertIntoHadoopFsRelationCommand}
@@ -51,14 +50,10 @@ import org.apache.hadoop.fs.Path
 import scala.util.control.Breaks.breakable
 
 class VeloxBackend extends SubstraitBackend {
+  import VeloxBackend._
   override def name(): String = VeloxBackend.BACKEND_NAME
-  override def batchType: Convention.BatchType = VeloxBatch
-  override def batchTypeFunc(): BatchOverride = {
-    case i: InMemoryTableScanExec
-        if i.supportsColumnar && i.relation.cacheBuilder.serializer
-          .isInstanceOf[ColumnarCachedBatchSerializer] =>
-      VeloxBatch
-  }
+  override def defaultBatchType: Convention.BatchType = VeloxBatch
+  override def convFuncOverride(): ConventionFunc.Override = new ConvFunc()
   override def buildInfo(): Backend.BuildInfo =
     Backend.BuildInfo("Velox", VELOX_BRANCH, VELOX_REVISION, 
VELOX_REVISION_TIME)
   override def iteratorApi(): IteratorApi = new VeloxIteratorApi
@@ -74,6 +69,15 @@ class VeloxBackend extends SubstraitBackend {
 object VeloxBackend {
   val BACKEND_NAME: String = "velox"
   val CONF_PREFIX: String = GlutenConfig.prefixOf(BACKEND_NAME)
+
+  private class ConvFunc() extends ConventionFunc.Override {
+    override def batchTypeOf: PartialFunction[SparkPlan, Convention.BatchType] 
= {
+      case i: InMemoryTableScanExec
+          if i.supportsColumnar && i.relation.cacheBuilder.serializer
+            .isInstanceOf[ColumnarCachedBatchSerializer] =>
+        VeloxBatch
+    }
+  }
 }
 
 object VeloxBackendSettings extends BackendSettingsApi {
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
index 926991ada7..e763e31dc5 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
@@ -18,6 +18,8 @@ package org.apache.gluten.backendsapi.velox
 
 import org.apache.gluten.GlutenConfig
 import org.apache.gluten.backendsapi.ListenerApi
+import org.apache.gluten.columnarbatch.ArrowBatches.{ArrowJavaBatch, 
ArrowNativeBatch}
+import org.apache.gluten.columnarbatch.VeloxBatch
 import org.apache.gluten.execution.datasource.{GlutenOrcWriterInjects, 
GlutenParquetWriterInjects, GlutenRowSplitter}
 import org.apache.gluten.expression.UDFMappings
 import org.apache.gluten.init.NativeBackendInitializer
@@ -119,6 +121,12 @@ class VeloxListenerApi extends ListenerApi with Logging {
   override def onExecutorShutdown(): Unit = shutdown()
 
   private def initialize(conf: SparkConf): Unit = {
+    // Force batch type initializations.
+    VeloxBatch.getClass
+    ArrowJavaBatch.getClass
+    ArrowNativeBatch.getClass
+
+    // Sets this configuration only once, since not undoable.
     if (conf.getBoolean(GlutenConfig.GLUTEN_DEBUG_KEEP_JNI_WORKSPACE, 
defaultValue = false)) {
       val debugDir = conf.get(GlutenConfig.GLUTEN_DEBUG_KEEP_JNI_WORKSPACE_DIR)
       JniWorkspace.enableDebug(debugDir)
diff --git 
a/backends-velox/src/main/scala/org/apache/spark/api/python/ColumnarArrowEvalPythonExec.scala
 
b/backends-velox/src/main/scala/org/apache/spark/api/python/ColumnarArrowEvalPythonExec.scala
index 5c27f94ca8..f1f5eb9062 100644
--- 
a/backends-velox/src/main/scala/org/apache/spark/api/python/ColumnarArrowEvalPythonExec.scala
+++ 
b/backends-velox/src/main/scala/org/apache/spark/api/python/ColumnarArrowEvalPythonExec.scala
@@ -16,8 +16,8 @@
  */
 package org.apache.spark.api.python
 
-import org.apache.gluten.columnarbatch.{ColumnarBatches, VeloxBatch}
 import org.apache.gluten.columnarbatch.ArrowBatches.ArrowJavaBatch
+import org.apache.gluten.columnarbatch.ColumnarBatches
 import org.apache.gluten.exception.GlutenException
 import org.apache.gluten.extension.GlutenPlan
 import org.apache.gluten.extension.columnar.transition.{Convention, 
ConventionReq}
@@ -218,13 +218,8 @@ case class ColumnarArrowEvalPythonExec(
 
   override protected def batchType0(): Convention.BatchType = ArrowJavaBatch
 
-  // FIXME: Make this accepts ArrowJavaBatch as input. Before doing that, a 
weight-based
-  //  shortest patch algorithm should be added into transition factory. So 
that the factory
-  //  can find out row->velox->arrow-native->arrow-java as the possible viable 
transition.
-  //  Otherwise with current solution, any input (even already in Arrow Java 
format) will be
-  //  converted into Velox format then into Arrow Java format before entering 
python runner.
   override def requiredChildrenConventions(): Seq[ConventionReq] = List(
-    ConventionReq.of(ConventionReq.RowType.Any, 
ConventionReq.BatchType.Is(VeloxBatch)))
+    ConventionReq.of(ConventionReq.RowType.Any, 
ConventionReq.BatchType.Is(ArrowJavaBatch)))
 
   override lazy val metrics = Map(
     "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output 
rows"),
@@ -348,17 +343,17 @@ case class ColumnarArrowEvalPythonExec(
         val inputBatchIter = contextAwareIterator.map {
           inputCb =>
             start_time = System.nanoTime()
-            val loaded = 
ColumnarBatches.load(ArrowBufferAllocators.contextInstance(), inputCb)
-            ColumnarBatches.retain(loaded)
+            ColumnarBatches.checkLoaded(inputCb)
+            ColumnarBatches.retain(inputCb)
             // 0. cache input for later merge
-            inputCbCache += loaded
-            numInputRows += loaded.numRows
+            inputCbCache += inputCb
+            numInputRows += inputCb.numRows
             // We only need to pass the referred cols data to python worker 
for evaluation.
             var colsForEval = new ArrayBuffer[ColumnVector]()
             for (i <- originalOffsets) {
-              colsForEval += loaded.column(i)
+              colsForEval += inputCb.column(i)
             }
-            new ColumnarBatch(colsForEval.toArray, loaded.numRows())
+            new ColumnarBatch(colsForEval.toArray, inputCb.numRows())
         }
 
         val outputColumnarBatchIterator =
diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/transition/VeloxTransitionSuite.scala
 
b/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/transition/VeloxTransitionSuite.scala
index df8ee8cc53..e23e7885f2 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/transition/VeloxTransitionSuite.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/transition/VeloxTransitionSuite.scala
@@ -19,7 +19,6 @@ package org.apache.gluten.extension.columnar.transition
 import org.apache.gluten.backendsapi.velox.VeloxListenerApi
 import org.apache.gluten.columnarbatch.ArrowBatches.{ArrowJavaBatch, 
ArrowNativeBatch}
 import org.apache.gluten.columnarbatch.VeloxBatch
-import org.apache.gluten.exception.GlutenException
 import org.apache.gluten.execution.{LoadArrowDataExec, OffloadArrowDataExec, 
RowToVeloxColumnarExec, VeloxColumnarToRowExec}
 import 
org.apache.gluten.extension.columnar.transition.Convention.BatchType.VanillaBatch
 import org.apache.gluten.test.MockVeloxBackend
@@ -64,11 +63,10 @@ class VeloxTransitionSuite extends SharedSparkSession {
 
   test("ArrowNative R2C - requires Arrow input") {
     val in = BatchUnary(ArrowNativeBatch, RowLeaf())
-    assertThrows[GlutenException] {
-      // No viable transitions.
-      // FIXME: Support this case.
-      Transitions.insertTransitions(in, outputsColumnar = false)
-    }
+    val out = Transitions.insertTransitions(in, outputsColumnar = false)
+    assert(
+      out == ColumnarToRowExec(
+        LoadArrowDataExec(BatchUnary(ArrowNativeBatch, 
RowToVeloxColumnarExec(RowLeaf())))))
   }
 
   test("ArrowNative-to-Velox C2C") {
@@ -92,11 +90,12 @@ class VeloxTransitionSuite extends SharedSparkSession {
 
   test("Vanilla-to-ArrowNative C2C") {
     val in = BatchUnary(ArrowNativeBatch, BatchLeaf(VanillaBatch))
-    assertThrows[GlutenException] {
-      // No viable transitions.
-      // FIXME: Support this case.
-      Transitions.insertTransitions(in, outputsColumnar = false)
-    }
+    val out = Transitions.insertTransitions(in, outputsColumnar = false)
+    assert(
+      out == ColumnarToRowExec(
+        LoadArrowDataExec(BatchUnary(
+          ArrowNativeBatch,
+          
RowToVeloxColumnarExec(ColumnarToRowExec(BatchLeaf(VanillaBatch)))))))
   }
 
   test("ArrowNative-to-Vanilla C2C") {
@@ -121,11 +120,10 @@ class VeloxTransitionSuite extends SharedSparkSession {
 
   test("ArrowJava R2C - requires Arrow input") {
     val in = BatchUnary(ArrowJavaBatch, RowLeaf())
-    assertThrows[GlutenException] {
-      // No viable transitions.
-      // FIXME: Support this case.
-      Transitions.insertTransitions(in, outputsColumnar = false)
-    }
+    val out = Transitions.insertTransitions(in, outputsColumnar = false)
+    assert(
+      out == ColumnarToRowExec(
+        BatchUnary(ArrowJavaBatch, 
LoadArrowDataExec(RowToVeloxColumnarExec(RowLeaf())))))
   }
 
   test("ArrowJava-to-Velox C2C") {
@@ -146,11 +144,12 @@ class VeloxTransitionSuite extends SharedSparkSession {
 
   test("Vanilla-to-ArrowJava C2C") {
     val in = BatchUnary(ArrowJavaBatch, BatchLeaf(VanillaBatch))
-    assertThrows[GlutenException] {
-      // No viable transitions.
-      // FIXME: Support this case.
-      Transitions.insertTransitions(in, outputsColumnar = false)
-    }
+    val out = Transitions.insertTransitions(in, outputsColumnar = false)
+    assert(
+      out == ColumnarToRowExec(
+        BatchUnary(
+          ArrowJavaBatch,
+          
LoadArrowDataExec(RowToVeloxColumnarExec(ColumnarToRowExec(BatchLeaf(VanillaBatch)))))))
   }
 
   test("ArrowJava-to-Vanilla C2C") {
@@ -195,8 +194,7 @@ class VeloxTransitionSuite extends SharedSparkSession {
     val in = BatchUnary(VanillaBatch, BatchLeaf(VeloxBatch))
     val out = Transitions.insertTransitions(in, outputsColumnar = false)
     assert(
-      out == ColumnarToRowExec(
-        BatchUnary(VanillaBatch, 
RowToColumnarExec(VeloxColumnarToRowExec(BatchLeaf(VeloxBatch))))))
+      out == ColumnarToRowExec(BatchUnary(VanillaBatch, 
LoadArrowDataExec(BatchLeaf(VeloxBatch)))))
   }
 
   override protected def beforeAll(): Unit = {
diff --git a/gluten-core/src/main/scala/org/apache/gluten/backend/Backend.scala 
b/gluten-core/src/main/scala/org/apache/gluten/backend/Backend.scala
index 5f82a2ee7d..1b9175d6b0 100644
--- a/gluten-core/src/main/scala/org/apache/gluten/backend/Backend.scala
+++ b/gluten-core/src/main/scala/org/apache/gluten/backend/Backend.scala
@@ -39,15 +39,16 @@ trait Backend {
   def onExecutorStart(pc: PluginContext): Unit = {}
   def onExecutorShutdown(): Unit = {}
 
-  /** The columnar-batch type this backend is using. */
-  def batchType: Convention.BatchType
+  /** The columnar-batch type this backend is by default using. */
+  def defaultBatchType: Convention.BatchType
 
   /**
    * Overrides 
[[org.apache.gluten.extension.columnar.transition.ConventionFunc]] Gluten is 
using to
    * determine the convention (its row-based processing / columnar-batch 
processing support) of a
-   * plan with a user-defined function that accepts a plan then returns batch 
type it outputs.
+   * plan with a user-defined function that accepts a plan then returns 
convention type it outputs,
+   * and input conventions it requires.
    */
-  def batchTypeFunc(): ConventionFunc.BatchOverride = PartialFunction.empty
+  def convFuncOverride(): ConventionFunc.Override = 
ConventionFunc.Override.Empty
 
   /** Query planner rules. */
   def injectRules(injector: RuleInjector): Unit
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala
index fcd34cb1f2..d585af7c71 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala
@@ -61,7 +61,9 @@ object Convention {
     Impl(rowType, batchType)
   }
 
-  sealed trait RowType
+  sealed trait RowType extends TransitionGraph.Vertex with Serializable {
+    Transition.graph.addVertex(this)
+  }
 
   object RowType {
     // None indicates that the plan doesn't support row-based processing.
@@ -69,23 +71,25 @@ object Convention {
     final case object VanillaRow extends RowType
   }
 
-  trait BatchType extends Serializable {
-    final def fromRow(transitionDef: TransitionDef): Unit = {
-      Transition.factory.update().defineFromRowTransition(this, transitionDef)
+  trait BatchType extends TransitionGraph.Vertex with Serializable {
+    Transition.graph.addVertex(this)
+
+    final protected def fromRow(transitionDef: TransitionDef): Unit = {
+      Transition.graph.addEdge(RowType.VanillaRow, this, 
transitionDef.create())
     }
 
-    final def toRow(transitionDef: TransitionDef): Unit = {
-      Transition.factory.update().defineToRowTransition(this, transitionDef)
+    final protected def toRow(transitionDef: TransitionDef): Unit = {
+      Transition.graph.addEdge(this, RowType.VanillaRow, 
transitionDef.create())
     }
 
-    final def fromBatch(from: BatchType, transitionDef: TransitionDef): Unit = 
{
+    final protected def fromBatch(from: BatchType, transitionDef: 
TransitionDef): Unit = {
       assert(from != this)
-      Transition.factory.update().defineBatchTransition(from, this, 
transitionDef)
+      Transition.graph.addEdge(from, this, transitionDef.create())
     }
 
-    final def toBatch(to: BatchType, transitionDef: TransitionDef): Unit = {
+    final protected def toBatch(to: BatchType, transitionDef: TransitionDef): 
Unit = {
       assert(to != this)
-      Transition.factory.update().defineBatchTransition(this, to, 
transitionDef)
+      Transition.graph.addEdge(this, to, transitionDef.create())
     }
   }
 
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionFunc.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionFunc.scala
index 21662f503e..c3feefe943 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionFunc.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionFunc.scala
@@ -26,13 +26,21 @@ import 
org.apache.spark.sql.execution.command.DataWritingCommandExec
 import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
 
 /** ConventionFunc is a utility to derive [[Convention]] or [[ConventionReq]] 
from a query plan. */
-trait ConventionFunc {
+sealed trait ConventionFunc {
   def conventionOf(plan: SparkPlan): Convention
   def conventionReqOf(plan: SparkPlan): ConventionReq
 }
 
 object ConventionFunc {
-  type BatchOverride = PartialFunction[SparkPlan, Convention.BatchType]
+  trait Override {
+    def rowTypeOf: PartialFunction[SparkPlan, Convention.RowType] = 
PartialFunction.empty
+    def batchTypeOf: PartialFunction[SparkPlan, Convention.BatchType] = 
PartialFunction.empty
+    def conventionReqOf: PartialFunction[SparkPlan, ConventionReq] = 
PartialFunction.empty
+  }
+
+  object Override {
+    object Empty extends Override
+  }
 
   // For testing, to make things work without a backend loaded.
   private var ignoreBackend: Boolean = false
@@ -53,17 +61,17 @@ object ConventionFunc {
     new BuiltinFunc(batchOverride)
   }
 
-  private def newOverride(): BatchOverride = {
+  private def newOverride(): Override = {
     synchronized {
       if (ignoreBackend) {
         // For testing
-        return PartialFunction.empty
+        return Override.Empty
       }
     }
-    Backend.get().batchTypeFunc()
+    Backend.get().convFuncOverride()
   }
 
-  private class BuiltinFunc(o: BatchOverride) extends ConventionFunc {
+  private class BuiltinFunc(o: Override) extends ConventionFunc {
     import BuiltinFunc._
     override def conventionOf(plan: SparkPlan): Convention = {
       val conv = conventionOf0(plan)
@@ -86,7 +94,7 @@ object ConventionFunc {
         val batchType = if (a.supportsColumnar) {
           // By default, we execute columnar AQE with backend batch output.
           // See 
org.apache.gluten.extension.columnar.transition.InsertTransitions.apply
-          Backend.get().batchType
+          Backend.get().defaultBatchType
         } else {
           Convention.BatchType.None
         }
@@ -98,6 +106,11 @@ object ConventionFunc {
     }
 
     private def rowTypeOf(plan: SparkPlan): Convention.RowType = {
+      val out = o.rowTypeOf.applyOrElse(plan, rowTypeOf0)
+      out
+    }
+
+    private def rowTypeOf0(plan: SparkPlan): Convention.RowType = {
       val out = plan match {
         case k: Convention.KnownRowType =>
           k.rowType()
@@ -113,25 +126,26 @@ object ConventionFunc {
     }
 
     private def batchTypeOf(plan: SparkPlan): Convention.BatchType = {
-      val out = o.applyOrElse(
-        plan,
-        (p: SparkPlan) =>
-          p match {
-            case k: Convention.KnownBatchType =>
-              k.batchType()
-            case _ if plan.supportsColumnar =>
-              Convention.BatchType.VanillaBatch
-            case _ =>
-              Convention.BatchType.None
-          }
-      )
+      val out = o.batchTypeOf.applyOrElse(plan, batchTypeOf0)
+      out
+    }
+
+    private def batchTypeOf0(plan: SparkPlan): Convention.BatchType = {
+      val out = plan match {
+        case k: Convention.KnownBatchType =>
+          k.batchType()
+        case _ if plan.supportsColumnar =>
+          Convention.BatchType.VanillaBatch
+        case _ =>
+          Convention.BatchType.None
+      }
       assert(out == Convention.BatchType.None || plan.supportsColumnar)
       out
     }
 
     override def conventionReqOf(plan: SparkPlan): ConventionReq = {
-      val out = conventionReqOf0(plan)
-      out
+      val req = o.conventionReqOf.applyOrElse(plan, conventionReqOf0)
+      req
     }
 
     private def conventionReqOf0(plan: SparkPlan): ConventionReq = plan match {
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionReq.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionReq.scala
index cb76ec4de0..ce613bf7db 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionReq.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionReq.scala
@@ -58,7 +58,7 @@ object ConventionReq {
   val vanillaBatch: ConventionReq =
     Impl(RowType.Any, BatchType.Is(Convention.BatchType.VanillaBatch))
   lazy val backendBatch: ConventionReq =
-    Impl(RowType.Any, BatchType.Is(Backend.get().batchType))
+    Impl(RowType.Any, BatchType.Is(Backend.get().defaultBatchType))
 
   def get(plan: SparkPlan): ConventionReq = 
ConventionFunc.create().conventionReqOf(plan)
   def of(rowType: RowType, batchType: BatchType): ConventionReq = 
Impl(rowType, batchType)
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraph.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraph.scala
new file mode 100644
index 0000000000..2a4e1f4225
--- /dev/null
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraph.scala
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension.columnar.transition
+
+import scala.collection.mutable
+
+/**
+ * Floyd-Warshall algorithm for finding e.g., cheapest transition between 
query plan nodes.
+ *
+ * https://en.wikipedia.org/wiki/Floyd%E2%80%93Warshall_algorithm
+ */
+trait FloydWarshallGraph[V <: AnyRef, E <: AnyRef] {
+  import FloydWarshallGraph._
+  def hasPath(from: V, to: V): Boolean
+  def pathOf(from: V, to: V): Path[E]
+}
+
+object FloydWarshallGraph {
+  trait Cost {
+    def +(other: Cost): Cost
+  }
+
+  trait CostModel[E <: AnyRef] {
+    def zero(): Cost
+    def costOf(edge: E): Cost
+    def costComparator(): Ordering[Cost]
+  }
+
+  trait Path[E <: AnyRef] {
+    def edges(): Seq[E]
+    def cost(): Cost
+  }
+
+  def builder[V <: AnyRef, E <: AnyRef](costModel: CostModel[E]): Builder[V, 
E] = {
+    Builder.create(costModel)
+  }
+
+  private object Path {
+    def apply[E <: AnyRef](costModel: CostModel[E], edges: Seq[E]): Path[E] = 
Impl(edges)(costModel)
+    private case class Impl[E <: AnyRef](override val edges: 
Seq[E])(costModel: CostModel[E])
+      extends Path[E] {
+      override val cost: Cost = {
+        edges.map(costModel.costOf).reduceOption(_ + 
_).getOrElse(costModel.zero())
+      }
+    }
+  }
+
+  private class Impl[V <: AnyRef, E <: AnyRef](pathTable: Map[V, Map[V, 
Path[E]]])
+    extends FloydWarshallGraph[V, E] {
+    override def hasPath(from: V, to: V): Boolean = {
+      if (!pathTable.contains(from)) {
+        return false
+      }
+      val vec = pathTable(from)
+      if (!vec.contains(to)) {
+        return false
+      }
+      true
+    }
+
+    override def pathOf(from: V, to: V): Path[E] = {
+      assert(hasPath(from, to))
+      val path = pathTable(from)(to)
+      path
+    }
+  }
+
+  trait Builder[V <: AnyRef, E <: AnyRef] {
+    def addVertex(v: V): Builder[V, E]
+    def addEdge(from: V, to: V, edge: E): Builder[V, E]
+    def build(): FloydWarshallGraph[V, E]
+  }
+
+  private object Builder {
+    // Thread safe.
+    private class Impl[V <: AnyRef, E <: AnyRef](costModel: CostModel[E]) 
extends Builder[V, E] {
+      private val pathTable: mutable.Map[V, mutable.Map[V, Path[E]]] = 
mutable.Map()
+      private var graph: Option[FloydWarshallGraph[V, E]] = None
+
+      override def addVertex(v: V): Builder[V, E] = synchronized {
+        assert(!pathTable.contains(v), s"Vertex $v already exists in graph")
+        pathTable.getOrElseUpdate(v, mutable.Map()).getOrElseUpdate(v, 
Path(costModel, Nil))
+        graph = None
+        this
+      }
+
+      override def addEdge(from: V, to: V, edge: E): Builder[V, E] = 
synchronized {
+        assert(from != to, s"Input vertices $from and $to should be different")
+        assert(pathTable.contains(from), s"Vertex $from not exists in graph")
+        assert(pathTable.contains(to), s"Vertex $to not exists in graph")
+        assert(!hasPath(from, to), s"Path from $from to $to already exists in 
graph")
+        pathTable(from) += to -> Path(costModel, Seq(edge))
+        graph = None
+        this
+      }
+
+      override def build(): FloydWarshallGraph[V, E] = synchronized {
+        if (graph.isEmpty) {
+          graph = Some(compile())
+        }
+        return graph.get
+      }
+
+      private def hasPath(from: V, to: V): Boolean = {
+        if (!pathTable.contains(from)) {
+          return false
+        }
+        val vec = pathTable(from)
+        if (!vec.contains(to)) {
+          return false
+        }
+        true
+      }
+
+      private def compile(): FloydWarshallGraph[V, E] = {
+        val vertices = pathTable.keys
+        for (k <- vertices) {
+          for (i <- vertices) {
+            for (j <- vertices) {
+              if (hasPath(i, k) && hasPath(k, j)) {
+                val pathIk = pathTable(i)(k)
+                val pathKj = pathTable(k)(j)
+                val newPath = Path(costModel, pathIk.edges() ++ pathKj.edges())
+                if (!hasPath(i, j)) {
+                  pathTable(i) += j -> newPath
+                } else {
+                  val path = pathTable(i)(j)
+                  if (costModel.costComparator().compare(newPath.cost(), 
path.cost()) < 0) {
+                    pathTable(i) += j -> newPath
+                  }
+                }
+              }
+            }
+          }
+        }
+        new FloydWarshallGraph.Impl(pathTable.map { case (k, m) => (k, 
m.toMap) }.toMap)
+      }
+    }
+
+    def create[V <: AnyRef, E <: AnyRef](costModel: CostModel[E]): Builder[V, 
E] = {
+      new Impl(costModel)
+    }
+  }
+}
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
index 3fd2839b5a..87e7204c14 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
@@ -18,9 +18,10 @@ package org.apache.gluten.extension.columnar.transition
 
 import org.apache.gluten.exception.GlutenException
 
-import org.apache.spark.sql.execution.SparkPlan
-
-import scala.collection.mutable
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan}
 
 /**
  * Transition is a simple function to convert a query plan to interested 
[[ConventionReq]].
@@ -29,18 +30,18 @@ import scala.collection.mutable
  * [[org.apache.gluten.extension.columnar.transition.Convention.BatchType]]'s 
definition.
  */
 trait Transition {
+  import Transition._
   final def apply(plan: SparkPlan): SparkPlan = {
     val out = apply0(plan)
-    if (out eq plan) {
-      assert(
-        this == Transition.empty,
-        "TransitionDef.empty / Transition.empty should be used when defining 
an empty transition.")
-    }
     out
   }
 
-  final def isEmpty: Boolean = {
-    this == Transition.empty
+  final lazy val isEmpty: Boolean = {
+    // Tests if a transition is actually no-op.
+    val plan = DummySparkPlan()
+    val out = apply0(plan)
+    val identical = out eq plan
+    identical
   }
 
   protected def apply0(plan: SparkPlan): SparkPlan
@@ -56,7 +57,10 @@ object TransitionDef {
 
 object Transition {
   val empty: Transition = (plan: SparkPlan) => plan
-  val factory: Factory = Factory.newBuiltin()
+  private val abort: Transition = (_: SparkPlan) => throw new 
UnsupportedOperationException("Abort")
+  private[transition] val graph: TransitionGraph.Builder = 
TransitionGraph.builder()
+
+  def factory(): Factory = Factory.newBuiltin(graph.build())
 
   def notFound(plan: SparkPlan): GlutenException = {
     new GlutenException(s"No viable transition found from plan's child to 
itself: $plan")
@@ -66,19 +70,6 @@ object Transition {
     new GlutenException(s"No viable transition to [$required] found for plan: 
$plan")
   }
 
-  private class ChainedTransition(first: Transition, second: Transition) 
extends Transition {
-    override def apply0(plan: SparkPlan): SparkPlan = {
-      second(first(plan))
-    }
-  }
-
-  private def chain(first: Transition, second: Transition): Transition = {
-    if (first.isEmpty && second.isEmpty) {
-      return Transition.empty
-    }
-    new ChainedTransition(first, second)
-  }
-
   trait Factory {
     final def findTransition(
         from: Convention,
@@ -90,63 +81,20 @@ object Transition {
     }
 
     final def satisfies(conv: Convention, req: ConventionReq): Boolean = {
-      val none = new Transition {
-        override protected def apply0(plan: SparkPlan): SparkPlan =
-          throw new UnsupportedOperationException()
-      }
-      val transition = findTransition(conv, req)(none)
+      val transition = findTransition(conv, req)(abort)
       transition.isEmpty
     }
 
     protected def findTransition(from: Convention, to: ConventionReq)(
         orElse: => Transition): Transition
-    private[transition] def update(): MutableFactory
-  }
-
-  trait MutableFactory extends Factory {
-    def defineFromRowTransition(to: Convention.BatchType, transitionDef: 
TransitionDef): Unit
-    def defineToRowTransition(from: Convention.BatchType, transitionDef: 
TransitionDef): Unit
-    def defineBatchTransition(
-        from: Convention.BatchType,
-        to: Convention.BatchType,
-        transitionDef: TransitionDef): Unit
   }
 
   private object Factory {
-    def newBuiltin(): Factory = {
-      new BuiltinFactory
+    def newBuiltin(graph: TransitionGraph): Factory = {
+      new BuiltinFactory(graph)
     }
 
-    private class BuiltinFactory extends MutableFactory {
-      private val fromRowTransitions: mutable.Map[Convention.BatchType, 
TransitionDef] =
-        mutable.Map()
-      private val toRowTransitions: mutable.Map[Convention.BatchType, 
TransitionDef] = mutable.Map()
-      private val batchTransitions
-          : mutable.Map[(Convention.BatchType, Convention.BatchType), 
TransitionDef] =
-        mutable.Map()
-
-      override def defineFromRowTransition(
-          to: Convention.BatchType,
-          transitionDef: TransitionDef): Unit = {
-        assert(!fromRowTransitions.contains(to))
-        fromRowTransitions += to -> transitionDef
-      }
-
-      override def defineToRowTransition(
-          from: Convention.BatchType,
-          transitionDef: TransitionDef): Unit = {
-        assert(!toRowTransitions.contains(from))
-        toRowTransitions += from -> transitionDef
-      }
-
-      override def defineBatchTransition(
-          from: Convention.BatchType,
-          to: Convention.BatchType,
-          transitionDef: TransitionDef): Unit = {
-        assert(!batchTransitions.contains((from, to)))
-        batchTransitions += (from, to) -> transitionDef
-      }
-
+    private class BuiltinFactory(graph: TransitionGraph) extends Factory {
       override def findTransition(from: Convention, to: ConventionReq)(
           orElse: => Transition): Transition = {
         assert(
@@ -165,7 +113,7 @@ object Transition {
           case (ConventionReq.RowType.Is(toRowType), 
ConventionReq.BatchType.Any) =>
             from.rowType match {
               case Convention.RowType.None =>
-                
toRowTransitions.get(from.batchType).map(_.create()).getOrElse(orElse)
+                graph.transitionOfOption(from.batchType, 
toRowType).getOrElse(orElse)
               case fromRowType =>
                 // We have only one single built-in row type.
                 assert(toRowType == fromRowType)
@@ -174,27 +122,12 @@ object Transition {
           case (ConventionReq.RowType.Any, 
ConventionReq.BatchType.Is(toBatchType)) =>
             from.batchType match {
               case Convention.BatchType.None =>
-                
fromRowTransitions.get(toBatchType).map(_.create()).getOrElse(orElse)
+                graph.transitionOfOption(from.rowType, 
toBatchType).getOrElse(orElse)
               case fromBatchType =>
                 if (toBatchType == fromBatchType) {
                   Transition.empty
                 } else {
-                  // Batch type conversion needed.
-                  //
-                  // We first look up for batch-to-batch transition. If found 
one, return that
-                  // transition to caller. Otherwise, look for from/to row 
transitions, then
-                  // return a bridged batch-to-row-to-batch transition.
-                  if (batchTransitions.contains((fromBatchType, toBatchType))) 
{
-                    // 1. Found batch-to-batch transition.
-                    batchTransitions((fromBatchType, toBatchType)).create()
-                  } else {
-                    // 2. Otherwise, build up batch-to-row-to-batch transition.
-                    val batchToRow =
-                      
toRowTransitions.get(fromBatchType).map(_.create()).getOrElse(orElse)
-                    val rowToBatch =
-                      
fromRowTransitions.get(toBatchType).map(_.create()).getOrElse(orElse)
-                    chain(batchToRow, rowToBatch)
-                  }
+                  graph.transitionOfOption(fromBatchType, 
toBatchType).getOrElse(orElse)
                 }
             }
           case (ConventionReq.RowType.Any, ConventionReq.BatchType.Any) =>
@@ -205,8 +138,12 @@ object Transition {
         }
         out
       }
-
-      override private[transition] def update(): MutableFactory = this
     }
   }
+
+  private case class DummySparkPlan() extends LeafExecNode {
+    override def supportsColumnar: Boolean = true // To bypass the assertion 
in ColumnarToRowExec.
+    override protected def doExecute(): RDD[InternalRow] = throw new 
UnsupportedOperationException()
+    override def output: Seq[Attribute] = Nil
+  }
 }
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala
new file mode 100644
index 0000000000..9cafcae8b5
--- /dev/null
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension.columnar.transition
+
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.util.SparkReflectionUtil
+
+object TransitionGraph {
+  trait Vertex {
+    override def toString: String = 
SparkReflectionUtil.getSimpleClassName(this.getClass)
+  }
+
+  type Builder = FloydWarshallGraph.Builder[TransitionGraph.Vertex, Transition]
+
+  def builder(): Builder = {
+    FloydWarshallGraph.builder(TransitionCostModel)
+  }
+
+  implicit class TransitionGraphOps(val graph: TransitionGraph) {
+    import TransitionGraphOps._
+    def hasTransition(from: TransitionGraph.Vertex, to: 
TransitionGraph.Vertex): Boolean = {
+      graph.hasPath(from, to)
+    }
+
+    def transitionOf(from: TransitionGraph.Vertex, to: 
TransitionGraph.Vertex): Transition = {
+      val path = graph.pathOf(from, to)
+      val out = path.edges().reduceOption((l, r) => chain(l, 
r)).getOrElse(Transition.empty)
+      out
+    }
+
+    def transitionOfOption(
+        from: TransitionGraph.Vertex,
+        to: TransitionGraph.Vertex): Option[Transition] = {
+      if (!hasTransition(from, to)) {
+        return None
+      }
+      Some(transitionOf(from, to))
+    }
+  }
+
+  private case class ChainedTransition(first: Transition, second: Transition) 
extends Transition {
+    override def apply0(plan: SparkPlan): SparkPlan = {
+      second(first(plan))
+    }
+  }
+
+  private object TransitionGraphOps {
+    private def chain(first: Transition, second: Transition): Transition = {
+      if (first.isEmpty && second.isEmpty) {
+        return Transition.empty
+      }
+      ChainedTransition(first, second)
+    }
+  }
+
+  private case class TransitionCost(count: Int) extends 
FloydWarshallGraph.Cost {
+    override def +(other: FloydWarshallGraph.Cost): TransitionCost = {
+      other match {
+        case TransitionCost(otherCount) => TransitionCost(count + otherCount)
+      }
+    }
+  }
+
+  private object TransitionCostModel extends 
FloydWarshallGraph.CostModel[Transition] {
+    override def zero(): FloydWarshallGraph.Cost = TransitionCost(0)
+    override def costOf(transition: Transition): FloydWarshallGraph.Cost = 
costOf0(transition)
+    override def costComparator(): Ordering[FloydWarshallGraph.Cost] = 
Ordering.Int.on {
+      case TransitionCost(c) => c
+    }
+    private def costOf0(transition: Transition): TransitionCost = transition 
match {
+      case t if t.isEmpty => TransitionCost(0)
+      case ChainedTransition(f, s) => costOf0(f) + costOf0(s)
+      case _ => TransitionCost(1)
+    }
+  }
+}
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala
index 1441814519..9987a65b0c 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala
@@ -89,7 +89,7 @@ object Transitions {
   }
 
   def toBackendBatchPlan(plan: SparkPlan): SparkPlan = {
-    val backendBatchType = Backend.get().batchType
+    val backendBatchType = Backend.get().defaultBatchType
     val out = toBatchPlan(plan, backendBatchType)
     out
   }
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/package.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/package.scala
index 2dd8d632e3..0a0deb17de 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/package.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/package.scala
@@ -23,6 +23,7 @@ import 
org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec
 import org.apache.spark.sql.execution.debug.DebugExec
 
 package object transition {
+  type TransitionGraph = FloydWarshallGraph[TransitionGraph.Vertex, Transition]
   // These 5 plan operators (as of Spark 3.5) are operators that have the
   // same convention with their children.
   //
diff --git 
a/gluten-core/src/main/scala/org/apache/spark/util/SparkReflectionUtil.scala 
b/gluten-core/src/main/scala/org/apache/spark/util/SparkReflectionUtil.scala
new file mode 100644
index 0000000000..40692346e0
--- /dev/null
+++ b/gluten-core/src/main/scala/org/apache/spark/util/SparkReflectionUtil.scala
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.util
+
+object SparkReflectionUtil {
+  def getSimpleClassName(cls: Class[_]): String = {
+    Utils.getSimpleName(cls)
+  }
+}
diff --git 
a/gluten-core/src/test/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraphSuite.scala
 
b/gluten-core/src/test/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraphSuite.scala
new file mode 100644
index 0000000000..6bc4ab804f
--- /dev/null
+++ 
b/gluten-core/src/test/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraphSuite.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension.columnar.transition
+
+import org.scalatest.funsuite.AnyFunSuite
+
+import java.util.concurrent.atomic.AtomicInteger
+
+class FloydWarshallGraphSuite extends AnyFunSuite {
+  import FloydWarshallGraphSuite._
+  test("Sanity") {
+    val v0 = Vertex()
+    val v1 = Vertex()
+    val v2 = Vertex()
+    val v3 = Vertex()
+    val v4 = Vertex()
+
+    val e01 = Edge(5)
+    val e12 = Edge(6)
+    val e03 = Edge(2)
+    val e34 = Edge(1)
+    val e42 = Edge(3)
+
+    val graph = FloydWarshallGraph
+      .builder(CostModel)
+      .addVertex(v0)
+      .addVertex(v1)
+      .addVertex(v2)
+      .addVertex(v3)
+      .addVertex(v4)
+      .addEdge(v0, v1, e01)
+      .addEdge(v1, v2, e12)
+      .addEdge(v0, v3, e03)
+      .addEdge(v3, v4, e34)
+      .addEdge(v4, v2, e42)
+      .build()
+
+    assert(graph.hasPath(v0, v1))
+    assert(graph.hasPath(v0, v2))
+    assert(!graph.hasPath(v1, v0))
+    assert(!graph.hasPath(v2, v0))
+
+    assert(graph.pathOf(v0, v0).edges() == Nil)
+
+    assert(graph.pathOf(v0, v1).edges() == Seq(e01))
+    assert(graph.pathOf(v1, v2).edges() == Seq(e12))
+    assert(graph.pathOf(v0, v3).edges() == Seq(e03))
+    assert(graph.pathOf(v3, v4).edges() == Seq(e34))
+    assert(graph.pathOf(v4, v2).edges() == Seq(e42))
+
+    assert(graph.pathOf(v0, v2).edges() == Seq(e03, e34, e42))
+  }
+}
+
+private object FloydWarshallGraphSuite {
+  case class Vertex private (id: Int)
+
+  private object Vertex {
+    private val id = new AtomicInteger(0)
+
+    def apply(): Vertex = {
+      Vertex(id.getAndIncrement())
+    }
+  }
+
+  case class Edge private (id: Int, distance: Long)
+
+  private object Edge {
+    private val id = new AtomicInteger(0)
+
+    def apply(distance: Long): Edge = {
+      Edge(id.getAndIncrement(), distance)
+    }
+  }
+
+  private case class LongCost(c: Long) extends FloydWarshallGraph.Cost {
+    override def +(other: FloydWarshallGraph.Cost): FloydWarshallGraph.Cost = 
other match {
+      case LongCost(o) => LongCost(c + o)
+    }
+  }
+
+  private object CostModel extends FloydWarshallGraph.CostModel[Edge] {
+    override def zero(): FloydWarshallGraph.Cost = LongCost(0)
+    override def costOf(edge: Edge): FloydWarshallGraph.Cost = 
LongCost(edge.distance * 10)
+    override def costComparator(): Ordering[FloydWarshallGraph.Cost] = 
Ordering.Long.on {
+      case LongCost(c) => c
+    }
+  }
+}
diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala 
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala
index 9b05d567e2..c658a43760 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala
@@ -109,7 +109,7 @@ trait GlutenPlan extends SparkPlan with 
Convention.KnownBatchType with LogLevelU
   }
 
   protected def batchType0(): Convention.BatchType = {
-    Backend.get().batchType
+    Backend.get().defaultBatchType
   }
 
   protected def doValidateInternal(): ValidationResult = 
ValidationResult.succeeded
diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/utils/PlanUtil.scala 
b/gluten-substrait/src/main/scala/org/apache/gluten/utils/PlanUtil.scala
index c38ea8af02..9ebd722781 100644
--- a/gluten-substrait/src/main/scala/org/apache/gluten/utils/PlanUtil.scala
+++ b/gluten-substrait/src/main/scala/org/apache/gluten/utils/PlanUtil.scala
@@ -27,7 +27,7 @@ import scala.annotation.tailrec
 
 object PlanUtil {
   private def isGlutenTableCacheInternal(i: InMemoryTableScanExec): Boolean = {
-    Convention.get(i).batchType == Backend.get().batchType
+    Convention.get(i).batchType == Backend.get().defaultBatchType
   }
 
   @tailrec
@@ -47,6 +47,6 @@ object PlanUtil {
   }
 
   def isGlutenColumnarOp(plan: SparkPlan): Boolean = {
-    Convention.get(plan).batchType == Backend.get().batchType
+    Convention.get(plan).batchType == Backend.get().defaultBatchType
   }
 }
diff --git 
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarCollapseTransformStages.scala
 
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarCollapseTransformStages.scala
index c57df192c5..ada7283da8 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarCollapseTransformStages.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarCollapseTransformStages.scala
@@ -166,7 +166,7 @@ case class ColumnarInputAdapter(child: SparkPlan)
   override def output: Seq[Attribute] = child.output
   override def supportsColumnar: Boolean = true
   override def batchType(): Convention.BatchType =
-    Backend.get().batchType
+    Backend.get().defaultBatchType
   override protected def doExecute(): RDD[InternalRow] = throw new 
UnsupportedOperationException()
   override protected def doExecuteColumnar(): RDD[ColumnarBatch] = 
child.executeColumnar()
   override def outputPartitioning: Partitioning = child.outputPartitioning


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to