This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 255b0cc8eb [GLUTEN-7313][VL] Explicit Arrow transitions, part 2: new
algorithm to find optimal transition (#7372)
255b0cc8eb is described below
commit 255b0cc8ebae7b23e2966dfd922b591958b8b6fa
Author: Hongze Zhang <[email protected]>
AuthorDate: Sun Sep 29 09:10:55 2024 +0800
[GLUTEN-7313][VL] Explicit Arrow transitions, part 2: new algorithm to find
optimal transition (#7372)
---
.../gluten/backendsapi/clickhouse/CHBackend.scala | 2 +-
.../backendsapi/clickhouse/CHListenerApi.scala | 3 +
.../org/apache/gluten/columnarbatch/CHBatch.scala | 5 +-
.../gluten/backendsapi/velox/VeloxBackend.scala | 24 ++--
.../backendsapi/velox/VeloxListenerApi.scala | 8 ++
.../api/python/ColumnarArrowEvalPythonExec.scala | 21 ++-
.../columnar/transition/VeloxTransitionSuite.scala | 44 +++---
.../scala/org/apache/gluten/backend/Backend.scala | 9 +-
.../extension/columnar/transition/Convention.scala | 24 ++--
.../columnar/transition/ConventionFunc.scala | 56 +++++---
.../columnar/transition/ConventionReq.scala | 2 +-
.../columnar/transition/FloydWarshallGraph.scala | 158 +++++++++++++++++++++
.../extension/columnar/transition/Transition.scala | 119 ++++------------
.../columnar/transition/TransitionGraph.scala | 90 ++++++++++++
.../columnar/transition/Transitions.scala | 2 +-
.../extension/columnar/transition/package.scala | 1 +
.../apache/spark/util/SparkReflectionUtil.scala | 23 +++
.../transition/FloydWarshallGraphSuite.scala | 103 ++++++++++++++
.../org/apache/gluten/extension/GlutenPlan.scala | 2 +-
.../scala/org/apache/gluten/utils/PlanUtil.scala | 4 +-
.../ColumnarCollapseTransformStages.scala | 2 +-
21 files changed, 521 insertions(+), 181 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
index d8966519da..900c9bafb4 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala
@@ -49,7 +49,7 @@ import scala.util.control.Breaks.{break, breakable}
class CHBackend extends SubstraitBackend {
override def name(): String = CHConf.BACKEND_NAME
- override def batchType: Convention.BatchType = CHBatch
+ override def defaultBatchType: Convention.BatchType = CHBatch
override def buildInfo(): Backend.BuildInfo =
Backend.BuildInfo("ClickHouse", CH_BRANCH, CH_COMMIT, "UNKNOWN")
override def iteratorApi(): IteratorApi = new CHIteratorApi
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
index f326f2edcf..16f5fa064c 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHListenerApi.scala
@@ -18,6 +18,7 @@ package org.apache.gluten.backendsapi.clickhouse
import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.ListenerApi
+import org.apache.gluten.columnarbatch.CHBatch
import org.apache.gluten.execution.CHBroadcastBuildSideCache
import org.apache.gluten.execution.datasource.{GlutenOrcWriterInjects,
GlutenParquetWriterInjects, GlutenRowSplitter}
import org.apache.gluten.expression.UDFMappings
@@ -68,6 +69,8 @@ class CHListenerApi extends ListenerApi with Logging {
override def onExecutorShutdown(): Unit = shutdown()
private def initialize(conf: SparkConf, isDriver: Boolean): Unit = {
+ // Force batch type initializations.
+ CHBatch.getClass
SparkDirectoryUtil.init(conf)
val libPath = conf.get(GlutenConfig.GLUTEN_LIB_PATH, StringUtils.EMPTY)
if (StringUtils.isBlank(libPath)) {
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/columnarbatch/CHBatch.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/columnarbatch/CHBatch.scala
index 0121d01578..079fa6bfd7 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/columnarbatch/CHBatch.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/columnarbatch/CHBatch.scala
@@ -23,8 +23,9 @@ import org.apache.spark.sql.execution.{CHColumnarToRowExec,
RowToCHNativeColumna
/**
* ClickHouse batch convention.
*
- * [[fromRow]] and [[toRow]] need a [[TransitionDef]] instance. The scala
allows an compact way to
- * implement trait using a lambda function.
+ * [[fromRow]] and [[toRow]] need a
+ * [[org.apache.gluten.extension.columnar.transition.TransitionDef]] instance.
The scala allows an
+ * compact way to implement trait using a lambda function.
*
* Here the detail definition is given in [[CHBatch.fromRow]].
* {{{
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
index 027876a2b0..56dc92f420 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala
@@ -25,8 +25,7 @@ import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.execution.WriteFilesExecTransformer
import org.apache.gluten.expression.WindowFunctionsBuilder
import org.apache.gluten.extension.ValidationResult
-import org.apache.gluten.extension.columnar.transition.Convention
-import
org.apache.gluten.extension.columnar.transition.ConventionFunc.BatchOverride
+import org.apache.gluten.extension.columnar.transition.{Convention,
ConventionFunc}
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat
import
org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat.{DwrfReadFormat,
OrcReadFormat, ParquetReadFormat}
@@ -37,7 +36,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias,
CumeDist, DenseRank, De
import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
ApproximatePercentile}
import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
-import org.apache.spark.sql.execution.ColumnarCachedBatchSerializer
+import org.apache.spark.sql.execution.{ColumnarCachedBatchSerializer,
SparkPlan}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import
org.apache.spark.sql.execution.command.CreateDataSourceTableAsSelectCommand
import org.apache.spark.sql.execution.datasources.{FileFormat,
InsertIntoHadoopFsRelationCommand}
@@ -51,14 +50,10 @@ import org.apache.hadoop.fs.Path
import scala.util.control.Breaks.breakable
class VeloxBackend extends SubstraitBackend {
+ import VeloxBackend._
override def name(): String = VeloxBackend.BACKEND_NAME
- override def batchType: Convention.BatchType = VeloxBatch
- override def batchTypeFunc(): BatchOverride = {
- case i: InMemoryTableScanExec
- if i.supportsColumnar && i.relation.cacheBuilder.serializer
- .isInstanceOf[ColumnarCachedBatchSerializer] =>
- VeloxBatch
- }
+ override def defaultBatchType: Convention.BatchType = VeloxBatch
+ override def convFuncOverride(): ConventionFunc.Override = new ConvFunc()
override def buildInfo(): Backend.BuildInfo =
Backend.BuildInfo("Velox", VELOX_BRANCH, VELOX_REVISION,
VELOX_REVISION_TIME)
override def iteratorApi(): IteratorApi = new VeloxIteratorApi
@@ -74,6 +69,15 @@ class VeloxBackend extends SubstraitBackend {
object VeloxBackend {
val BACKEND_NAME: String = "velox"
val CONF_PREFIX: String = GlutenConfig.prefixOf(BACKEND_NAME)
+
+ private class ConvFunc() extends ConventionFunc.Override {
+ override def batchTypeOf: PartialFunction[SparkPlan, Convention.BatchType]
= {
+ case i: InMemoryTableScanExec
+ if i.supportsColumnar && i.relation.cacheBuilder.serializer
+ .isInstanceOf[ColumnarCachedBatchSerializer] =>
+ VeloxBatch
+ }
+ }
}
object VeloxBackendSettings extends BackendSettingsApi {
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
index 926991ada7..e763e31dc5 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala
@@ -18,6 +18,8 @@ package org.apache.gluten.backendsapi.velox
import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.ListenerApi
+import org.apache.gluten.columnarbatch.ArrowBatches.{ArrowJavaBatch,
ArrowNativeBatch}
+import org.apache.gluten.columnarbatch.VeloxBatch
import org.apache.gluten.execution.datasource.{GlutenOrcWriterInjects,
GlutenParquetWriterInjects, GlutenRowSplitter}
import org.apache.gluten.expression.UDFMappings
import org.apache.gluten.init.NativeBackendInitializer
@@ -119,6 +121,12 @@ class VeloxListenerApi extends ListenerApi with Logging {
override def onExecutorShutdown(): Unit = shutdown()
private def initialize(conf: SparkConf): Unit = {
+ // Force batch type initializations.
+ VeloxBatch.getClass
+ ArrowJavaBatch.getClass
+ ArrowNativeBatch.getClass
+
+ // Sets this configuration only once, since not undoable.
if (conf.getBoolean(GlutenConfig.GLUTEN_DEBUG_KEEP_JNI_WORKSPACE,
defaultValue = false)) {
val debugDir = conf.get(GlutenConfig.GLUTEN_DEBUG_KEEP_JNI_WORKSPACE_DIR)
JniWorkspace.enableDebug(debugDir)
diff --git
a/backends-velox/src/main/scala/org/apache/spark/api/python/ColumnarArrowEvalPythonExec.scala
b/backends-velox/src/main/scala/org/apache/spark/api/python/ColumnarArrowEvalPythonExec.scala
index 5c27f94ca8..f1f5eb9062 100644
---
a/backends-velox/src/main/scala/org/apache/spark/api/python/ColumnarArrowEvalPythonExec.scala
+++
b/backends-velox/src/main/scala/org/apache/spark/api/python/ColumnarArrowEvalPythonExec.scala
@@ -16,8 +16,8 @@
*/
package org.apache.spark.api.python
-import org.apache.gluten.columnarbatch.{ColumnarBatches, VeloxBatch}
import org.apache.gluten.columnarbatch.ArrowBatches.ArrowJavaBatch
+import org.apache.gluten.columnarbatch.ColumnarBatches
import org.apache.gluten.exception.GlutenException
import org.apache.gluten.extension.GlutenPlan
import org.apache.gluten.extension.columnar.transition.{Convention,
ConventionReq}
@@ -218,13 +218,8 @@ case class ColumnarArrowEvalPythonExec(
override protected def batchType0(): Convention.BatchType = ArrowJavaBatch
- // FIXME: Make this accepts ArrowJavaBatch as input. Before doing that, a
weight-based
- // shortest patch algorithm should be added into transition factory. So
that the factory
- // can find out row->velox->arrow-native->arrow-java as the possible viable
transition.
- // Otherwise with current solution, any input (even already in Arrow Java
format) will be
- // converted into Velox format then into Arrow Java format before entering
python runner.
override def requiredChildrenConventions(): Seq[ConventionReq] = List(
- ConventionReq.of(ConventionReq.RowType.Any,
ConventionReq.BatchType.Is(VeloxBatch)))
+ ConventionReq.of(ConventionReq.RowType.Any,
ConventionReq.BatchType.Is(ArrowJavaBatch)))
override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output
rows"),
@@ -348,17 +343,17 @@ case class ColumnarArrowEvalPythonExec(
val inputBatchIter = contextAwareIterator.map {
inputCb =>
start_time = System.nanoTime()
- val loaded =
ColumnarBatches.load(ArrowBufferAllocators.contextInstance(), inputCb)
- ColumnarBatches.retain(loaded)
+ ColumnarBatches.checkLoaded(inputCb)
+ ColumnarBatches.retain(inputCb)
// 0. cache input for later merge
- inputCbCache += loaded
- numInputRows += loaded.numRows
+ inputCbCache += inputCb
+ numInputRows += inputCb.numRows
// We only need to pass the referred cols data to python worker
for evaluation.
var colsForEval = new ArrayBuffer[ColumnVector]()
for (i <- originalOffsets) {
- colsForEval += loaded.column(i)
+ colsForEval += inputCb.column(i)
}
- new ColumnarBatch(colsForEval.toArray, loaded.numRows())
+ new ColumnarBatch(colsForEval.toArray, inputCb.numRows())
}
val outputColumnarBatchIterator =
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/transition/VeloxTransitionSuite.scala
b/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/transition/VeloxTransitionSuite.scala
index df8ee8cc53..e23e7885f2 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/transition/VeloxTransitionSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/transition/VeloxTransitionSuite.scala
@@ -19,7 +19,6 @@ package org.apache.gluten.extension.columnar.transition
import org.apache.gluten.backendsapi.velox.VeloxListenerApi
import org.apache.gluten.columnarbatch.ArrowBatches.{ArrowJavaBatch,
ArrowNativeBatch}
import org.apache.gluten.columnarbatch.VeloxBatch
-import org.apache.gluten.exception.GlutenException
import org.apache.gluten.execution.{LoadArrowDataExec, OffloadArrowDataExec,
RowToVeloxColumnarExec, VeloxColumnarToRowExec}
import
org.apache.gluten.extension.columnar.transition.Convention.BatchType.VanillaBatch
import org.apache.gluten.test.MockVeloxBackend
@@ -64,11 +63,10 @@ class VeloxTransitionSuite extends SharedSparkSession {
test("ArrowNative R2C - requires Arrow input") {
val in = BatchUnary(ArrowNativeBatch, RowLeaf())
- assertThrows[GlutenException] {
- // No viable transitions.
- // FIXME: Support this case.
- Transitions.insertTransitions(in, outputsColumnar = false)
- }
+ val out = Transitions.insertTransitions(in, outputsColumnar = false)
+ assert(
+ out == ColumnarToRowExec(
+ LoadArrowDataExec(BatchUnary(ArrowNativeBatch,
RowToVeloxColumnarExec(RowLeaf())))))
}
test("ArrowNative-to-Velox C2C") {
@@ -92,11 +90,12 @@ class VeloxTransitionSuite extends SharedSparkSession {
test("Vanilla-to-ArrowNative C2C") {
val in = BatchUnary(ArrowNativeBatch, BatchLeaf(VanillaBatch))
- assertThrows[GlutenException] {
- // No viable transitions.
- // FIXME: Support this case.
- Transitions.insertTransitions(in, outputsColumnar = false)
- }
+ val out = Transitions.insertTransitions(in, outputsColumnar = false)
+ assert(
+ out == ColumnarToRowExec(
+ LoadArrowDataExec(BatchUnary(
+ ArrowNativeBatch,
+
RowToVeloxColumnarExec(ColumnarToRowExec(BatchLeaf(VanillaBatch)))))))
}
test("ArrowNative-to-Vanilla C2C") {
@@ -121,11 +120,10 @@ class VeloxTransitionSuite extends SharedSparkSession {
test("ArrowJava R2C - requires Arrow input") {
val in = BatchUnary(ArrowJavaBatch, RowLeaf())
- assertThrows[GlutenException] {
- // No viable transitions.
- // FIXME: Support this case.
- Transitions.insertTransitions(in, outputsColumnar = false)
- }
+ val out = Transitions.insertTransitions(in, outputsColumnar = false)
+ assert(
+ out == ColumnarToRowExec(
+ BatchUnary(ArrowJavaBatch,
LoadArrowDataExec(RowToVeloxColumnarExec(RowLeaf())))))
}
test("ArrowJava-to-Velox C2C") {
@@ -146,11 +144,12 @@ class VeloxTransitionSuite extends SharedSparkSession {
test("Vanilla-to-ArrowJava C2C") {
val in = BatchUnary(ArrowJavaBatch, BatchLeaf(VanillaBatch))
- assertThrows[GlutenException] {
- // No viable transitions.
- // FIXME: Support this case.
- Transitions.insertTransitions(in, outputsColumnar = false)
- }
+ val out = Transitions.insertTransitions(in, outputsColumnar = false)
+ assert(
+ out == ColumnarToRowExec(
+ BatchUnary(
+ ArrowJavaBatch,
+
LoadArrowDataExec(RowToVeloxColumnarExec(ColumnarToRowExec(BatchLeaf(VanillaBatch)))))))
}
test("ArrowJava-to-Vanilla C2C") {
@@ -195,8 +194,7 @@ class VeloxTransitionSuite extends SharedSparkSession {
val in = BatchUnary(VanillaBatch, BatchLeaf(VeloxBatch))
val out = Transitions.insertTransitions(in, outputsColumnar = false)
assert(
- out == ColumnarToRowExec(
- BatchUnary(VanillaBatch,
RowToColumnarExec(VeloxColumnarToRowExec(BatchLeaf(VeloxBatch))))))
+ out == ColumnarToRowExec(BatchUnary(VanillaBatch,
LoadArrowDataExec(BatchLeaf(VeloxBatch)))))
}
override protected def beforeAll(): Unit = {
diff --git a/gluten-core/src/main/scala/org/apache/gluten/backend/Backend.scala
b/gluten-core/src/main/scala/org/apache/gluten/backend/Backend.scala
index 5f82a2ee7d..1b9175d6b0 100644
--- a/gluten-core/src/main/scala/org/apache/gluten/backend/Backend.scala
+++ b/gluten-core/src/main/scala/org/apache/gluten/backend/Backend.scala
@@ -39,15 +39,16 @@ trait Backend {
def onExecutorStart(pc: PluginContext): Unit = {}
def onExecutorShutdown(): Unit = {}
- /** The columnar-batch type this backend is using. */
- def batchType: Convention.BatchType
+ /** The columnar-batch type this backend is by default using. */
+ def defaultBatchType: Convention.BatchType
/**
* Overrides
[[org.apache.gluten.extension.columnar.transition.ConventionFunc]] Gluten is
using to
* determine the convention (its row-based processing / columnar-batch
processing support) of a
- * plan with a user-defined function that accepts a plan then returns batch
type it outputs.
+ * plan with a user-defined function that accepts a plan then returns
convention type it outputs,
+ * and input conventions it requires.
*/
- def batchTypeFunc(): ConventionFunc.BatchOverride = PartialFunction.empty
+ def convFuncOverride(): ConventionFunc.Override =
ConventionFunc.Override.Empty
/** Query planner rules. */
def injectRules(injector: RuleInjector): Unit
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala
index fcd34cb1f2..d585af7c71 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala
@@ -61,7 +61,9 @@ object Convention {
Impl(rowType, batchType)
}
- sealed trait RowType
+ sealed trait RowType extends TransitionGraph.Vertex with Serializable {
+ Transition.graph.addVertex(this)
+ }
object RowType {
// None indicates that the plan doesn't support row-based processing.
@@ -69,23 +71,25 @@ object Convention {
final case object VanillaRow extends RowType
}
- trait BatchType extends Serializable {
- final def fromRow(transitionDef: TransitionDef): Unit = {
- Transition.factory.update().defineFromRowTransition(this, transitionDef)
+ trait BatchType extends TransitionGraph.Vertex with Serializable {
+ Transition.graph.addVertex(this)
+
+ final protected def fromRow(transitionDef: TransitionDef): Unit = {
+ Transition.graph.addEdge(RowType.VanillaRow, this,
transitionDef.create())
}
- final def toRow(transitionDef: TransitionDef): Unit = {
- Transition.factory.update().defineToRowTransition(this, transitionDef)
+ final protected def toRow(transitionDef: TransitionDef): Unit = {
+ Transition.graph.addEdge(this, RowType.VanillaRow,
transitionDef.create())
}
- final def fromBatch(from: BatchType, transitionDef: TransitionDef): Unit =
{
+ final protected def fromBatch(from: BatchType, transitionDef:
TransitionDef): Unit = {
assert(from != this)
- Transition.factory.update().defineBatchTransition(from, this,
transitionDef)
+ Transition.graph.addEdge(from, this, transitionDef.create())
}
- final def toBatch(to: BatchType, transitionDef: TransitionDef): Unit = {
+ final protected def toBatch(to: BatchType, transitionDef: TransitionDef):
Unit = {
assert(to != this)
- Transition.factory.update().defineBatchTransition(this, to,
transitionDef)
+ Transition.graph.addEdge(this, to, transitionDef.create())
}
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionFunc.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionFunc.scala
index 21662f503e..c3feefe943 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionFunc.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionFunc.scala
@@ -26,13 +26,21 @@ import
org.apache.spark.sql.execution.command.DataWritingCommandExec
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
/** ConventionFunc is a utility to derive [[Convention]] or [[ConventionReq]]
from a query plan. */
-trait ConventionFunc {
+sealed trait ConventionFunc {
def conventionOf(plan: SparkPlan): Convention
def conventionReqOf(plan: SparkPlan): ConventionReq
}
object ConventionFunc {
- type BatchOverride = PartialFunction[SparkPlan, Convention.BatchType]
+ trait Override {
+ def rowTypeOf: PartialFunction[SparkPlan, Convention.RowType] =
PartialFunction.empty
+ def batchTypeOf: PartialFunction[SparkPlan, Convention.BatchType] =
PartialFunction.empty
+ def conventionReqOf: PartialFunction[SparkPlan, ConventionReq] =
PartialFunction.empty
+ }
+
+ object Override {
+ object Empty extends Override
+ }
// For testing, to make things work without a backend loaded.
private var ignoreBackend: Boolean = false
@@ -53,17 +61,17 @@ object ConventionFunc {
new BuiltinFunc(batchOverride)
}
- private def newOverride(): BatchOverride = {
+ private def newOverride(): Override = {
synchronized {
if (ignoreBackend) {
// For testing
- return PartialFunction.empty
+ return Override.Empty
}
}
- Backend.get().batchTypeFunc()
+ Backend.get().convFuncOverride()
}
- private class BuiltinFunc(o: BatchOverride) extends ConventionFunc {
+ private class BuiltinFunc(o: Override) extends ConventionFunc {
import BuiltinFunc._
override def conventionOf(plan: SparkPlan): Convention = {
val conv = conventionOf0(plan)
@@ -86,7 +94,7 @@ object ConventionFunc {
val batchType = if (a.supportsColumnar) {
// By default, we execute columnar AQE with backend batch output.
// See
org.apache.gluten.extension.columnar.transition.InsertTransitions.apply
- Backend.get().batchType
+ Backend.get().defaultBatchType
} else {
Convention.BatchType.None
}
@@ -98,6 +106,11 @@ object ConventionFunc {
}
private def rowTypeOf(plan: SparkPlan): Convention.RowType = {
+ val out = o.rowTypeOf.applyOrElse(plan, rowTypeOf0)
+ out
+ }
+
+ private def rowTypeOf0(plan: SparkPlan): Convention.RowType = {
val out = plan match {
case k: Convention.KnownRowType =>
k.rowType()
@@ -113,25 +126,26 @@ object ConventionFunc {
}
private def batchTypeOf(plan: SparkPlan): Convention.BatchType = {
- val out = o.applyOrElse(
- plan,
- (p: SparkPlan) =>
- p match {
- case k: Convention.KnownBatchType =>
- k.batchType()
- case _ if plan.supportsColumnar =>
- Convention.BatchType.VanillaBatch
- case _ =>
- Convention.BatchType.None
- }
- )
+ val out = o.batchTypeOf.applyOrElse(plan, batchTypeOf0)
+ out
+ }
+
+ private def batchTypeOf0(plan: SparkPlan): Convention.BatchType = {
+ val out = plan match {
+ case k: Convention.KnownBatchType =>
+ k.batchType()
+ case _ if plan.supportsColumnar =>
+ Convention.BatchType.VanillaBatch
+ case _ =>
+ Convention.BatchType.None
+ }
assert(out == Convention.BatchType.None || plan.supportsColumnar)
out
}
override def conventionReqOf(plan: SparkPlan): ConventionReq = {
- val out = conventionReqOf0(plan)
- out
+ val req = o.conventionReqOf.applyOrElse(plan, conventionReqOf0)
+ req
}
private def conventionReqOf0(plan: SparkPlan): ConventionReq = plan match {
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionReq.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionReq.scala
index cb76ec4de0..ce613bf7db 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionReq.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionReq.scala
@@ -58,7 +58,7 @@ object ConventionReq {
val vanillaBatch: ConventionReq =
Impl(RowType.Any, BatchType.Is(Convention.BatchType.VanillaBatch))
lazy val backendBatch: ConventionReq =
- Impl(RowType.Any, BatchType.Is(Backend.get().batchType))
+ Impl(RowType.Any, BatchType.Is(Backend.get().defaultBatchType))
def get(plan: SparkPlan): ConventionReq =
ConventionFunc.create().conventionReqOf(plan)
def of(rowType: RowType, batchType: BatchType): ConventionReq =
Impl(rowType, batchType)
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraph.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraph.scala
new file mode 100644
index 0000000000..2a4e1f4225
--- /dev/null
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraph.scala
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension.columnar.transition
+
+import scala.collection.mutable
+
+/**
+ * Floyd-Warshall algorithm for finding e.g., cheapest transition between
query plan nodes.
+ *
+ * https://en.wikipedia.org/wiki/Floyd%E2%80%93Warshall_algorithm
+ */
+trait FloydWarshallGraph[V <: AnyRef, E <: AnyRef] {
+ import FloydWarshallGraph._
+ def hasPath(from: V, to: V): Boolean
+ def pathOf(from: V, to: V): Path[E]
+}
+
+object FloydWarshallGraph {
+ trait Cost {
+ def +(other: Cost): Cost
+ }
+
+ trait CostModel[E <: AnyRef] {
+ def zero(): Cost
+ def costOf(edge: E): Cost
+ def costComparator(): Ordering[Cost]
+ }
+
+ trait Path[E <: AnyRef] {
+ def edges(): Seq[E]
+ def cost(): Cost
+ }
+
+ def builder[V <: AnyRef, E <: AnyRef](costModel: CostModel[E]): Builder[V,
E] = {
+ Builder.create(costModel)
+ }
+
+ private object Path {
+ def apply[E <: AnyRef](costModel: CostModel[E], edges: Seq[E]): Path[E] =
Impl(edges)(costModel)
+ private case class Impl[E <: AnyRef](override val edges:
Seq[E])(costModel: CostModel[E])
+ extends Path[E] {
+ override val cost: Cost = {
+ edges.map(costModel.costOf).reduceOption(_ +
_).getOrElse(costModel.zero())
+ }
+ }
+ }
+
+ private class Impl[V <: AnyRef, E <: AnyRef](pathTable: Map[V, Map[V,
Path[E]]])
+ extends FloydWarshallGraph[V, E] {
+ override def hasPath(from: V, to: V): Boolean = {
+ if (!pathTable.contains(from)) {
+ return false
+ }
+ val vec = pathTable(from)
+ if (!vec.contains(to)) {
+ return false
+ }
+ true
+ }
+
+ override def pathOf(from: V, to: V): Path[E] = {
+ assert(hasPath(from, to))
+ val path = pathTable(from)(to)
+ path
+ }
+ }
+
+ trait Builder[V <: AnyRef, E <: AnyRef] {
+ def addVertex(v: V): Builder[V, E]
+ def addEdge(from: V, to: V, edge: E): Builder[V, E]
+ def build(): FloydWarshallGraph[V, E]
+ }
+
+ private object Builder {
+ // Thread safe.
+ private class Impl[V <: AnyRef, E <: AnyRef](costModel: CostModel[E])
extends Builder[V, E] {
+ private val pathTable: mutable.Map[V, mutable.Map[V, Path[E]]] =
mutable.Map()
+ private var graph: Option[FloydWarshallGraph[V, E]] = None
+
+ override def addVertex(v: V): Builder[V, E] = synchronized {
+ assert(!pathTable.contains(v), s"Vertex $v already exists in graph")
+ pathTable.getOrElseUpdate(v, mutable.Map()).getOrElseUpdate(v,
Path(costModel, Nil))
+ graph = None
+ this
+ }
+
+ override def addEdge(from: V, to: V, edge: E): Builder[V, E] =
synchronized {
+ assert(from != to, s"Input vertices $from and $to should be different")
+ assert(pathTable.contains(from), s"Vertex $from not exists in graph")
+ assert(pathTable.contains(to), s"Vertex $to not exists in graph")
+ assert(!hasPath(from, to), s"Path from $from to $to already exists in
graph")
+ pathTable(from) += to -> Path(costModel, Seq(edge))
+ graph = None
+ this
+ }
+
+ override def build(): FloydWarshallGraph[V, E] = synchronized {
+ if (graph.isEmpty) {
+ graph = Some(compile())
+ }
+ return graph.get
+ }
+
+ private def hasPath(from: V, to: V): Boolean = {
+ if (!pathTable.contains(from)) {
+ return false
+ }
+ val vec = pathTable(from)
+ if (!vec.contains(to)) {
+ return false
+ }
+ true
+ }
+
+ private def compile(): FloydWarshallGraph[V, E] = {
+ val vertices = pathTable.keys
+ for (k <- vertices) {
+ for (i <- vertices) {
+ for (j <- vertices) {
+ if (hasPath(i, k) && hasPath(k, j)) {
+ val pathIk = pathTable(i)(k)
+ val pathKj = pathTable(k)(j)
+ val newPath = Path(costModel, pathIk.edges() ++ pathKj.edges())
+ if (!hasPath(i, j)) {
+ pathTable(i) += j -> newPath
+ } else {
+ val path = pathTable(i)(j)
+ if (costModel.costComparator().compare(newPath.cost(),
path.cost()) < 0) {
+ pathTable(i) += j -> newPath
+ }
+ }
+ }
+ }
+ }
+ }
+ new FloydWarshallGraph.Impl(pathTable.map { case (k, m) => (k,
m.toMap) }.toMap)
+ }
+ }
+
+ def create[V <: AnyRef, E <: AnyRef](costModel: CostModel[E]): Builder[V,
E] = {
+ new Impl(costModel)
+ }
+ }
+}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
index 3fd2839b5a..87e7204c14 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
@@ -18,9 +18,10 @@ package org.apache.gluten.extension.columnar.transition
import org.apache.gluten.exception.GlutenException
-import org.apache.spark.sql.execution.SparkPlan
-
-import scala.collection.mutable
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan}
/**
* Transition is a simple function to convert a query plan to interested
[[ConventionReq]].
@@ -29,18 +30,18 @@ import scala.collection.mutable
* [[org.apache.gluten.extension.columnar.transition.Convention.BatchType]]'s
definition.
*/
trait Transition {
+ import Transition._
final def apply(plan: SparkPlan): SparkPlan = {
val out = apply0(plan)
- if (out eq plan) {
- assert(
- this == Transition.empty,
- "TransitionDef.empty / Transition.empty should be used when defining
an empty transition.")
- }
out
}
- final def isEmpty: Boolean = {
- this == Transition.empty
+ final lazy val isEmpty: Boolean = {
+ // Tests if a transition is actually no-op.
+ val plan = DummySparkPlan()
+ val out = apply0(plan)
+ val identical = out eq plan
+ identical
}
protected def apply0(plan: SparkPlan): SparkPlan
@@ -56,7 +57,10 @@ object TransitionDef {
object Transition {
val empty: Transition = (plan: SparkPlan) => plan
- val factory: Factory = Factory.newBuiltin()
+ private val abort: Transition = (_: SparkPlan) => throw new
UnsupportedOperationException("Abort")
+ private[transition] val graph: TransitionGraph.Builder =
TransitionGraph.builder()
+
+ def factory(): Factory = Factory.newBuiltin(graph.build())
def notFound(plan: SparkPlan): GlutenException = {
new GlutenException(s"No viable transition found from plan's child to
itself: $plan")
@@ -66,19 +70,6 @@ object Transition {
new GlutenException(s"No viable transition to [$required] found for plan:
$plan")
}
- private class ChainedTransition(first: Transition, second: Transition)
extends Transition {
- override def apply0(plan: SparkPlan): SparkPlan = {
- second(first(plan))
- }
- }
-
- private def chain(first: Transition, second: Transition): Transition = {
- if (first.isEmpty && second.isEmpty) {
- return Transition.empty
- }
- new ChainedTransition(first, second)
- }
-
trait Factory {
final def findTransition(
from: Convention,
@@ -90,63 +81,20 @@ object Transition {
}
final def satisfies(conv: Convention, req: ConventionReq): Boolean = {
- val none = new Transition {
- override protected def apply0(plan: SparkPlan): SparkPlan =
- throw new UnsupportedOperationException()
- }
- val transition = findTransition(conv, req)(none)
+ val transition = findTransition(conv, req)(abort)
transition.isEmpty
}
protected def findTransition(from: Convention, to: ConventionReq)(
orElse: => Transition): Transition
- private[transition] def update(): MutableFactory
- }
-
- trait MutableFactory extends Factory {
- def defineFromRowTransition(to: Convention.BatchType, transitionDef:
TransitionDef): Unit
- def defineToRowTransition(from: Convention.BatchType, transitionDef:
TransitionDef): Unit
- def defineBatchTransition(
- from: Convention.BatchType,
- to: Convention.BatchType,
- transitionDef: TransitionDef): Unit
}
private object Factory {
- def newBuiltin(): Factory = {
- new BuiltinFactory
+ def newBuiltin(graph: TransitionGraph): Factory = {
+ new BuiltinFactory(graph)
}
- private class BuiltinFactory extends MutableFactory {
- private val fromRowTransitions: mutable.Map[Convention.BatchType,
TransitionDef] =
- mutable.Map()
- private val toRowTransitions: mutable.Map[Convention.BatchType,
TransitionDef] = mutable.Map()
- private val batchTransitions
- : mutable.Map[(Convention.BatchType, Convention.BatchType),
TransitionDef] =
- mutable.Map()
-
- override def defineFromRowTransition(
- to: Convention.BatchType,
- transitionDef: TransitionDef): Unit = {
- assert(!fromRowTransitions.contains(to))
- fromRowTransitions += to -> transitionDef
- }
-
- override def defineToRowTransition(
- from: Convention.BatchType,
- transitionDef: TransitionDef): Unit = {
- assert(!toRowTransitions.contains(from))
- toRowTransitions += from -> transitionDef
- }
-
- override def defineBatchTransition(
- from: Convention.BatchType,
- to: Convention.BatchType,
- transitionDef: TransitionDef): Unit = {
- assert(!batchTransitions.contains((from, to)))
- batchTransitions += (from, to) -> transitionDef
- }
-
+ private class BuiltinFactory(graph: TransitionGraph) extends Factory {
override def findTransition(from: Convention, to: ConventionReq)(
orElse: => Transition): Transition = {
assert(
@@ -165,7 +113,7 @@ object Transition {
case (ConventionReq.RowType.Is(toRowType),
ConventionReq.BatchType.Any) =>
from.rowType match {
case Convention.RowType.None =>
-
toRowTransitions.get(from.batchType).map(_.create()).getOrElse(orElse)
+ graph.transitionOfOption(from.batchType,
toRowType).getOrElse(orElse)
case fromRowType =>
// We have only one single built-in row type.
assert(toRowType == fromRowType)
@@ -174,27 +122,12 @@ object Transition {
case (ConventionReq.RowType.Any,
ConventionReq.BatchType.Is(toBatchType)) =>
from.batchType match {
case Convention.BatchType.None =>
-
fromRowTransitions.get(toBatchType).map(_.create()).getOrElse(orElse)
+ graph.transitionOfOption(from.rowType,
toBatchType).getOrElse(orElse)
case fromBatchType =>
if (toBatchType == fromBatchType) {
Transition.empty
} else {
- // Batch type conversion needed.
- //
- // We first look up for batch-to-batch transition. If found
one, return that
- // transition to caller. Otherwise, look for from/to row
transitions, then
- // return a bridged batch-to-row-to-batch transition.
- if (batchTransitions.contains((fromBatchType, toBatchType)))
{
- // 1. Found batch-to-batch transition.
- batchTransitions((fromBatchType, toBatchType)).create()
- } else {
- // 2. Otherwise, build up batch-to-row-to-batch transition.
- val batchToRow =
-
toRowTransitions.get(fromBatchType).map(_.create()).getOrElse(orElse)
- val rowToBatch =
-
fromRowTransitions.get(toBatchType).map(_.create()).getOrElse(orElse)
- chain(batchToRow, rowToBatch)
- }
+ graph.transitionOfOption(fromBatchType,
toBatchType).getOrElse(orElse)
}
}
case (ConventionReq.RowType.Any, ConventionReq.BatchType.Any) =>
@@ -205,8 +138,12 @@ object Transition {
}
out
}
-
- override private[transition] def update(): MutableFactory = this
}
}
+
+ private case class DummySparkPlan() extends LeafExecNode {
+ override def supportsColumnar: Boolean = true // To bypass the assertion
in ColumnarToRowExec.
+ override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
+ override def output: Seq[Attribute] = Nil
+ }
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala
new file mode 100644
index 0000000000..9cafcae8b5
--- /dev/null
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension.columnar.transition
+
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.util.SparkReflectionUtil
+
+object TransitionGraph {
+ trait Vertex {
+ override def toString: String =
SparkReflectionUtil.getSimpleClassName(this.getClass)
+ }
+
+ type Builder = FloydWarshallGraph.Builder[TransitionGraph.Vertex, Transition]
+
+ def builder(): Builder = {
+ FloydWarshallGraph.builder(TransitionCostModel)
+ }
+
+ implicit class TransitionGraphOps(val graph: TransitionGraph) {
+ import TransitionGraphOps._
+ def hasTransition(from: TransitionGraph.Vertex, to:
TransitionGraph.Vertex): Boolean = {
+ graph.hasPath(from, to)
+ }
+
+ def transitionOf(from: TransitionGraph.Vertex, to:
TransitionGraph.Vertex): Transition = {
+ val path = graph.pathOf(from, to)
+ val out = path.edges().reduceOption((l, r) => chain(l,
r)).getOrElse(Transition.empty)
+ out
+ }
+
+ def transitionOfOption(
+ from: TransitionGraph.Vertex,
+ to: TransitionGraph.Vertex): Option[Transition] = {
+ if (!hasTransition(from, to)) {
+ return None
+ }
+ Some(transitionOf(from, to))
+ }
+ }
+
+ private case class ChainedTransition(first: Transition, second: Transition)
extends Transition {
+ override def apply0(plan: SparkPlan): SparkPlan = {
+ second(first(plan))
+ }
+ }
+
+ private object TransitionGraphOps {
+ private def chain(first: Transition, second: Transition): Transition = {
+ if (first.isEmpty && second.isEmpty) {
+ return Transition.empty
+ }
+ ChainedTransition(first, second)
+ }
+ }
+
+ private case class TransitionCost(count: Int) extends
FloydWarshallGraph.Cost {
+ override def +(other: FloydWarshallGraph.Cost): TransitionCost = {
+ other match {
+ case TransitionCost(otherCount) => TransitionCost(count + otherCount)
+ }
+ }
+ }
+
+ private object TransitionCostModel extends
FloydWarshallGraph.CostModel[Transition] {
+ override def zero(): FloydWarshallGraph.Cost = TransitionCost(0)
+ override def costOf(transition: Transition): FloydWarshallGraph.Cost =
costOf0(transition)
+ override def costComparator(): Ordering[FloydWarshallGraph.Cost] =
Ordering.Int.on {
+ case TransitionCost(c) => c
+ }
+ private def costOf0(transition: Transition): TransitionCost = transition
match {
+ case t if t.isEmpty => TransitionCost(0)
+ case ChainedTransition(f, s) => costOf0(f) + costOf0(s)
+ case _ => TransitionCost(1)
+ }
+ }
+}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala
index 1441814519..9987a65b0c 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala
@@ -89,7 +89,7 @@ object Transitions {
}
def toBackendBatchPlan(plan: SparkPlan): SparkPlan = {
- val backendBatchType = Backend.get().batchType
+ val backendBatchType = Backend.get().defaultBatchType
val out = toBatchPlan(plan, backendBatchType)
out
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/package.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/package.scala
index 2dd8d632e3..0a0deb17de 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/package.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/package.scala
@@ -23,6 +23,7 @@ import
org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec
import org.apache.spark.sql.execution.debug.DebugExec
package object transition {
+ type TransitionGraph = FloydWarshallGraph[TransitionGraph.Vertex, Transition]
// These 5 plan operators (as of Spark 3.5) are operators that have the
// same convention with their children.
//
diff --git
a/gluten-core/src/main/scala/org/apache/spark/util/SparkReflectionUtil.scala
b/gluten-core/src/main/scala/org/apache/spark/util/SparkReflectionUtil.scala
new file mode 100644
index 0000000000..40692346e0
--- /dev/null
+++ b/gluten-core/src/main/scala/org/apache/spark/util/SparkReflectionUtil.scala
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.util
+
+object SparkReflectionUtil {
+ def getSimpleClassName(cls: Class[_]): String = {
+ Utils.getSimpleName(cls)
+ }
+}
diff --git
a/gluten-core/src/test/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraphSuite.scala
b/gluten-core/src/test/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraphSuite.scala
new file mode 100644
index 0000000000..6bc4ab804f
--- /dev/null
+++
b/gluten-core/src/test/scala/org/apache/gluten/extension/columnar/transition/FloydWarshallGraphSuite.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension.columnar.transition
+
+import org.scalatest.funsuite.AnyFunSuite
+
+import java.util.concurrent.atomic.AtomicInteger
+
+class FloydWarshallGraphSuite extends AnyFunSuite {
+ import FloydWarshallGraphSuite._
+ test("Sanity") {
+ val v0 = Vertex()
+ val v1 = Vertex()
+ val v2 = Vertex()
+ val v3 = Vertex()
+ val v4 = Vertex()
+
+ val e01 = Edge(5)
+ val e12 = Edge(6)
+ val e03 = Edge(2)
+ val e34 = Edge(1)
+ val e42 = Edge(3)
+
+ val graph = FloydWarshallGraph
+ .builder(CostModel)
+ .addVertex(v0)
+ .addVertex(v1)
+ .addVertex(v2)
+ .addVertex(v3)
+ .addVertex(v4)
+ .addEdge(v0, v1, e01)
+ .addEdge(v1, v2, e12)
+ .addEdge(v0, v3, e03)
+ .addEdge(v3, v4, e34)
+ .addEdge(v4, v2, e42)
+ .build()
+
+ assert(graph.hasPath(v0, v1))
+ assert(graph.hasPath(v0, v2))
+ assert(!graph.hasPath(v1, v0))
+ assert(!graph.hasPath(v2, v0))
+
+ assert(graph.pathOf(v0, v0).edges() == Nil)
+
+ assert(graph.pathOf(v0, v1).edges() == Seq(e01))
+ assert(graph.pathOf(v1, v2).edges() == Seq(e12))
+ assert(graph.pathOf(v0, v3).edges() == Seq(e03))
+ assert(graph.pathOf(v3, v4).edges() == Seq(e34))
+ assert(graph.pathOf(v4, v2).edges() == Seq(e42))
+
+ assert(graph.pathOf(v0, v2).edges() == Seq(e03, e34, e42))
+ }
+}
+
+private object FloydWarshallGraphSuite {
+ case class Vertex private (id: Int)
+
+ private object Vertex {
+ private val id = new AtomicInteger(0)
+
+ def apply(): Vertex = {
+ Vertex(id.getAndIncrement())
+ }
+ }
+
+ case class Edge private (id: Int, distance: Long)
+
+ private object Edge {
+ private val id = new AtomicInteger(0)
+
+ def apply(distance: Long): Edge = {
+ Edge(id.getAndIncrement(), distance)
+ }
+ }
+
+ private case class LongCost(c: Long) extends FloydWarshallGraph.Cost {
+ override def +(other: FloydWarshallGraph.Cost): FloydWarshallGraph.Cost =
other match {
+ case LongCost(o) => LongCost(c + o)
+ }
+ }
+
+ private object CostModel extends FloydWarshallGraph.CostModel[Edge] {
+ override def zero(): FloydWarshallGraph.Cost = LongCost(0)
+ override def costOf(edge: Edge): FloydWarshallGraph.Cost =
LongCost(edge.distance * 10)
+ override def costComparator(): Ordering[FloydWarshallGraph.Cost] =
Ordering.Long.on {
+ case LongCost(c) => c
+ }
+ }
+}
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala
index 9b05d567e2..c658a43760 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala
@@ -109,7 +109,7 @@ trait GlutenPlan extends SparkPlan with
Convention.KnownBatchType with LogLevelU
}
protected def batchType0(): Convention.BatchType = {
- Backend.get().batchType
+ Backend.get().defaultBatchType
}
protected def doValidateInternal(): ValidationResult =
ValidationResult.succeeded
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/utils/PlanUtil.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/utils/PlanUtil.scala
index c38ea8af02..9ebd722781 100644
--- a/gluten-substrait/src/main/scala/org/apache/gluten/utils/PlanUtil.scala
+++ b/gluten-substrait/src/main/scala/org/apache/gluten/utils/PlanUtil.scala
@@ -27,7 +27,7 @@ import scala.annotation.tailrec
object PlanUtil {
private def isGlutenTableCacheInternal(i: InMemoryTableScanExec): Boolean = {
- Convention.get(i).batchType == Backend.get().batchType
+ Convention.get(i).batchType == Backend.get().defaultBatchType
}
@tailrec
@@ -47,6 +47,6 @@ object PlanUtil {
}
def isGlutenColumnarOp(plan: SparkPlan): Boolean = {
- Convention.get(plan).batchType == Backend.get().batchType
+ Convention.get(plan).batchType == Backend.get().defaultBatchType
}
}
diff --git
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarCollapseTransformStages.scala
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarCollapseTransformStages.scala
index c57df192c5..ada7283da8 100644
---
a/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarCollapseTransformStages.scala
+++
b/gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarCollapseTransformStages.scala
@@ -166,7 +166,7 @@ case class ColumnarInputAdapter(child: SparkPlan)
override def output: Seq[Attribute] = child.output
override def supportsColumnar: Boolean = true
override def batchType(): Convention.BatchType =
- Backend.get().batchType
+ Backend.get().defaultBatchType
override protected def doExecute(): RDD[InternalRow] = throw new
UnsupportedOperationException()
override protected def doExecuteColumnar(): RDD[ColumnarBatch] =
child.executeColumnar()
override def outputPartitioning: Partitioning = child.outputPartitioning
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]