This is an automated email from the ASF dual-hosted git repository.

hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 152045891c [CORE][VL] Cost model code refactors (#8541)
152045891c is described below

commit 152045891c2baf94781b29a5831ba20bd1e84586
Author: Hongze Zhang <[email protected]>
AuthorDate: Thu Jan 16 13:54:00 2025 +0800

    [CORE][VL] Cost model code refactors (#8541)
---
 .../gluten/backendsapi/clickhouse/CHBackend.scala  |  4 +-
 .../gluten/backendsapi/clickhouse/CHRuleApi.scala  |  4 --
 .../gluten/backendsapi/velox/VeloxBackend.scala    |  4 +-
 .../gluten/backendsapi/velox/VeloxRuleApi.scala    |  5 --
 .../gluten/execution/MiscOperatorSuite.scala       |  1 -
 .../enumerated/planner/VeloxRasSuite.scala         | 30 +++++-----
 .../org/apache/gluten/component/Component.scala    |  7 +++
 .../GlutenCost.scala}                              | 15 +----
 .../extension/columnar/cost/GlutenCostModel.scala  | 66 ++++++++++++++++++++++
 .../{enumerated/planner => }/cost/LongCost.scala   |  6 +-
 .../planner => }/cost/LongCostModel.scala          | 15 +++--
 .../{enumerated/planner => }/cost/LongCoster.scala |  2 +-
 .../planner => }/cost/LongCosterChain.scala        |  3 +-
 .../columnar/enumerated/EnumeratedTransform.scala  | 20 ++++++-
 .../enumerated/planner/property/Conv.scala         |  4 +-
 .../extension/columnar/transition/Convention.scala |  8 +--
 .../columnar/transition/FloydWarshallGraph.scala   | 50 +++++++---------
 .../extension/columnar/transition/Transition.scala | 43 ++++++++++----
 .../columnar/transition/TransitionGraph.scala      | 37 ++++++------
 .../columnar/transition/Transitions.scala          |  5 +-
 .../gluten/extension/injector/GlutenInjector.scala | 32 +----------
 .../transition/FloydWarshallGraphSuite.scala       |  4 +-
 .../planner => }/cost/LegacyCoster.scala           |  2 +-
 .../planner => }/cost/RoughCoster.scala            |  2 +-
 .../columnar/transition/TransitionSuite.scala      |  7 +--
 .../org/apache/gluten/config/GlutenConfig.scala    |  7 ++-
 26 files changed, 222 insertions(+), 161 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
index aa9e3e553c..9626987593 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
@@ -24,6 +24,7 @@ import org.apache.gluten.config.GlutenConfig
 import org.apache.gluten.execution.WriteFilesExecTransformer
 import org.apache.gluten.expression.WindowFunctionsBuilder
 import org.apache.gluten.extension.ValidationResult
+import org.apache.gluten.extension.columnar.cost.{LegacyCoster, LongCoster}
 import org.apache.gluten.extension.columnar.transition.{Convention, 
ConventionFunc}
 import org.apache.gluten.substrait.rel.LocalFilesNode
 import org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat
@@ -54,7 +55,6 @@ class CHBackend extends SubstraitBackend {
   override def name(): String = CHConf.BACKEND_NAME
   override def buildInfo(): BuildInfo =
     BuildInfo("ClickHouse", CH_BRANCH, CH_COMMIT, "UNKNOWN")
-  override def convFuncOverride(): ConventionFunc.Override = new ConvFunc()
   override def iteratorApi(): IteratorApi = new CHIteratorApi
   override def sparkPlanExecApi(): SparkPlanExecApi = new CHSparkPlanExecApi
   override def transformerApi(): TransformerApi = new CHTransformerApi
@@ -63,6 +63,8 @@ class CHBackend extends SubstraitBackend {
   override def listenerApi(): ListenerApi = new CHListenerApi
   override def ruleApi(): RuleApi = new CHRuleApi
   override def settings(): BackendSettingsApi = CHBackendSettings
+  override def convFuncOverride(): ConventionFunc.Override = new ConvFunc()
+  override def costers(): Seq[LongCoster] = Seq(LegacyCoster)
 }
 
 object CHBackend {
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
index 426c88c907..21ae342a22 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
@@ -22,7 +22,6 @@ import org.apache.gluten.config.GlutenConfig
 import org.apache.gluten.extension._
 import org.apache.gluten.extension.columnar._
 import 
org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow,
 RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast}
-import 
org.apache.gluten.extension.columnar.enumerated.planner.cost.LegacyCoster
 import org.apache.gluten.extension.columnar.heuristic.{ExpandFallbackPolicy, 
HeuristicTransform}
 import org.apache.gluten.extension.columnar.offload.{OffloadExchange, 
OffloadJoin, OffloadOthers}
 import org.apache.gluten.extension.columnar.rewrite._
@@ -143,9 +142,6 @@ object CHRuleApi {
   }
 
   private def injectRas(injector: RasInjector): Unit = {
-    // Register legacy coster for transition planner.
-    injector.injectCoster(_ => LegacyCoster)
-
     // CH backend doesn't work with RAS at the moment. Inject a rule that 
aborts any
     // execution calls.
     injector.injectPreTransform(
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
index 519e98c5d4..677d8792c7 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
@@ -25,6 +25,7 @@ import org.apache.gluten.exception.GlutenNotSupportException
 import org.apache.gluten.execution.WriteFilesExecTransformer
 import org.apache.gluten.expression.WindowFunctionsBuilder
 import org.apache.gluten.extension.ValidationResult
+import org.apache.gluten.extension.columnar.cost.{LegacyCoster, LongCoster, 
RoughCoster}
 import org.apache.gluten.extension.columnar.transition.{Convention, 
ConventionFunc}
 import org.apache.gluten.sql.shims.SparkShimLoader
 import org.apache.gluten.substrait.rel.LocalFilesNode
@@ -61,7 +62,6 @@ class VeloxBackend extends SubstraitBackend {
   override def name(): String = VeloxBackend.BACKEND_NAME
   override def buildInfo(): BuildInfo =
     BuildInfo("Velox", VELOX_BRANCH, VELOX_REVISION, VELOX_REVISION_TIME)
-  override def convFuncOverride(): ConventionFunc.Override = new ConvFunc()
   override def iteratorApi(): IteratorApi = new VeloxIteratorApi
   override def sparkPlanExecApi(): SparkPlanExecApi = new VeloxSparkPlanExecApi
   override def transformerApi(): TransformerApi = new VeloxTransformerApi
@@ -70,6 +70,8 @@ class VeloxBackend extends SubstraitBackend {
   override def listenerApi(): ListenerApi = new VeloxListenerApi
   override def ruleApi(): RuleApi = new VeloxRuleApi
   override def settings(): BackendSettingsApi = VeloxBackendSettings
+  override def convFuncOverride(): ConventionFunc.Override = new ConvFunc()
+  override def costers(): Seq[LongCoster] = Seq(LegacyCoster, RoughCoster)
 }
 
 object VeloxBackend {
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
index 6c60ab7d53..0cf6ac6713 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
@@ -23,7 +23,6 @@ import org.apache.gluten.extension._
 import org.apache.gluten.extension.columnar._
 import 
org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow,
 RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast}
 import org.apache.gluten.extension.columnar.enumerated.{RasOffload, RemoveSort}
-import 
org.apache.gluten.extension.columnar.enumerated.planner.cost.{LegacyCoster, 
RoughCoster}
 import org.apache.gluten.extension.columnar.heuristic.{ExpandFallbackPolicy, 
HeuristicTransform}
 import org.apache.gluten.extension.columnar.offload.{OffloadExchange, 
OffloadJoin, OffloadOthers}
 import org.apache.gluten.extension.columnar.rewrite._
@@ -120,10 +119,6 @@ object VeloxRuleApi {
   }
 
   private def injectRas(injector: RasInjector): Unit = {
-    // Gluten RAS: Costers.
-    injector.injectCoster(_ => LegacyCoster)
-    injector.injectCoster(_ => RoughCoster)
-
     // Gluten RAS: Pre rules.
     injector.injectPreTransform(_ => RemoveTransitions)
     injector.injectPreTransform(_ => PushDownInputFileExpression.PreOffload)
diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala
 
b/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala
index 5b6b8a30ed..f0be15f07d 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala
@@ -34,7 +34,6 @@ import java.util.concurrent.TimeUnit
 import scala.collection.JavaConverters
 
 class MiscOperatorSuite extends VeloxWholeStageTransformerSuite with 
AdaptiveSparkPlanHelper {
-
   protected val rootPath: String = getClass.getResource("/").getPath
   override protected val resourcePath: String = "/tpch-data-parquet"
   override protected val fileFormat: String = "parquet"
diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/enumerated/planner/VeloxRasSuite.scala
 
b/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/enumerated/planner/VeloxRasSuite.scala
index e7de629b39..050a881394 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/enumerated/planner/VeloxRasSuite.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/enumerated/planner/VeloxRasSuite.scala
@@ -17,11 +17,11 @@
 package org.apache.gluten.extension.columnar.enumerated.planner
 
 import org.apache.gluten.config.GlutenConfig
+import org.apache.gluten.extension.columnar.cost.{GlutenCost, GlutenCostModel, 
LegacyCoster, LongCostModel}
 import org.apache.gluten.extension.columnar.enumerated.EnumeratedTransform
-import 
org.apache.gluten.extension.columnar.enumerated.planner.cost.{GlutenCostModel, 
LegacyCoster, LongCostModel}
 import org.apache.gluten.extension.columnar.enumerated.planner.property.Conv
 import org.apache.gluten.extension.columnar.transition.{Convention, 
ConventionReq}
-import org.apache.gluten.ras.{Cost, Ras}
+import org.apache.gluten.ras.Ras
 import org.apache.gluten.ras.RasSuiteBase._
 import org.apache.gluten.ras.path.RasPath
 import org.apache.gluten.ras.property.PropertySet
@@ -152,7 +152,7 @@ object VeloxRasSuite {
   def newRas(rasRules: Seq[RasRule[SparkPlan]]): Ras[SparkPlan] = {
     GlutenOptimization
       .builder()
-      .costModel(sessionCostModel())
+      .costModel(EnumeratedTransform.asRasCostModel(sessionCostModel()))
       .addRules(rasRules)
       .create()
       .asInstanceOf[Ras[SparkPlan]]
@@ -205,27 +205,27 @@ object VeloxRasSuite {
 
   class UserCostModel1 extends GlutenCostModel {
     private val base = legacyCostModel()
-    override def costOf(node: SparkPlan): Cost = node match {
+    override def costOf(node: SparkPlan): GlutenCost = node match {
       case _: RowUnary => base.makeInfCost()
       case other => base.costOf(other)
     }
-    override def costComparator(): Ordering[Cost] = base.costComparator()
-    override def makeInfCost(): Cost = base.makeInfCost()
-    override def sum(one: Cost, other: Cost): Cost = base.sum(one, other)
-    override def diff(one: Cost, other: Cost): Cost = base.diff(one, other)
-    override def makeZeroCost(): Cost = base.makeZeroCost()
+    override def costComparator(): Ordering[GlutenCost] = base.costComparator()
+    override def makeInfCost(): GlutenCost = base.makeInfCost()
+    override def sum(one: GlutenCost, other: GlutenCost): GlutenCost = 
base.sum(one, other)
+    override def diff(one: GlutenCost, other: GlutenCost): GlutenCost = 
base.diff(one, other)
+    override def makeZeroCost(): GlutenCost = base.makeZeroCost()
   }
 
   class UserCostModel2 extends GlutenCostModel {
     private val base = legacyCostModel()
-    override def costOf(node: SparkPlan): Cost = node match {
+    override def costOf(node: SparkPlan): GlutenCost = node match {
       case _: ColumnarUnary => base.makeInfCost()
       case other => base.costOf(other)
     }
-    override def costComparator(): Ordering[Cost] = base.costComparator()
-    override def makeInfCost(): Cost = base.makeInfCost()
-    override def sum(one: Cost, other: Cost): Cost = base.sum(one, other)
-    override def diff(one: Cost, other: Cost): Cost = base.diff(one, other)
-    override def makeZeroCost(): Cost = base.makeZeroCost()
+    override def costComparator(): Ordering[GlutenCost] = base.costComparator()
+    override def makeInfCost(): GlutenCost = base.makeInfCost()
+    override def sum(one: GlutenCost, other: GlutenCost): GlutenCost = 
base.sum(one, other)
+    override def diff(one: GlutenCost, other: GlutenCost): GlutenCost = 
base.diff(one, other)
+    override def makeZeroCost(): GlutenCost = base.makeZeroCost()
   }
 }
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/component/Component.scala 
b/gluten-core/src/main/scala/org/apache/gluten/component/Component.scala
index 4a066e1484..bc8640a6bc 100644
--- a/gluten-core/src/main/scala/org/apache/gluten/component/Component.scala
+++ b/gluten-core/src/main/scala/org/apache/gluten/component/Component.scala
@@ -16,6 +16,7 @@
  */
 package org.apache.gluten.component
 
+import org.apache.gluten.extension.columnar.cost.LongCoster
 import org.apache.gluten.extension.columnar.transition.ConventionFunc
 import org.apache.gluten.extension.injector.Injector
 
@@ -69,6 +70,12 @@ trait Component {
    */
   def convFuncOverride(): ConventionFunc.Override = 
ConventionFunc.Override.Empty
 
+  /**
+   * A sequence of [[org.apache.gluten.extension.columnar.cost.LongCoster]] 
Gluten is using for cost
+   * evaluation.
+   */
+  def costers(): Seq[LongCoster] = Nil
+
   /** Query planner rules. */
   def injectRules(injector: Injector): Unit
 }
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/GlutenCostModel.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/GlutenCost.scala
similarity index 66%
rename from 
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/GlutenCostModel.scala
rename to 
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/GlutenCost.scala
index 41e5529d2e..08a21549a0 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/GlutenCostModel.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/GlutenCost.scala
@@ -14,17 +14,6 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.gluten.extension.columnar.enumerated.planner.cost
+package org.apache.gluten.extension.columnar.cost
 
-import org.apache.gluten.ras.{Cost, CostModel}
-
-import org.apache.spark.sql.execution.SparkPlan
-
-trait GlutenCostModel extends CostModel[SparkPlan] {
-  // Returns cost value of one + other.
-  def sum(one: Cost, other: Cost): Cost
-  // Returns cost value of one - other.
-  def diff(one: Cost, other: Cost): Cost
-
-  def makeZeroCost(): Cost
-}
+trait GlutenCost
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/GlutenCostModel.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/GlutenCostModel.scala
new file mode 100644
index 0000000000..80edf8919f
--- /dev/null
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/GlutenCostModel.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension.columnar.cost
+
+import org.apache.gluten.component.Component
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.util.SparkReflectionUtil
+
+/**
+ * The cost model API of Gluten. Used by:
+ *   1. RAS planner for cost-based optimization; 2. Transition graph for 
choosing transition paths.
+ */
+trait GlutenCostModel {
+  def costOf(node: SparkPlan): GlutenCost
+  def costComparator(): Ordering[GlutenCost]
+  def makeZeroCost(): GlutenCost
+  def makeInfCost(): GlutenCost
+  // Returns cost value of one + other.
+  def sum(one: GlutenCost, other: GlutenCost): GlutenCost
+  // Returns cost value of one - other.
+  def diff(one: GlutenCost, other: GlutenCost): GlutenCost
+}
+
+object GlutenCostModel extends Logging {
+  def find(aliasOrClass: String): GlutenCostModel = {
+    val costModelRegistry = LongCostModel.registry()
+    // Components should override Backend's costers. Hence, reversed 
registration order is applied.
+    Component
+      .sorted()
+      .reverse
+      .flatMap(_.costers())
+      .foreach(coster => costModelRegistry.register(coster))
+    val costModel = find(costModelRegistry, aliasOrClass)
+    costModel
+  }
+
+  private def find(registry: LongCostModel.Registry, aliasOrClass: String): 
GlutenCostModel = {
+    if (LongCostModel.Kind.values().contains(aliasOrClass)) {
+      val kind = LongCostModel.Kind.values()(aliasOrClass)
+      val model = registry.get(kind)
+      return model
+    }
+    val clazz = SparkReflectionUtil.classForName(aliasOrClass)
+    logInfo(s"Using user cost model: $aliasOrClass")
+    val ctor = clazz.getDeclaredConstructor()
+    ctor.setAccessible(true)
+    val model: GlutenCostModel = 
ctor.newInstance().asInstanceOf[GlutenCostModel]
+    model
+  }
+}
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCost.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCost.scala
similarity index 84%
rename from 
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCost.scala
rename to 
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCost.scala
index aa74f7736f..7de8407ffe 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCost.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCost.scala
@@ -14,8 +14,6 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.gluten.extension.columnar.enumerated.planner.cost
+package org.apache.gluten.extension.columnar.cost
 
-import org.apache.gluten.ras.Cost
-
-case class LongCost(value: Long) extends Cost
+case class LongCost(value: Long) extends GlutenCost
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCostModel.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCostModel.scala
similarity index 88%
rename from 
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCostModel.scala
rename to 
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCostModel.scala
index 0d11541b73..2cdf86e6af 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCostModel.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCostModel.scala
@@ -14,11 +14,10 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.gluten.extension.columnar.enumerated.planner.cost
+package org.apache.gluten.extension.columnar.cost
 
 import org.apache.gluten.exception.GlutenException
 import 
org.apache.gluten.extension.columnar.enumerated.planner.plan.GlutenPlanModel.GroupLeafExec
-import org.apache.gluten.ras.Cost
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.execution.SparkPlan
@@ -39,15 +38,15 @@ abstract class LongCostModel extends GlutenCostModel {
     assert(a >= 0)
     assert(b >= 0)
     val sum = a + b
-    if (sum < a || sum < b) Long.MaxValue else sum
+    if (sum < a || sum < b) infLongCost else sum
   }
 
-  override def sum(one: Cost, other: Cost): LongCost = (one, other) match {
+  override def sum(one: GlutenCost, other: GlutenCost): LongCost = (one, 
other) match {
     case (LongCost(value), LongCost(otherValue)) => LongCost(safeSum(value, 
otherValue))
   }
 
   // Returns cost value of one - other.
-  override def diff(one: Cost, other: Cost): Cost = (one, other) match {
+  override def diff(one: GlutenCost, other: GlutenCost): GlutenCost = (one, 
other) match {
     case (LongCost(value), LongCost(otherValue)) =>
       val d = Math.subtractExact(value, otherValue)
       require(d >= zeroLongCost, s"Difference between cost $one and $other 
should not be negative")
@@ -62,13 +61,13 @@ abstract class LongCostModel extends GlutenCostModel {
 
   def selfLongCostOf(node: SparkPlan): Long
 
-  override def costComparator(): Ordering[Cost] = Ordering.Long.on {
+  override def costComparator(): Ordering[GlutenCost] = Ordering.Long.on {
     case LongCost(value) => value
     case _ => throw new IllegalStateException("Unexpected cost type")
   }
 
-  override def makeInfCost(): Cost = LongCost(infLongCost)
-  override def makeZeroCost(): Cost = LongCost(zeroLongCost)
+  override def makeInfCost(): GlutenCost = LongCost(infLongCost)
+  override def makeZeroCost(): GlutenCost = LongCost(zeroLongCost)
 }
 
 object LongCostModel extends Logging {
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCoster.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCoster.scala
similarity index 95%
rename from 
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCoster.scala
rename to 
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCoster.scala
index f06d1a4db8..8346f8987c 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCoster.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCoster.scala
@@ -14,7 +14,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.gluten.extension.columnar.enumerated.planner.cost
+package org.apache.gluten.extension.columnar.cost
 
 import org.apache.spark.sql.execution.SparkPlan
 
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCosterChain.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCosterChain.scala
similarity index 96%
rename from 
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCosterChain.scala
rename to 
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCosterChain.scala
index 00980e7712..c7fe616a4d 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCosterChain.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCosterChain.scala
@@ -14,7 +14,8 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.gluten.extension.columnar.enumerated.planner.cost
+package org.apache.gluten.extension.columnar.cost
+
 import org.apache.gluten.exception.GlutenException
 
 import org.apache.spark.sql.execution.SparkPlan
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
index 59e829e179..72926407c5 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
@@ -19,12 +19,13 @@ package org.apache.gluten.extension.columnar.enumerated
 import org.apache.gluten.component.Component
 import org.apache.gluten.exception.GlutenException
 import 
org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleCall
+import org.apache.gluten.extension.columnar.cost.{GlutenCost, GlutenCostModel}
 import 
org.apache.gluten.extension.columnar.enumerated.planner.GlutenOptimization
-import 
org.apache.gluten.extension.columnar.enumerated.planner.cost.GlutenCostModel
 import org.apache.gluten.extension.columnar.enumerated.planner.property.Conv
 import org.apache.gluten.extension.injector.Injector
 import org.apache.gluten.extension.util.AdaptiveContext
 import org.apache.gluten.logging.LogLevelUtil
+import org.apache.gluten.ras.{Cost, CostModel}
 import org.apache.gluten.ras.property.PropertySet
 import org.apache.gluten.ras.rule.RasRule
 
@@ -47,11 +48,12 @@ import org.apache.spark.sql.execution._
 case class EnumeratedTransform(costModel: GlutenCostModel, rules: 
Seq[RasRule[SparkPlan]])
   extends Rule[SparkPlan]
   with LogLevelUtil {
+  import EnumeratedTransform._
 
   private val optimization = {
     GlutenOptimization
       .builder()
-      .costModel(costModel)
+      .costModel(asRasCostModel(costModel))
       .addRules(rules)
       .create()
   }
@@ -82,4 +84,18 @@ object EnumeratedTransform {
     val call = new ColumnarRuleCall(session, AdaptiveContext(session), false)
     dummyInjector.gluten.ras.createEnumeratedTransform(call)
   }
+
+  def asRasCostModel(gcm: GlutenCostModel): CostModel[SparkPlan] = {
+    new CostModelAdapter(gcm)
+  }
+
+  /** The adapter to make GlutenCostModel comply with RAS cost model. */
+  private class CostModelAdapter(gcm: GlutenCostModel) extends 
CostModel[SparkPlan] {
+    override def costOf(node: SparkPlan): Cost = CostAdapter(gcm.costOf(node))
+    override def costComparator(): Ordering[Cost] =
+      gcm.costComparator().on[Cost] { case CostAdapter(gc) => gc }
+    override def makeInfCost(): Cost = CostAdapter(gcm.makeInfCost())
+  }
+
+  private case class CostAdapter(gc: GlutenCost) extends Cost
 }
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/property/Conv.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/property/Conv.scala
index 9fa0a839a4..ff530d49bc 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/property/Conv.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/property/Conv.scala
@@ -46,7 +46,7 @@ sealed trait Conv extends Property[SparkPlan] {
       return true
     }
     val prop = this.asInstanceOf[Prop]
-    val out = Transition.factory().satisfies(prop.prop, req.req)
+    val out = Transition.factory.satisfies(prop.prop, req.req)
     out
   }
 }
@@ -64,7 +64,7 @@ object Conv {
   def findTransition(from: Conv, to: Conv): Transition = {
     val prop = from.asInstanceOf[Prop]
     val req = to.asInstanceOf[Req]
-    val out = Transition.factory().findTransition(prop.prop, req.req, new 
IllegalStateException())
+    val out = Transition.factory.findTransition(prop.prop, req.req, new 
IllegalStateException())
     out
   }
 
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala
index ff0f295852..cd341410cf 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala
@@ -116,21 +116,21 @@ object Convention {
     protected[this] def registerTransitions(): Unit
 
     final protected[this] def fromRow(transition: Transition): Unit = {
-      Transition.graph.addEdge(RowType.VanillaRow, this, transition)
+      Transition.factory.update(graph => graph.addEdge(RowType.VanillaRow, 
this, transition))
     }
 
     final protected[this] def toRow(transition: Transition): Unit = {
-      Transition.graph.addEdge(this, RowType.VanillaRow, transition)
+      Transition.factory.update(graph => graph.addEdge(this, 
RowType.VanillaRow, transition))
     }
 
     final protected[this] def fromBatch(from: BatchType, transition: 
Transition): Unit = {
       assert(from != this)
-      Transition.graph.addEdge(from, this, transition)
+      Transition.factory.update(graph => graph.addEdge(from, this, transition))
     }
 
     final protected[this] def toBatch(to: BatchType, transition: Transition): 
Unit = {
       assert(to != this)
-      Transition.graph.addEdge(this, to, transition)
+      Transition.factory.update(graph => graph.addEdge(this, to, transition))
     }
   }
 
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraph.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraph.scala
index b05e939687..00c687d3b6 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraph.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraph.scala
@@ -44,8 +44,8 @@ object FloydWarshallGraph {
     def cost(costModel: CostModel[E]): Cost
   }
 
-  def builder[V <: AnyRef, E <: AnyRef](costModelFactory: () => CostModel[E]): 
Builder[V, E] = {
-    Builder.create(costModelFactory)
+  def builder[V <: AnyRef, E <: AnyRef](): Builder[V, E] = {
+    Builder.create()
   }
 
   private object Path {
@@ -83,24 +83,22 @@ object FloydWarshallGraph {
   trait Builder[V <: AnyRef, E <: AnyRef] {
     def addVertex(v: V): Builder[V, E]
     def addEdge(from: V, to: V, edge: E): Builder[V, E]
-    def build(): FloydWarshallGraph[V, E]
+    def build(costModel: CostModel[E]): FloydWarshallGraph[V, E]
   }
 
   private object Builder {
-    // Thread safe.
-    private class Impl[V <: AnyRef, E <: AnyRef](costModelFactory: () => 
CostModel[E])
-      extends Builder[V, E] {
+    private class Impl[V <: AnyRef, E <: AnyRef]() extends Builder[V, E] {
       private val pathTable: mutable.Map[V, mutable.Map[V, Path[E]]] = 
mutable.Map()
       private var graph: Option[FloydWarshallGraph[V, E]] = None
 
-      override def addVertex(v: V): Builder[V, E] = synchronized {
+      override def addVertex(v: V): Builder[V, E] = {
         assert(!pathTable.contains(v), s"Vertex $v already exists in graph")
         pathTable.getOrElseUpdate(v, mutable.Map()).getOrElseUpdate(v, 
Path(Nil))
         graph = None
         this
       }
 
-      override def addEdge(from: V, to: V, edge: E): Builder[V, E] = 
synchronized {
+      override def addEdge(from: V, to: V, edge: E): Builder[V, E] = {
         assert(from != to, s"Input vertices $from and $to should be different")
         assert(pathTable.contains(from), s"Vertex $from not exists in graph")
         assert(pathTable.contains(to), s"Vertex $to not exists in graph")
@@ -110,26 +108,7 @@ object FloydWarshallGraph {
         this
       }
 
-      override def build(): FloydWarshallGraph[V, E] = synchronized {
-        if (graph.isEmpty) {
-          graph = Some(compile())
-        }
-        return graph.get
-      }
-
-      private def hasPath(from: V, to: V): Boolean = {
-        if (!pathTable.contains(from)) {
-          return false
-        }
-        val vec = pathTable(from)
-        if (!vec.contains(to)) {
-          return false
-        }
-        true
-      }
-
-      private def compile(): FloydWarshallGraph[V, E] = {
-        val costModel = costModelFactory()
+      override def build(costModel: CostModel[E]): FloydWarshallGraph[V, E] = {
         val vertices = pathTable.keys
         for (k <- vertices) {
           for (i <- vertices) {
@@ -156,10 +135,21 @@ object FloydWarshallGraph {
         }
         new FloydWarshallGraph.Impl(pathTable.map { case (k, m) => (k, 
m.toMap) }.toMap)
       }
+
+      private def hasPath(from: V, to: V): Boolean = {
+        if (!pathTable.contains(from)) {
+          return false
+        }
+        val vec = pathTable(from)
+        if (!vec.contains(to)) {
+          return false
+        }
+        true
+      }
     }
 
-    def create[V <: AnyRef, E <: AnyRef](costModelFactory: () => 
CostModel[E]): Builder[V, E] = {
-      new Impl(costModelFactory)
+    def create[V <: AnyRef, E <: AnyRef](): Builder[V, E] = {
+      new Impl()
     }
   }
 }
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
index e7a073d9ad..41951d6da9 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
@@ -16,10 +16,14 @@
  */
 package org.apache.gluten.extension.columnar.transition
 
+import org.apache.gluten.config.GlutenConfig
 import org.apache.gluten.exception.GlutenException
+import org.apache.gluten.extension.columnar.cost.GlutenCostModel
 
 import org.apache.spark.sql.execution.SparkPlan
 
+import scala.collection.mutable
+
 /**
  * Transition is a simple function to convert a query plan to interested 
[[ConventionReq]].
  *
@@ -47,9 +51,7 @@ trait Transition {
 object Transition {
   val empty: Transition = (plan: SparkPlan) => plan
   private val abort: Transition = (_: SparkPlan) => throw new 
UnsupportedOperationException("Abort")
-  private[transition] val graph: TransitionGraph.Builder = 
TransitionGraph.builder()
-
-  def factory(): Factory = Factory.newBuiltin(graph.build())
+  val factory = Factory.newBuiltin()
 
   def notFound(plan: SparkPlan): GlutenException = {
     new GlutenException(s"No viable transition found from plan's child to 
itself: $plan")
@@ -74,16 +76,32 @@ object Transition {
       transition.isEmpty
     }
 
-    protected def findTransition(from: Convention, to: ConventionReq)(
+    def update(body: TransitionGraph.Builder => Unit): Unit
+
+    protected[Factory] def findTransition(from: Convention, to: ConventionReq)(
         orElse: => Transition): Transition
   }
 
   private object Factory {
-    def newBuiltin(graph: TransitionGraph): Factory = {
-      new BuiltinFactory(graph)
+    def newBuiltin(): Factory = {
+      new BuiltinFactory()
     }
 
-    private class BuiltinFactory(graph: TransitionGraph) extends Factory {
+    private class BuiltinFactory() extends Factory {
+      private val graphBuilder: TransitionGraph.Builder = 
TransitionGraph.builder()
+      // Use of this cache allows user to set a new cost model in the same 
Spark session,
+      // then the new cost model will take effect for new transition-finding 
requests.
+      private val graphCache = mutable.Map[String, TransitionGraph]()
+
+      private def graph(): TransitionGraph = synchronized {
+        val aliasOrClass = GlutenConfig.get.rasCostModel
+        graphCache.getOrElseUpdate(
+          aliasOrClass, {
+            val base = GlutenCostModel.find(aliasOrClass)
+            graphBuilder.build(TransitionGraph.asTransitionCostModel(base))
+          })
+      }
+
       override def findTransition(from: Convention, to: ConventionReq)(
           orElse: => Transition): Transition = {
         assert(
@@ -104,7 +122,7 @@ object Transition {
               case Convention.RowType.None =>
                 // Input query plan doesn't have recognizable row-based output,
                 // find columnar-to-row transition.
-                graph.transitionOfOption(from.batchType, 
toRowType).getOrElse(orElse)
+                graph().transitionOfOption(from.batchType, 
toRowType).getOrElse(orElse)
               case fromRowType if toRowType == fromRowType =>
                 // We have only one single built-in row type.
                 Transition.empty
@@ -117,12 +135,12 @@ object Transition {
               case Convention.BatchType.None =>
                 // Input query plan doesn't have recognizable columnar output,
                 // find row-to-columnar transition.
-                graph.transitionOfOption(from.rowType, 
toBatchType).getOrElse(orElse)
+                graph().transitionOfOption(from.rowType, 
toBatchType).getOrElse(orElse)
               case fromBatchType if toBatchType == fromBatchType =>
                 Transition.empty
               case fromBatchType =>
                 // Find columnar-to-columnar transition.
-                graph.transitionOfOption(fromBatchType, 
toBatchType).getOrElse(orElse)
+                graph().transitionOfOption(fromBatchType, 
toBatchType).getOrElse(orElse)
             }
           case (ConventionReq.RowType.Any, ConventionReq.BatchType.Any) =>
             Transition.empty
@@ -132,6 +150,11 @@ object Transition {
         }
         out
       }
+
+      override def update(func: TransitionGraph.Builder => Unit): Unit = 
synchronized {
+        func(graphBuilder)
+        graphCache.clear()
+      }
     }
   }
 }
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala
index 8e97443831..7dece0b3f5 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala
@@ -16,9 +16,8 @@
  */
 package org.apache.gluten.extension.columnar.transition
 
-import org.apache.gluten.extension.columnar.enumerated.EnumeratedTransform
+import org.apache.gluten.extension.columnar.cost.{GlutenCost, GlutenCostModel}
 import org.apache.gluten.extension.columnar.transition.Convention.BatchType
-import org.apache.gluten.ras.Cost
 
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.util.SparkReflectionUtil
@@ -40,7 +39,7 @@ object TransitionGraph {
     }
 
     final private def register(): Unit = BatchType.synchronized {
-      Transition.graph.addVertex(this)
+      Transition.factory.update(graph => graph.addVertex(this))
       register0()
     }
 
@@ -51,8 +50,13 @@ object TransitionGraph {
 
   type Builder = FloydWarshallGraph.Builder[TransitionGraph.Vertex, Transition]
 
-  def builder(): Builder = {
-    FloydWarshallGraph.builder(() => new TransitionCostModel())
+  private[transition] def builder(): Builder = {
+    FloydWarshallGraph.builder()
+  }
+
+  private[transition] def asTransitionCostModel(
+      base: GlutenCostModel): FloydWarshallGraph.CostModel[Transition] = {
+    new TransitionCostModel(base)
   }
 
   implicit class TransitionGraphOps(val graph: TransitionGraph) {
@@ -93,21 +97,22 @@ object TransitionGraph {
   }
 
   /** Reuse RAS cost to represent transition cost. */
-  private case class TransitionCost(value: Cost, nodeNames: Seq[String])
+  private case class TransitionCost(value: GlutenCost, nodeNames: Seq[String])
     extends FloydWarshallGraph.Cost
 
   /**
-   * The cost model reuses RAS's cost model to evaluate cost of transitions.
+   * The transition cost model relies on the registered Gluten cost model 
internally to evaluate
+   * cost of transitions.
    *
    * Note the transition graph is built once for all subsequent Spark sessions 
created on the same
-   * driver, so any access to Spark dynamic SQL config in RAS cost model will 
not take effect for
+   * driver, so any access to Spark dynamic SQL config in Gluten cost model 
will not take effect for
    * the transition cost evaluation. Hence, it's not recommended to access 
Spark dynamic
-   * configurations in RAS cost model as well.
+   * configurations in Gluten cost model as well.
    */
-  private class TransitionCostModel() extends 
FloydWarshallGraph.CostModel[Transition] {
-    private val rasCostModel = EnumeratedTransform.static().costModel
+  private class TransitionCostModel(base: GlutenCostModel)
+    extends FloydWarshallGraph.CostModel[Transition] {
 
-    override def zero(): TransitionCost = 
TransitionCost(rasCostModel.makeZeroCost(), Nil)
+    override def zero(): TransitionCost = TransitionCost(base.makeZeroCost(), 
Nil)
     override def costOf(transition: Transition): TransitionCost = {
       costOf0(transition)
     }
@@ -115,13 +120,13 @@ object TransitionGraph {
         one: FloydWarshallGraph.Cost,
         other: FloydWarshallGraph.Cost): FloydWarshallGraph.Cost = (one, 
other) match {
       case (TransitionCost(c1, p1), TransitionCost(c2, p2)) =>
-        TransitionCost(rasCostModel.sum(c1, c2), p1 ++ p2)
+        TransitionCost(base.sum(c1, c2), p1 ++ p2)
     }
     override def costComparator(): Ordering[FloydWarshallGraph.Cost] = {
       (x: FloydWarshallGraph.Cost, y: FloydWarshallGraph.Cost) =>
         (x, y) match {
           case (TransitionCost(v1, nodeNames1), TransitionCost(v2, 
nodeNames2)) =>
-            val diff = rasCostModel.costComparator().compare(v1, v2)
+            val diff = base.costComparator().compare(v1, v2)
             if (diff != 0) {
               diff
             } else {
@@ -139,14 +144,14 @@ object TransitionGraph {
        * The calculation considers C2C's cost as half of C2R / R2C's cost. So 
query planner prefers
        * C2C than C2R / R2C.
        */
-      def rasCostOfPlan(plan: SparkPlan): Cost = rasCostModel.costOf(plan)
+      def rasCostOfPlan(plan: SparkPlan): GlutenCost = base.costOf(plan)
       def nodeNamesOfPlan(plan: SparkPlan): Seq[String] = {
         plan.map(_.nodeName).reverse
       }
 
       val leafCost = rasCostOfPlan(leaf)
       val accumulatedCost = rasCostOfPlan(transited)
-      val costDiff = rasCostModel.diff(accumulatedCost, leafCost)
+      val costDiff = base.diff(accumulatedCost, leafCost)
 
       val leafNodeNames = nodeNamesOfPlan(leaf)
       val accumulatedNodeNames = nodeNamesOfPlan(transited)
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala
index 297485d844..6ac847d19e 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala
@@ -52,7 +52,7 @@ case class InsertTransitions(convReq: ConventionReq) extends 
Rule[SparkPlan] {
           child
         } else {
           val transition =
-            Transition.factory().findTransition(from, convReq, 
Transition.notFound(node))
+            Transition.factory.findTransition(from, convReq, 
Transition.notFound(node))
           val newChild = transition.apply(child)
           newChild
         }
@@ -100,8 +100,7 @@ object Transitions {
   def enforceReq(plan: SparkPlan, req: ConventionReq): SparkPlan = {
     val convFunc = ConventionFunc.create()
     val removed = RemoveTransitions.removeForNode(plan)
-    val transition = Transition
-      .factory()
+    val transition = Transition.factory
       .findTransition(convFunc.conventionOf(removed), req, 
Transition.notFound(removed, req))
     val out = transition.apply(removed)
     out
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala
index 23db1c436d..a208db2c96 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala
@@ -20,8 +20,8 @@ import org.apache.gluten.config.GlutenConfig
 import org.apache.gluten.extension.GlutenColumnarRule
 import org.apache.gluten.extension.columnar.ColumnarRuleApplier
 import 
org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleCall
+import org.apache.gluten.extension.columnar.cost.GlutenCostModel
 import org.apache.gluten.extension.columnar.enumerated.{EnumeratedApplier, 
EnumeratedTransform}
-import 
org.apache.gluten.extension.columnar.enumerated.planner.cost.{GlutenCostModel, 
LongCoster, LongCostModel}
 import org.apache.gluten.extension.columnar.heuristic.{HeuristicApplier, 
HeuristicTransform}
 import org.apache.gluten.ras.rule.RasRule
 
@@ -29,7 +29,6 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.util.SparkReflectionUtil
 
 import scala.collection.mutable
 
@@ -106,7 +105,6 @@ object GlutenInjector {
   class RasInjector extends Logging {
     private val preTransformBuilders = mutable.Buffer.empty[ColumnarRuleCall 
=> Rule[SparkPlan]]
     private val rasRuleBuilders = mutable.Buffer.empty[ColumnarRuleCall => 
RasRule[SparkPlan]]
-    private val costerBuilders = mutable.Buffer.empty[ColumnarRuleCall => 
LongCoster]
     private val postTransformBuilders = mutable.Buffer.empty[ColumnarRuleCall 
=> Rule[SparkPlan]]
 
     def injectPreTransform(builder: ColumnarRuleCall => Rule[SparkPlan]): Unit 
= {
@@ -117,10 +115,6 @@ object GlutenInjector {
       rasRuleBuilders += builder
     }
 
-    def injectCoster(builder: ColumnarRuleCall => LongCoster): Unit = {
-      costerBuilders += builder
-    }
-
     def injectPostTransform(builder: ColumnarRuleCall => Rule[SparkPlan]): 
Unit = {
       postTransformBuilders += builder
     }
@@ -135,31 +129,9 @@ object GlutenInjector {
     def createEnumeratedTransform(call: ColumnarRuleCall): EnumeratedTransform 
= {
       // Build RAS rules.
       val rules = rasRuleBuilders.map(_(call))
-
-      // Build the cost model.
-      val costModelRegistry = LongCostModel.registry()
-      costerBuilders.foreach(cb => costModelRegistry.register(cb(call)))
-      val aliasOrClass = call.glutenConf.rasCostModel
-      val costModel = findCostModel(costModelRegistry, aliasOrClass)
-
+      val costModel = GlutenCostModel.find(call.glutenConf.rasCostModel)
       // Create transform.
       EnumeratedTransform(costModel, rules.toSeq)
     }
-
-    private def findCostModel(
-        registry: LongCostModel.Registry,
-        aliasOrClass: String): GlutenCostModel = {
-      if (LongCostModel.Kind.values().contains(aliasOrClass)) {
-        val kind = LongCostModel.Kind.values()(aliasOrClass)
-        val model = registry.get(kind)
-        return model
-      }
-      val clazz = SparkReflectionUtil.classForName(aliasOrClass)
-      logInfo(s"Using user cost model: $aliasOrClass")
-      val ctor = clazz.getDeclaredConstructor()
-      ctor.setAccessible(true)
-      val model: GlutenCostModel = 
ctor.newInstance().asInstanceOf[GlutenCostModel]
-      model
-    }
   }
 }
diff --git 
a/gluten-core/src/test/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraphSuite.scala
 
b/gluten-core/src/test/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraphSuite.scala
index 7b60940a1a..7d78df45c1 100644
--- 
a/gluten-core/src/test/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraphSuite.scala
+++ 
b/gluten-core/src/test/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraphSuite.scala
@@ -36,7 +36,7 @@ class FloydWarshallGraphSuite extends AnyFunSuite {
     val e42 = Edge(3)
 
     val graph = FloydWarshallGraph
-      .builder(() => CostModel)
+      .builder()
       .addVertex(v0)
       .addVertex(v1)
       .addVertex(v2)
@@ -47,7 +47,7 @@ class FloydWarshallGraphSuite extends AnyFunSuite {
       .addEdge(v0, v3, e03)
       .addEdge(v3, v4, e34)
       .addEdge(v4, v2, e42)
-      .build()
+      .build(CostModel)
 
     assert(graph.hasPath(v0, v1))
     assert(graph.hasPath(v0, v2))
diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LegacyCoster.scala
 
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/cost/LegacyCoster.scala
similarity index 96%
rename from 
gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LegacyCoster.scala
rename to 
gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/cost/LegacyCoster.scala
index bb89d0035b..a8e1524fc9 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LegacyCoster.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/cost/LegacyCoster.scala
@@ -14,7 +14,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.gluten.extension.columnar.enumerated.planner.cost
+package org.apache.gluten.extension.columnar.cost
 
 import 
org.apache.gluten.extension.columnar.transition.{ColumnarToColumnarLike, 
ColumnarToRowLike, RowToColumnarLike}
 import org.apache.gluten.utils.PlanUtil
diff --git 
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/RoughCoster.scala
 
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/cost/RoughCoster.scala
similarity index 97%
rename from 
gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/RoughCoster.scala
rename to 
gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/cost/RoughCoster.scala
index ab893265ec..caee696df6 100644
--- 
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/RoughCoster.scala
+++ 
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/cost/RoughCoster.scala
@@ -14,7 +14,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.gluten.extension.columnar.enumerated.planner.cost
+package org.apache.gluten.extension.columnar.cost
 
 import org.apache.gluten.execution.RowToColumnarExecBase
 import 
org.apache.gluten.extension.columnar.transition.{ColumnarToColumnarLike, 
ColumnarToRowLike, RowToColumnarLike}
diff --git 
a/gluten-substrait/src/test/scala/org/apache/gluten/extension/columnar/transition/TransitionSuite.scala
 
b/gluten-substrait/src/test/scala/org/apache/gluten/extension/columnar/transition/TransitionSuite.scala
index 2c423783fd..03c9c0f559 100644
--- 
a/gluten-substrait/src/test/scala/org/apache/gluten/extension/columnar/transition/TransitionSuite.scala
+++ 
b/gluten-substrait/src/test/scala/org/apache/gluten/extension/columnar/transition/TransitionSuite.scala
@@ -20,7 +20,7 @@ import org.apache.gluten.backend.Backend
 import org.apache.gluten.component.Component
 import org.apache.gluten.exception.GlutenException
 import org.apache.gluten.execution.{ColumnarToColumnarExec, GlutenPlan}
-import 
org.apache.gluten.extension.columnar.enumerated.planner.cost.LegacyCoster
+import org.apache.gluten.extension.columnar.cost.{LegacyCoster, LongCoster}
 import org.apache.gluten.extension.injector.Injector
 
 import org.apache.spark.rdd.RDD
@@ -152,8 +152,7 @@ object TransitionSuite extends TransitionSuiteBase {
     override def name(): String = "dummy-backend"
     override def buildInfo(): Component.BuildInfo =
       Component.BuildInfo("DUMMY_BACKEND", "N/A", "N/A", "N/A")
-    override def injectRules(injector: Injector): Unit = {
-      injector.gluten.ras.injectCoster(_ => LegacyCoster)
-    }
+    override def injectRules(injector: Injector): Unit = {}
+    override def costers(): Seq[LongCoster] = Seq(LegacyCoster)
   }
 }
diff --git 
a/shims/common/src/main/scala/org/apache/gluten/config/GlutenConfig.scala 
b/shims/common/src/main/scala/org/apache/gluten/config/GlutenConfig.scala
index 3f30d00430..e4d3a76326 100644
--- a/shims/common/src/main/scala/org/apache/gluten/config/GlutenConfig.scala
+++ b/shims/common/src/main/scala/org/apache/gluten/config/GlutenConfig.scala
@@ -1431,12 +1431,15 @@ object GlutenConfig {
       .booleanConf
       .createWithDefault(false)
 
+  // FIXME: This option is no longer only used by RAS. Should change key to
+  //  `spark.gluten.costModel` or something similar.
   val RAS_COST_MODEL =
     buildConf("spark.gluten.ras.costModel")
       .doc(
         "The class name of user-defined cost model that will be used by 
Gluten's transition " +
-          "planner as well as by RAS. If not specified, a legacy built-in cost 
model that " +
-          "exhaustively offloads computations will be used.")
+          "planner as well as by RAS. If not specified, a legacy built-in cost 
model will be " +
+          "used. The legacy cost model helps RAS planner exhaustively offload 
computations, and " +
+          "helps transition planner choose columnar-to-columnar transition 
over others.")
       .stringConf
       .createWithDefaultString("legacy")
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to