This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 152045891c [CORE][VL] Cost model code refactors (#8541)
152045891c is described below
commit 152045891c2baf94781b29a5831ba20bd1e84586
Author: Hongze Zhang <[email protected]>
AuthorDate: Thu Jan 16 13:54:00 2025 +0800
[CORE][VL] Cost model code refactors (#8541)
---
.../gluten/backendsapi/clickhouse/CHBackend.scala | 4 +-
.../gluten/backendsapi/clickhouse/CHRuleApi.scala | 4 --
.../gluten/backendsapi/velox/VeloxBackend.scala | 4 +-
.../gluten/backendsapi/velox/VeloxRuleApi.scala | 5 --
.../gluten/execution/MiscOperatorSuite.scala | 1 -
.../enumerated/planner/VeloxRasSuite.scala | 30 +++++-----
.../org/apache/gluten/component/Component.scala | 7 +++
.../GlutenCost.scala} | 15 +----
.../extension/columnar/cost/GlutenCostModel.scala | 66 ++++++++++++++++++++++
.../{enumerated/planner => }/cost/LongCost.scala | 6 +-
.../planner => }/cost/LongCostModel.scala | 15 +++--
.../{enumerated/planner => }/cost/LongCoster.scala | 2 +-
.../planner => }/cost/LongCosterChain.scala | 3 +-
.../columnar/enumerated/EnumeratedTransform.scala | 20 ++++++-
.../enumerated/planner/property/Conv.scala | 4 +-
.../extension/columnar/transition/Convention.scala | 8 +--
.../columnar/transition/FloydWarshallGraph.scala | 50 +++++++---------
.../extension/columnar/transition/Transition.scala | 43 ++++++++++----
.../columnar/transition/TransitionGraph.scala | 37 ++++++------
.../columnar/transition/Transitions.scala | 5 +-
.../gluten/extension/injector/GlutenInjector.scala | 32 +----------
.../transition/FloydWarshallGraphSuite.scala | 4 +-
.../planner => }/cost/LegacyCoster.scala | 2 +-
.../planner => }/cost/RoughCoster.scala | 2 +-
.../columnar/transition/TransitionSuite.scala | 7 +--
.../org/apache/gluten/config/GlutenConfig.scala | 7 ++-
26 files changed, 222 insertions(+), 161 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
index aa9e3e553c..9626987593 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
@@ -24,6 +24,7 @@ import org.apache.gluten.config.GlutenConfig
import org.apache.gluten.execution.WriteFilesExecTransformer
import org.apache.gluten.expression.WindowFunctionsBuilder
import org.apache.gluten.extension.ValidationResult
+import org.apache.gluten.extension.columnar.cost.{LegacyCoster, LongCoster}
import org.apache.gluten.extension.columnar.transition.{Convention,
ConventionFunc}
import org.apache.gluten.substrait.rel.LocalFilesNode
import org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat
@@ -54,7 +55,6 @@ class CHBackend extends SubstraitBackend {
override def name(): String = CHConf.BACKEND_NAME
override def buildInfo(): BuildInfo =
BuildInfo("ClickHouse", CH_BRANCH, CH_COMMIT, "UNKNOWN")
- override def convFuncOverride(): ConventionFunc.Override = new ConvFunc()
override def iteratorApi(): IteratorApi = new CHIteratorApi
override def sparkPlanExecApi(): SparkPlanExecApi = new CHSparkPlanExecApi
override def transformerApi(): TransformerApi = new CHTransformerApi
@@ -63,6 +63,8 @@ class CHBackend extends SubstraitBackend {
override def listenerApi(): ListenerApi = new CHListenerApi
override def ruleApi(): RuleApi = new CHRuleApi
override def settings(): BackendSettingsApi = CHBackendSettings
+ override def convFuncOverride(): ConventionFunc.Override = new ConvFunc()
+ override def costers(): Seq[LongCoster] = Seq(LegacyCoster)
}
object CHBackend {
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
index 426c88c907..21ae342a22 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala
@@ -22,7 +22,6 @@ import org.apache.gluten.config.GlutenConfig
import org.apache.gluten.extension._
import org.apache.gluten.extension.columnar._
import
org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow,
RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast}
-import
org.apache.gluten.extension.columnar.enumerated.planner.cost.LegacyCoster
import org.apache.gluten.extension.columnar.heuristic.{ExpandFallbackPolicy,
HeuristicTransform}
import org.apache.gluten.extension.columnar.offload.{OffloadExchange,
OffloadJoin, OffloadOthers}
import org.apache.gluten.extension.columnar.rewrite._
@@ -143,9 +142,6 @@ object CHRuleApi {
}
private def injectRas(injector: RasInjector): Unit = {
- // Register legacy coster for transition planner.
- injector.injectCoster(_ => LegacyCoster)
-
// CH backend doesn't work with RAS at the moment. Inject a rule that
aborts any
// execution calls.
injector.injectPreTransform(
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
index 519e98c5d4..677d8792c7 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
@@ -25,6 +25,7 @@ import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.execution.WriteFilesExecTransformer
import org.apache.gluten.expression.WindowFunctionsBuilder
import org.apache.gluten.extension.ValidationResult
+import org.apache.gluten.extension.columnar.cost.{LegacyCoster, LongCoster,
RoughCoster}
import org.apache.gluten.extension.columnar.transition.{Convention,
ConventionFunc}
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.substrait.rel.LocalFilesNode
@@ -61,7 +62,6 @@ class VeloxBackend extends SubstraitBackend {
override def name(): String = VeloxBackend.BACKEND_NAME
override def buildInfo(): BuildInfo =
BuildInfo("Velox", VELOX_BRANCH, VELOX_REVISION, VELOX_REVISION_TIME)
- override def convFuncOverride(): ConventionFunc.Override = new ConvFunc()
override def iteratorApi(): IteratorApi = new VeloxIteratorApi
override def sparkPlanExecApi(): SparkPlanExecApi = new VeloxSparkPlanExecApi
override def transformerApi(): TransformerApi = new VeloxTransformerApi
@@ -70,6 +70,8 @@ class VeloxBackend extends SubstraitBackend {
override def listenerApi(): ListenerApi = new VeloxListenerApi
override def ruleApi(): RuleApi = new VeloxRuleApi
override def settings(): BackendSettingsApi = VeloxBackendSettings
+ override def convFuncOverride(): ConventionFunc.Override = new ConvFunc()
+ override def costers(): Seq[LongCoster] = Seq(LegacyCoster, RoughCoster)
}
object VeloxBackend {
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
index 6c60ab7d53..0cf6ac6713 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
@@ -23,7 +23,6 @@ import org.apache.gluten.extension._
import org.apache.gluten.extension.columnar._
import
org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow,
RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast}
import org.apache.gluten.extension.columnar.enumerated.{RasOffload, RemoveSort}
-import
org.apache.gluten.extension.columnar.enumerated.planner.cost.{LegacyCoster,
RoughCoster}
import org.apache.gluten.extension.columnar.heuristic.{ExpandFallbackPolicy,
HeuristicTransform}
import org.apache.gluten.extension.columnar.offload.{OffloadExchange,
OffloadJoin, OffloadOthers}
import org.apache.gluten.extension.columnar.rewrite._
@@ -120,10 +119,6 @@ object VeloxRuleApi {
}
private def injectRas(injector: RasInjector): Unit = {
- // Gluten RAS: Costers.
- injector.injectCoster(_ => LegacyCoster)
- injector.injectCoster(_ => RoughCoster)
-
// Gluten RAS: Pre rules.
injector.injectPreTransform(_ => RemoveTransitions)
injector.injectPreTransform(_ => PushDownInputFileExpression.PreOffload)
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala
b/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala
index 5b6b8a30ed..f0be15f07d 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala
@@ -34,7 +34,6 @@ import java.util.concurrent.TimeUnit
import scala.collection.JavaConverters
class MiscOperatorSuite extends VeloxWholeStageTransformerSuite with
AdaptiveSparkPlanHelper {
-
protected val rootPath: String = getClass.getResource("/").getPath
override protected val resourcePath: String = "/tpch-data-parquet"
override protected val fileFormat: String = "parquet"
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/enumerated/planner/VeloxRasSuite.scala
b/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/enumerated/planner/VeloxRasSuite.scala
index e7de629b39..050a881394 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/enumerated/planner/VeloxRasSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/enumerated/planner/VeloxRasSuite.scala
@@ -17,11 +17,11 @@
package org.apache.gluten.extension.columnar.enumerated.planner
import org.apache.gluten.config.GlutenConfig
+import org.apache.gluten.extension.columnar.cost.{GlutenCost, GlutenCostModel,
LegacyCoster, LongCostModel}
import org.apache.gluten.extension.columnar.enumerated.EnumeratedTransform
-import
org.apache.gluten.extension.columnar.enumerated.planner.cost.{GlutenCostModel,
LegacyCoster, LongCostModel}
import org.apache.gluten.extension.columnar.enumerated.planner.property.Conv
import org.apache.gluten.extension.columnar.transition.{Convention,
ConventionReq}
-import org.apache.gluten.ras.{Cost, Ras}
+import org.apache.gluten.ras.Ras
import org.apache.gluten.ras.RasSuiteBase._
import org.apache.gluten.ras.path.RasPath
import org.apache.gluten.ras.property.PropertySet
@@ -152,7 +152,7 @@ object VeloxRasSuite {
def newRas(rasRules: Seq[RasRule[SparkPlan]]): Ras[SparkPlan] = {
GlutenOptimization
.builder()
- .costModel(sessionCostModel())
+ .costModel(EnumeratedTransform.asRasCostModel(sessionCostModel()))
.addRules(rasRules)
.create()
.asInstanceOf[Ras[SparkPlan]]
@@ -205,27 +205,27 @@ object VeloxRasSuite {
class UserCostModel1 extends GlutenCostModel {
private val base = legacyCostModel()
- override def costOf(node: SparkPlan): Cost = node match {
+ override def costOf(node: SparkPlan): GlutenCost = node match {
case _: RowUnary => base.makeInfCost()
case other => base.costOf(other)
}
- override def costComparator(): Ordering[Cost] = base.costComparator()
- override def makeInfCost(): Cost = base.makeInfCost()
- override def sum(one: Cost, other: Cost): Cost = base.sum(one, other)
- override def diff(one: Cost, other: Cost): Cost = base.diff(one, other)
- override def makeZeroCost(): Cost = base.makeZeroCost()
+ override def costComparator(): Ordering[GlutenCost] = base.costComparator()
+ override def makeInfCost(): GlutenCost = base.makeInfCost()
+ override def sum(one: GlutenCost, other: GlutenCost): GlutenCost =
base.sum(one, other)
+ override def diff(one: GlutenCost, other: GlutenCost): GlutenCost =
base.diff(one, other)
+ override def makeZeroCost(): GlutenCost = base.makeZeroCost()
}
class UserCostModel2 extends GlutenCostModel {
private val base = legacyCostModel()
- override def costOf(node: SparkPlan): Cost = node match {
+ override def costOf(node: SparkPlan): GlutenCost = node match {
case _: ColumnarUnary => base.makeInfCost()
case other => base.costOf(other)
}
- override def costComparator(): Ordering[Cost] = base.costComparator()
- override def makeInfCost(): Cost = base.makeInfCost()
- override def sum(one: Cost, other: Cost): Cost = base.sum(one, other)
- override def diff(one: Cost, other: Cost): Cost = base.diff(one, other)
- override def makeZeroCost(): Cost = base.makeZeroCost()
+ override def costComparator(): Ordering[GlutenCost] = base.costComparator()
+ override def makeInfCost(): GlutenCost = base.makeInfCost()
+ override def sum(one: GlutenCost, other: GlutenCost): GlutenCost =
base.sum(one, other)
+ override def diff(one: GlutenCost, other: GlutenCost): GlutenCost =
base.diff(one, other)
+ override def makeZeroCost(): GlutenCost = base.makeZeroCost()
}
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/component/Component.scala
b/gluten-core/src/main/scala/org/apache/gluten/component/Component.scala
index 4a066e1484..bc8640a6bc 100644
--- a/gluten-core/src/main/scala/org/apache/gluten/component/Component.scala
+++ b/gluten-core/src/main/scala/org/apache/gluten/component/Component.scala
@@ -16,6 +16,7 @@
*/
package org.apache.gluten.component
+import org.apache.gluten.extension.columnar.cost.LongCoster
import org.apache.gluten.extension.columnar.transition.ConventionFunc
import org.apache.gluten.extension.injector.Injector
@@ -69,6 +70,12 @@ trait Component {
*/
def convFuncOverride(): ConventionFunc.Override =
ConventionFunc.Override.Empty
+ /**
+ * A sequence of [[org.apache.gluten.extension.columnar.cost.LongCoster]]
Gluten is using for cost
+ * evaluation.
+ */
+ def costers(): Seq[LongCoster] = Nil
+
/** Query planner rules. */
def injectRules(injector: Injector): Unit
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/GlutenCostModel.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/GlutenCost.scala
similarity index 66%
rename from
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/GlutenCostModel.scala
rename to
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/GlutenCost.scala
index 41e5529d2e..08a21549a0 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/GlutenCostModel.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/GlutenCost.scala
@@ -14,17 +14,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.gluten.extension.columnar.enumerated.planner.cost
+package org.apache.gluten.extension.columnar.cost
-import org.apache.gluten.ras.{Cost, CostModel}
-
-import org.apache.spark.sql.execution.SparkPlan
-
-trait GlutenCostModel extends CostModel[SparkPlan] {
- // Returns cost value of one + other.
- def sum(one: Cost, other: Cost): Cost
- // Returns cost value of one - other.
- def diff(one: Cost, other: Cost): Cost
-
- def makeZeroCost(): Cost
-}
+trait GlutenCost
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/GlutenCostModel.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/GlutenCostModel.scala
new file mode 100644
index 0000000000..80edf8919f
--- /dev/null
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/GlutenCostModel.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension.columnar.cost
+
+import org.apache.gluten.component.Component
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.util.SparkReflectionUtil
+
+/**
+ * The cost model API of Gluten. Used by:
+ * 1. RAS planner for cost-based optimization; 2. Transition graph for
choosing transition paths.
+ */
+trait GlutenCostModel {
+ def costOf(node: SparkPlan): GlutenCost
+ def costComparator(): Ordering[GlutenCost]
+ def makeZeroCost(): GlutenCost
+ def makeInfCost(): GlutenCost
+ // Returns cost value of one + other.
+ def sum(one: GlutenCost, other: GlutenCost): GlutenCost
+ // Returns cost value of one - other.
+ def diff(one: GlutenCost, other: GlutenCost): GlutenCost
+}
+
+object GlutenCostModel extends Logging {
+ def find(aliasOrClass: String): GlutenCostModel = {
+ val costModelRegistry = LongCostModel.registry()
+ // Components should override Backend's costers. Hence, reversed
registration order is applied.
+ Component
+ .sorted()
+ .reverse
+ .flatMap(_.costers())
+ .foreach(coster => costModelRegistry.register(coster))
+ val costModel = find(costModelRegistry, aliasOrClass)
+ costModel
+ }
+
+ private def find(registry: LongCostModel.Registry, aliasOrClass: String):
GlutenCostModel = {
+ if (LongCostModel.Kind.values().contains(aliasOrClass)) {
+ val kind = LongCostModel.Kind.values()(aliasOrClass)
+ val model = registry.get(kind)
+ return model
+ }
+ val clazz = SparkReflectionUtil.classForName(aliasOrClass)
+ logInfo(s"Using user cost model: $aliasOrClass")
+ val ctor = clazz.getDeclaredConstructor()
+ ctor.setAccessible(true)
+ val model: GlutenCostModel =
ctor.newInstance().asInstanceOf[GlutenCostModel]
+ model
+ }
+}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCost.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCost.scala
similarity index 84%
rename from
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCost.scala
rename to
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCost.scala
index aa74f7736f..7de8407ffe 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCost.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCost.scala
@@ -14,8 +14,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.gluten.extension.columnar.enumerated.planner.cost
+package org.apache.gluten.extension.columnar.cost
-import org.apache.gluten.ras.Cost
-
-case class LongCost(value: Long) extends Cost
+case class LongCost(value: Long) extends GlutenCost
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCostModel.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCostModel.scala
similarity index 88%
rename from
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCostModel.scala
rename to
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCostModel.scala
index 0d11541b73..2cdf86e6af 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCostModel.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCostModel.scala
@@ -14,11 +14,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.gluten.extension.columnar.enumerated.planner.cost
+package org.apache.gluten.extension.columnar.cost
import org.apache.gluten.exception.GlutenException
import
org.apache.gluten.extension.columnar.enumerated.planner.plan.GlutenPlanModel.GroupLeafExec
-import org.apache.gluten.ras.Cost
import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.SparkPlan
@@ -39,15 +38,15 @@ abstract class LongCostModel extends GlutenCostModel {
assert(a >= 0)
assert(b >= 0)
val sum = a + b
- if (sum < a || sum < b) Long.MaxValue else sum
+ if (sum < a || sum < b) infLongCost else sum
}
- override def sum(one: Cost, other: Cost): LongCost = (one, other) match {
+ override def sum(one: GlutenCost, other: GlutenCost): LongCost = (one,
other) match {
case (LongCost(value), LongCost(otherValue)) => LongCost(safeSum(value,
otherValue))
}
// Returns cost value of one - other.
- override def diff(one: Cost, other: Cost): Cost = (one, other) match {
+ override def diff(one: GlutenCost, other: GlutenCost): GlutenCost = (one,
other) match {
case (LongCost(value), LongCost(otherValue)) =>
val d = Math.subtractExact(value, otherValue)
require(d >= zeroLongCost, s"Difference between cost $one and $other
should not be negative")
@@ -62,13 +61,13 @@ abstract class LongCostModel extends GlutenCostModel {
def selfLongCostOf(node: SparkPlan): Long
- override def costComparator(): Ordering[Cost] = Ordering.Long.on {
+ override def costComparator(): Ordering[GlutenCost] = Ordering.Long.on {
case LongCost(value) => value
case _ => throw new IllegalStateException("Unexpected cost type")
}
- override def makeInfCost(): Cost = LongCost(infLongCost)
- override def makeZeroCost(): Cost = LongCost(zeroLongCost)
+ override def makeInfCost(): GlutenCost = LongCost(infLongCost)
+ override def makeZeroCost(): GlutenCost = LongCost(zeroLongCost)
}
object LongCostModel extends Logging {
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCoster.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCoster.scala
similarity index 95%
rename from
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCoster.scala
rename to
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCoster.scala
index f06d1a4db8..8346f8987c 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCoster.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCoster.scala
@@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.gluten.extension.columnar.enumerated.planner.cost
+package org.apache.gluten.extension.columnar.cost
import org.apache.spark.sql.execution.SparkPlan
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCosterChain.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCosterChain.scala
similarity index 96%
rename from
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCosterChain.scala
rename to
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCosterChain.scala
index 00980e7712..c7fe616a4d 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LongCosterChain.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/cost/LongCosterChain.scala
@@ -14,7 +14,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.gluten.extension.columnar.enumerated.planner.cost
+package org.apache.gluten.extension.columnar.cost
+
import org.apache.gluten.exception.GlutenException
import org.apache.spark.sql.execution.SparkPlan
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
index 59e829e179..72926407c5 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
@@ -19,12 +19,13 @@ package org.apache.gluten.extension.columnar.enumerated
import org.apache.gluten.component.Component
import org.apache.gluten.exception.GlutenException
import
org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleCall
+import org.apache.gluten.extension.columnar.cost.{GlutenCost, GlutenCostModel}
import
org.apache.gluten.extension.columnar.enumerated.planner.GlutenOptimization
-import
org.apache.gluten.extension.columnar.enumerated.planner.cost.GlutenCostModel
import org.apache.gluten.extension.columnar.enumerated.planner.property.Conv
import org.apache.gluten.extension.injector.Injector
import org.apache.gluten.extension.util.AdaptiveContext
import org.apache.gluten.logging.LogLevelUtil
+import org.apache.gluten.ras.{Cost, CostModel}
import org.apache.gluten.ras.property.PropertySet
import org.apache.gluten.ras.rule.RasRule
@@ -47,11 +48,12 @@ import org.apache.spark.sql.execution._
case class EnumeratedTransform(costModel: GlutenCostModel, rules:
Seq[RasRule[SparkPlan]])
extends Rule[SparkPlan]
with LogLevelUtil {
+ import EnumeratedTransform._
private val optimization = {
GlutenOptimization
.builder()
- .costModel(costModel)
+ .costModel(asRasCostModel(costModel))
.addRules(rules)
.create()
}
@@ -82,4 +84,18 @@ object EnumeratedTransform {
val call = new ColumnarRuleCall(session, AdaptiveContext(session), false)
dummyInjector.gluten.ras.createEnumeratedTransform(call)
}
+
+ def asRasCostModel(gcm: GlutenCostModel): CostModel[SparkPlan] = {
+ new CostModelAdapter(gcm)
+ }
+
+ /** The adapter to make GlutenCostModel comply with RAS cost model. */
+ private class CostModelAdapter(gcm: GlutenCostModel) extends
CostModel[SparkPlan] {
+ override def costOf(node: SparkPlan): Cost = CostAdapter(gcm.costOf(node))
+ override def costComparator(): Ordering[Cost] =
+ gcm.costComparator().on[Cost] { case CostAdapter(gc) => gc }
+ override def makeInfCost(): Cost = CostAdapter(gcm.makeInfCost())
+ }
+
+ private case class CostAdapter(gc: GlutenCost) extends Cost
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/property/Conv.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/property/Conv.scala
index 9fa0a839a4..ff530d49bc 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/property/Conv.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/property/Conv.scala
@@ -46,7 +46,7 @@ sealed trait Conv extends Property[SparkPlan] {
return true
}
val prop = this.asInstanceOf[Prop]
- val out = Transition.factory().satisfies(prop.prop, req.req)
+ val out = Transition.factory.satisfies(prop.prop, req.req)
out
}
}
@@ -64,7 +64,7 @@ object Conv {
def findTransition(from: Conv, to: Conv): Transition = {
val prop = from.asInstanceOf[Prop]
val req = to.asInstanceOf[Req]
- val out = Transition.factory().findTransition(prop.prop, req.req, new
IllegalStateException())
+ val out = Transition.factory.findTransition(prop.prop, req.req, new
IllegalStateException())
out
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala
index ff0f295852..cd341410cf 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala
@@ -116,21 +116,21 @@ object Convention {
protected[this] def registerTransitions(): Unit
final protected[this] def fromRow(transition: Transition): Unit = {
- Transition.graph.addEdge(RowType.VanillaRow, this, transition)
+ Transition.factory.update(graph => graph.addEdge(RowType.VanillaRow,
this, transition))
}
final protected[this] def toRow(transition: Transition): Unit = {
- Transition.graph.addEdge(this, RowType.VanillaRow, transition)
+ Transition.factory.update(graph => graph.addEdge(this,
RowType.VanillaRow, transition))
}
final protected[this] def fromBatch(from: BatchType, transition:
Transition): Unit = {
assert(from != this)
- Transition.graph.addEdge(from, this, transition)
+ Transition.factory.update(graph => graph.addEdge(from, this, transition))
}
final protected[this] def toBatch(to: BatchType, transition: Transition):
Unit = {
assert(to != this)
- Transition.graph.addEdge(this, to, transition)
+ Transition.factory.update(graph => graph.addEdge(this, to, transition))
}
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraph.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraph.scala
index b05e939687..00c687d3b6 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraph.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraph.scala
@@ -44,8 +44,8 @@ object FloydWarshallGraph {
def cost(costModel: CostModel[E]): Cost
}
- def builder[V <: AnyRef, E <: AnyRef](costModelFactory: () => CostModel[E]):
Builder[V, E] = {
- Builder.create(costModelFactory)
+ def builder[V <: AnyRef, E <: AnyRef](): Builder[V, E] = {
+ Builder.create()
}
private object Path {
@@ -83,24 +83,22 @@ object FloydWarshallGraph {
trait Builder[V <: AnyRef, E <: AnyRef] {
def addVertex(v: V): Builder[V, E]
def addEdge(from: V, to: V, edge: E): Builder[V, E]
- def build(): FloydWarshallGraph[V, E]
+ def build(costModel: CostModel[E]): FloydWarshallGraph[V, E]
}
private object Builder {
- // Thread safe.
- private class Impl[V <: AnyRef, E <: AnyRef](costModelFactory: () =>
CostModel[E])
- extends Builder[V, E] {
+ private class Impl[V <: AnyRef, E <: AnyRef]() extends Builder[V, E] {
private val pathTable: mutable.Map[V, mutable.Map[V, Path[E]]] =
mutable.Map()
private var graph: Option[FloydWarshallGraph[V, E]] = None
- override def addVertex(v: V): Builder[V, E] = synchronized {
+ override def addVertex(v: V): Builder[V, E] = {
assert(!pathTable.contains(v), s"Vertex $v already exists in graph")
pathTable.getOrElseUpdate(v, mutable.Map()).getOrElseUpdate(v,
Path(Nil))
graph = None
this
}
- override def addEdge(from: V, to: V, edge: E): Builder[V, E] =
synchronized {
+ override def addEdge(from: V, to: V, edge: E): Builder[V, E] = {
assert(from != to, s"Input vertices $from and $to should be different")
assert(pathTable.contains(from), s"Vertex $from not exists in graph")
assert(pathTable.contains(to), s"Vertex $to not exists in graph")
@@ -110,26 +108,7 @@ object FloydWarshallGraph {
this
}
- override def build(): FloydWarshallGraph[V, E] = synchronized {
- if (graph.isEmpty) {
- graph = Some(compile())
- }
- return graph.get
- }
-
- private def hasPath(from: V, to: V): Boolean = {
- if (!pathTable.contains(from)) {
- return false
- }
- val vec = pathTable(from)
- if (!vec.contains(to)) {
- return false
- }
- true
- }
-
- private def compile(): FloydWarshallGraph[V, E] = {
- val costModel = costModelFactory()
+ override def build(costModel: CostModel[E]): FloydWarshallGraph[V, E] = {
val vertices = pathTable.keys
for (k <- vertices) {
for (i <- vertices) {
@@ -156,10 +135,21 @@ object FloydWarshallGraph {
}
new FloydWarshallGraph.Impl(pathTable.map { case (k, m) => (k,
m.toMap) }.toMap)
}
+
+ private def hasPath(from: V, to: V): Boolean = {
+ if (!pathTable.contains(from)) {
+ return false
+ }
+ val vec = pathTable(from)
+ if (!vec.contains(to)) {
+ return false
+ }
+ true
+ }
}
- def create[V <: AnyRef, E <: AnyRef](costModelFactory: () =>
CostModel[E]): Builder[V, E] = {
- new Impl(costModelFactory)
+ def create[V <: AnyRef, E <: AnyRef](): Builder[V, E] = {
+ new Impl()
}
}
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
index e7a073d9ad..41951d6da9 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
@@ -16,10 +16,14 @@
*/
package org.apache.gluten.extension.columnar.transition
+import org.apache.gluten.config.GlutenConfig
import org.apache.gluten.exception.GlutenException
+import org.apache.gluten.extension.columnar.cost.GlutenCostModel
import org.apache.spark.sql.execution.SparkPlan
+import scala.collection.mutable
+
/**
* Transition is a simple function to convert a query plan to interested
[[ConventionReq]].
*
@@ -47,9 +51,7 @@ trait Transition {
object Transition {
val empty: Transition = (plan: SparkPlan) => plan
private val abort: Transition = (_: SparkPlan) => throw new
UnsupportedOperationException("Abort")
- private[transition] val graph: TransitionGraph.Builder =
TransitionGraph.builder()
-
- def factory(): Factory = Factory.newBuiltin(graph.build())
+ val factory = Factory.newBuiltin()
def notFound(plan: SparkPlan): GlutenException = {
new GlutenException(s"No viable transition found from plan's child to
itself: $plan")
@@ -74,16 +76,32 @@ object Transition {
transition.isEmpty
}
- protected def findTransition(from: Convention, to: ConventionReq)(
+ def update(body: TransitionGraph.Builder => Unit): Unit
+
+ protected[Factory] def findTransition(from: Convention, to: ConventionReq)(
orElse: => Transition): Transition
}
private object Factory {
- def newBuiltin(graph: TransitionGraph): Factory = {
- new BuiltinFactory(graph)
+ def newBuiltin(): Factory = {
+ new BuiltinFactory()
}
- private class BuiltinFactory(graph: TransitionGraph) extends Factory {
+ private class BuiltinFactory() extends Factory {
+ private val graphBuilder: TransitionGraph.Builder =
TransitionGraph.builder()
+ // Use of this cache allows user to set a new cost model in the same
Spark session,
+ // then the new cost model will take effect for new transition-finding
requests.
+ private val graphCache = mutable.Map[String, TransitionGraph]()
+
+ private def graph(): TransitionGraph = synchronized {
+ val aliasOrClass = GlutenConfig.get.rasCostModel
+ graphCache.getOrElseUpdate(
+ aliasOrClass, {
+ val base = GlutenCostModel.find(aliasOrClass)
+ graphBuilder.build(TransitionGraph.asTransitionCostModel(base))
+ })
+ }
+
override def findTransition(from: Convention, to: ConventionReq)(
orElse: => Transition): Transition = {
assert(
@@ -104,7 +122,7 @@ object Transition {
case Convention.RowType.None =>
// Input query plan doesn't have recognizable row-based output,
// find columnar-to-row transition.
- graph.transitionOfOption(from.batchType,
toRowType).getOrElse(orElse)
+ graph().transitionOfOption(from.batchType,
toRowType).getOrElse(orElse)
case fromRowType if toRowType == fromRowType =>
// We have only one single built-in row type.
Transition.empty
@@ -117,12 +135,12 @@ object Transition {
case Convention.BatchType.None =>
// Input query plan doesn't have recognizable columnar output,
// find row-to-columnar transition.
- graph.transitionOfOption(from.rowType,
toBatchType).getOrElse(orElse)
+ graph().transitionOfOption(from.rowType,
toBatchType).getOrElse(orElse)
case fromBatchType if toBatchType == fromBatchType =>
Transition.empty
case fromBatchType =>
// Find columnar-to-columnar transition.
- graph.transitionOfOption(fromBatchType,
toBatchType).getOrElse(orElse)
+ graph().transitionOfOption(fromBatchType,
toBatchType).getOrElse(orElse)
}
case (ConventionReq.RowType.Any, ConventionReq.BatchType.Any) =>
Transition.empty
@@ -132,6 +150,11 @@ object Transition {
}
out
}
+
+ override def update(func: TransitionGraph.Builder => Unit): Unit =
synchronized {
+ func(graphBuilder)
+ graphCache.clear()
+ }
}
}
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala
index 8e97443831..7dece0b3f5 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala
@@ -16,9 +16,8 @@
*/
package org.apache.gluten.extension.columnar.transition
-import org.apache.gluten.extension.columnar.enumerated.EnumeratedTransform
+import org.apache.gluten.extension.columnar.cost.{GlutenCost, GlutenCostModel}
import org.apache.gluten.extension.columnar.transition.Convention.BatchType
-import org.apache.gluten.ras.Cost
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.util.SparkReflectionUtil
@@ -40,7 +39,7 @@ object TransitionGraph {
}
final private def register(): Unit = BatchType.synchronized {
- Transition.graph.addVertex(this)
+ Transition.factory.update(graph => graph.addVertex(this))
register0()
}
@@ -51,8 +50,13 @@ object TransitionGraph {
type Builder = FloydWarshallGraph.Builder[TransitionGraph.Vertex, Transition]
- def builder(): Builder = {
- FloydWarshallGraph.builder(() => new TransitionCostModel())
+ private[transition] def builder(): Builder = {
+ FloydWarshallGraph.builder()
+ }
+
+ private[transition] def asTransitionCostModel(
+ base: GlutenCostModel): FloydWarshallGraph.CostModel[Transition] = {
+ new TransitionCostModel(base)
}
implicit class TransitionGraphOps(val graph: TransitionGraph) {
@@ -93,21 +97,22 @@ object TransitionGraph {
}
/** Reuse RAS cost to represent transition cost. */
- private case class TransitionCost(value: Cost, nodeNames: Seq[String])
+ private case class TransitionCost(value: GlutenCost, nodeNames: Seq[String])
extends FloydWarshallGraph.Cost
/**
- * The cost model reuses RAS's cost model to evaluate cost of transitions.
+ * The transition cost model relies on the registered Gluten cost model
internally to evaluate
+ * cost of transitions.
*
* Note the transition graph is built once for all subsequent Spark sessions
created on the same
- * driver, so any access to Spark dynamic SQL config in RAS cost model will
not take effect for
+ * driver, so any access to Spark dynamic SQL config in Gluten cost model
will not take effect for
* the transition cost evaluation. Hence, it's not recommended to access
Spark dynamic
- * configurations in RAS cost model as well.
+ * configurations in Gluten cost model as well.
*/
- private class TransitionCostModel() extends
FloydWarshallGraph.CostModel[Transition] {
- private val rasCostModel = EnumeratedTransform.static().costModel
+ private class TransitionCostModel(base: GlutenCostModel)
+ extends FloydWarshallGraph.CostModel[Transition] {
- override def zero(): TransitionCost =
TransitionCost(rasCostModel.makeZeroCost(), Nil)
+ override def zero(): TransitionCost = TransitionCost(base.makeZeroCost(),
Nil)
override def costOf(transition: Transition): TransitionCost = {
costOf0(transition)
}
@@ -115,13 +120,13 @@ object TransitionGraph {
one: FloydWarshallGraph.Cost,
other: FloydWarshallGraph.Cost): FloydWarshallGraph.Cost = (one,
other) match {
case (TransitionCost(c1, p1), TransitionCost(c2, p2)) =>
- TransitionCost(rasCostModel.sum(c1, c2), p1 ++ p2)
+ TransitionCost(base.sum(c1, c2), p1 ++ p2)
}
override def costComparator(): Ordering[FloydWarshallGraph.Cost] = {
(x: FloydWarshallGraph.Cost, y: FloydWarshallGraph.Cost) =>
(x, y) match {
case (TransitionCost(v1, nodeNames1), TransitionCost(v2,
nodeNames2)) =>
- val diff = rasCostModel.costComparator().compare(v1, v2)
+ val diff = base.costComparator().compare(v1, v2)
if (diff != 0) {
diff
} else {
@@ -139,14 +144,14 @@ object TransitionGraph {
* The calculation considers C2C's cost as half of C2R / R2C's cost. So
query planner prefers
* C2C than C2R / R2C.
*/
- def rasCostOfPlan(plan: SparkPlan): Cost = rasCostModel.costOf(plan)
+ def rasCostOfPlan(plan: SparkPlan): GlutenCost = base.costOf(plan)
def nodeNamesOfPlan(plan: SparkPlan): Seq[String] = {
plan.map(_.nodeName).reverse
}
val leafCost = rasCostOfPlan(leaf)
val accumulatedCost = rasCostOfPlan(transited)
- val costDiff = rasCostModel.diff(accumulatedCost, leafCost)
+ val costDiff = base.diff(accumulatedCost, leafCost)
val leafNodeNames = nodeNamesOfPlan(leaf)
val accumulatedNodeNames = nodeNamesOfPlan(transited)
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala
index 297485d844..6ac847d19e 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala
@@ -52,7 +52,7 @@ case class InsertTransitions(convReq: ConventionReq) extends
Rule[SparkPlan] {
child
} else {
val transition =
- Transition.factory().findTransition(from, convReq,
Transition.notFound(node))
+ Transition.factory.findTransition(from, convReq,
Transition.notFound(node))
val newChild = transition.apply(child)
newChild
}
@@ -100,8 +100,7 @@ object Transitions {
def enforceReq(plan: SparkPlan, req: ConventionReq): SparkPlan = {
val convFunc = ConventionFunc.create()
val removed = RemoveTransitions.removeForNode(plan)
- val transition = Transition
- .factory()
+ val transition = Transition.factory
.findTransition(convFunc.conventionOf(removed), req,
Transition.notFound(removed, req))
val out = transition.apply(removed)
out
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala
index 23db1c436d..a208db2c96 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/injector/GlutenInjector.scala
@@ -20,8 +20,8 @@ import org.apache.gluten.config.GlutenConfig
import org.apache.gluten.extension.GlutenColumnarRule
import org.apache.gluten.extension.columnar.ColumnarRuleApplier
import
org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleCall
+import org.apache.gluten.extension.columnar.cost.GlutenCostModel
import org.apache.gluten.extension.columnar.enumerated.{EnumeratedApplier,
EnumeratedTransform}
-import
org.apache.gluten.extension.columnar.enumerated.planner.cost.{GlutenCostModel,
LongCoster, LongCostModel}
import org.apache.gluten.extension.columnar.heuristic.{HeuristicApplier,
HeuristicTransform}
import org.apache.gluten.ras.rule.RasRule
@@ -29,7 +29,6 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.util.SparkReflectionUtil
import scala.collection.mutable
@@ -106,7 +105,6 @@ object GlutenInjector {
class RasInjector extends Logging {
private val preTransformBuilders = mutable.Buffer.empty[ColumnarRuleCall
=> Rule[SparkPlan]]
private val rasRuleBuilders = mutable.Buffer.empty[ColumnarRuleCall =>
RasRule[SparkPlan]]
- private val costerBuilders = mutable.Buffer.empty[ColumnarRuleCall =>
LongCoster]
private val postTransformBuilders = mutable.Buffer.empty[ColumnarRuleCall
=> Rule[SparkPlan]]
def injectPreTransform(builder: ColumnarRuleCall => Rule[SparkPlan]): Unit
= {
@@ -117,10 +115,6 @@ object GlutenInjector {
rasRuleBuilders += builder
}
- def injectCoster(builder: ColumnarRuleCall => LongCoster): Unit = {
- costerBuilders += builder
- }
-
def injectPostTransform(builder: ColumnarRuleCall => Rule[SparkPlan]):
Unit = {
postTransformBuilders += builder
}
@@ -135,31 +129,9 @@ object GlutenInjector {
def createEnumeratedTransform(call: ColumnarRuleCall): EnumeratedTransform
= {
// Build RAS rules.
val rules = rasRuleBuilders.map(_(call))
-
- // Build the cost model.
- val costModelRegistry = LongCostModel.registry()
- costerBuilders.foreach(cb => costModelRegistry.register(cb(call)))
- val aliasOrClass = call.glutenConf.rasCostModel
- val costModel = findCostModel(costModelRegistry, aliasOrClass)
-
+ val costModel = GlutenCostModel.find(call.glutenConf.rasCostModel)
// Create transform.
EnumeratedTransform(costModel, rules.toSeq)
}
-
- private def findCostModel(
- registry: LongCostModel.Registry,
- aliasOrClass: String): GlutenCostModel = {
- if (LongCostModel.Kind.values().contains(aliasOrClass)) {
- val kind = LongCostModel.Kind.values()(aliasOrClass)
- val model = registry.get(kind)
- return model
- }
- val clazz = SparkReflectionUtil.classForName(aliasOrClass)
- logInfo(s"Using user cost model: $aliasOrClass")
- val ctor = clazz.getDeclaredConstructor()
- ctor.setAccessible(true)
- val model: GlutenCostModel =
ctor.newInstance().asInstanceOf[GlutenCostModel]
- model
- }
}
}
diff --git
a/gluten-core/src/test/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraphSuite.scala
b/gluten-core/src/test/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraphSuite.scala
index 7b60940a1a..7d78df45c1 100644
---
a/gluten-core/src/test/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraphSuite.scala
+++
b/gluten-core/src/test/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraphSuite.scala
@@ -36,7 +36,7 @@ class FloydWarshallGraphSuite extends AnyFunSuite {
val e42 = Edge(3)
val graph = FloydWarshallGraph
- .builder(() => CostModel)
+ .builder()
.addVertex(v0)
.addVertex(v1)
.addVertex(v2)
@@ -47,7 +47,7 @@ class FloydWarshallGraphSuite extends AnyFunSuite {
.addEdge(v0, v3, e03)
.addEdge(v3, v4, e34)
.addEdge(v4, v2, e42)
- .build()
+ .build(CostModel)
assert(graph.hasPath(v0, v1))
assert(graph.hasPath(v0, v2))
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LegacyCoster.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/cost/LegacyCoster.scala
similarity index 96%
rename from
gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LegacyCoster.scala
rename to
gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/cost/LegacyCoster.scala
index bb89d0035b..a8e1524fc9 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/LegacyCoster.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/cost/LegacyCoster.scala
@@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.gluten.extension.columnar.enumerated.planner.cost
+package org.apache.gluten.extension.columnar.cost
import
org.apache.gluten.extension.columnar.transition.{ColumnarToColumnarLike,
ColumnarToRowLike, RowToColumnarLike}
import org.apache.gluten.utils.PlanUtil
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/RoughCoster.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/cost/RoughCoster.scala
similarity index 97%
rename from
gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/RoughCoster.scala
rename to
gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/cost/RoughCoster.scala
index ab893265ec..caee696df6 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/enumerated/planner/cost/RoughCoster.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/cost/RoughCoster.scala
@@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.gluten.extension.columnar.enumerated.planner.cost
+package org.apache.gluten.extension.columnar.cost
import org.apache.gluten.execution.RowToColumnarExecBase
import
org.apache.gluten.extension.columnar.transition.{ColumnarToColumnarLike,
ColumnarToRowLike, RowToColumnarLike}
diff --git
a/gluten-substrait/src/test/scala/org/apache/gluten/extension/columnar/transition/TransitionSuite.scala
b/gluten-substrait/src/test/scala/org/apache/gluten/extension/columnar/transition/TransitionSuite.scala
index 2c423783fd..03c9c0f559 100644
---
a/gluten-substrait/src/test/scala/org/apache/gluten/extension/columnar/transition/TransitionSuite.scala
+++
b/gluten-substrait/src/test/scala/org/apache/gluten/extension/columnar/transition/TransitionSuite.scala
@@ -20,7 +20,7 @@ import org.apache.gluten.backend.Backend
import org.apache.gluten.component.Component
import org.apache.gluten.exception.GlutenException
import org.apache.gluten.execution.{ColumnarToColumnarExec, GlutenPlan}
-import
org.apache.gluten.extension.columnar.enumerated.planner.cost.LegacyCoster
+import org.apache.gluten.extension.columnar.cost.{LegacyCoster, LongCoster}
import org.apache.gluten.extension.injector.Injector
import org.apache.spark.rdd.RDD
@@ -152,8 +152,7 @@ object TransitionSuite extends TransitionSuiteBase {
override def name(): String = "dummy-backend"
override def buildInfo(): Component.BuildInfo =
Component.BuildInfo("DUMMY_BACKEND", "N/A", "N/A", "N/A")
- override def injectRules(injector: Injector): Unit = {
- injector.gluten.ras.injectCoster(_ => LegacyCoster)
- }
+ override def injectRules(injector: Injector): Unit = {}
+ override def costers(): Seq[LongCoster] = Seq(LegacyCoster)
}
}
diff --git
a/shims/common/src/main/scala/org/apache/gluten/config/GlutenConfig.scala
b/shims/common/src/main/scala/org/apache/gluten/config/GlutenConfig.scala
index 3f30d00430..e4d3a76326 100644
--- a/shims/common/src/main/scala/org/apache/gluten/config/GlutenConfig.scala
+++ b/shims/common/src/main/scala/org/apache/gluten/config/GlutenConfig.scala
@@ -1431,12 +1431,15 @@ object GlutenConfig {
.booleanConf
.createWithDefault(false)
+ // FIXME: This option is no longer only used by RAS. Should change key to
+ // `spark.gluten.costModel` or something similar.
val RAS_COST_MODEL =
buildConf("spark.gluten.ras.costModel")
.doc(
"The class name of user-defined cost model that will be used by
Gluten's transition " +
- "planner as well as by RAS. If not specified, a legacy built-in cost
model that " +
- "exhaustively offloads computations will be used.")
+ "planner as well as by RAS. If not specified, a legacy built-in cost
model will be " +
+ "used. The legacy cost model helps RAS planner exhaustively offload
computations, and " +
+ "helps transition planner choose columnar-to-columnar transition
over others.")
.stringConf
.createWithDefaultString("legacy")
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]