This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new d6b298221 [VL] RAS: Reuse same code path with heuristic planner for
convention enforcement (#5824)
d6b298221 is described below
commit d6b298221f1360626e52862985f78abc7436183d
Author: Hongze Zhang <[email protected]>
AuthorDate: Wed May 22 17:55:00 2024 +0800
[VL] RAS: Reuse same code path with heuristic planner for convention
enforcement (#5824)
---
.../backendsapi/velox/VeloxSparkPlanExecApi.scala | 3 +-
.../org/apache/gluten/planner/VeloxRasSuite.scala | 15 ++-
.../gluten/execution/ColumnarToRowExecBase.scala | 10 +-
.../org/apache/gluten/extension/GlutenPlan.scala | 3 +-
.../columnar/enumerated/EnumeratedTransform.scala | 13 +-
.../extension/columnar/transition/Convention.scala | 4 +
.../columnar/transition/ConventionFunc.scala | 115 +++++++++++++---
.../columnar/transition/ConventionReq.scala | 17 ++-
.../extension/columnar/transition/Transition.scala | 30 ++++-
.../columnar/transition/Transitions.scala | 48 +------
.../gluten/planner/plan/GlutenPlanModel.scala | 70 +++++++---
.../org/apache/gluten/planner/property/Conv.scala | 106 +++++++++++++++
.../gluten/planner/property/Convention.scala | 147 ---------------------
.../planner/property/GlutenPropertyModel.scala | 6 +-
.../org/apache/spark/util/SparkTaskUtil.scala | 21 ++-
.../apache/gluten/columnarbatch/ArrowBatch.scala | 7 +-
.../scala/org/apache/gluten/ras/PlanModel.scala | 2 +-
.../src/main/scala/org/apache/gluten/ras/Ras.scala | 22 ++-
.../scala/org/apache/gluten/ras/RasGroup.scala | 10 +-
.../main/scala/org/apache/gluten/ras/RasNode.scala | 6 +-
.../gluten/ras/exaustive/ExhaustivePlanner.scala | 2 +-
.../org/apache/gluten/ras/rule/RuleApplier.scala | 2 +-
.../apache/gluten/ras/vis/GraphvizVisualizer.scala | 2 +-
.../org/apache/gluten/ras/OperationSuite.scala | 7 +-
.../org/apache/gluten/ras/PropertySuite.scala | 9 +-
.../scala/org/apache/gluten/ras/RasSuiteBase.scala | 8 +-
.../org/apache/gluten/ras/mock/MockMemoState.scala | 2 +-
.../org/apache/gluten/ras/mock/MockRasPath.scala | 2 +-
.../gluten/ras/specific/DistributedSuite.scala | 18 +--
29 files changed, 401 insertions(+), 306 deletions(-)
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
index 2d37b1185..322116582 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
@@ -81,7 +81,8 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
*/
override def batchTypeFunc(): BatchOverride = {
case i: InMemoryTableScanExec
- if
i.relation.cacheBuilder.serializer.isInstanceOf[ColumnarCachedBatchSerializer]
=>
+ if i.supportsColumnar && i.relation.cacheBuilder.serializer
+ .isInstanceOf[ColumnarCachedBatchSerializer] =>
VeloxBatch
}
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/planner/VeloxRasSuite.scala
b/backends-velox/src/test/scala/org/apache/gluten/planner/VeloxRasSuite.scala
index 4690ef516..ae2cea0ba 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/planner/VeloxRasSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/planner/VeloxRasSuite.scala
@@ -16,7 +16,8 @@
*/
package org.apache.gluten.planner
-import org.apache.gluten.planner.property.Conventions
+import org.apache.gluten.extension.columnar.transition.ConventionReq
+import org.apache.gluten.planner.property.Conv
import org.apache.gluten.ras.Best.BestNotFoundException
import org.apache.gluten.ras.Ras
import org.apache.gluten.ras.RasSuiteBase._
@@ -44,7 +45,7 @@ class VeloxRasSuite extends SharedSparkSession {
test("C2R, R2C - explicitly requires any properties") {
val in = RowUnary(RowLeaf(TRIVIAL_SCHEMA))
val planner =
- newRas().newPlanner(in, PropertySet(List(Conventions.ANY)))
+ newRas().newPlanner(in, PropertySet(List(Conv.any)))
val out = planner.plan()
assert(out == RowUnary(RowLeaf(TRIVIAL_SCHEMA)))
}
@@ -52,7 +53,7 @@ class VeloxRasSuite extends SharedSparkSession {
test("C2R, R2C - requires columnar output") {
val in = RowUnary(RowLeaf(TRIVIAL_SCHEMA))
val planner =
- newRas().newPlanner(in, PropertySet(List(Conventions.VANILLA_COLUMNAR)))
+ newRas().newPlanner(in,
PropertySet(List(Conv.req(ConventionReq.vanillaBatch))))
val out = planner.plan()
assert(out == RowToColumnarExec(RowUnary(RowLeaf(TRIVIAL_SCHEMA))))
}
@@ -63,7 +64,7 @@ class VeloxRasSuite extends SharedSparkSession {
RowUnary(
RowUnary(ColumnarUnary(RowUnary(RowUnary(ColumnarUnary(RowLeaf(TRIVIAL_SCHEMA))))))))
val planner =
- newRas().newPlanner(in, PropertySet(List(Conventions.ROW_BASED)))
+ newRas().newPlanner(in, PropertySet(List(Conv.req(ConventionReq.row))))
val out = planner.plan()
assert(
out == ColumnarToRowExec(
@@ -91,7 +92,7 @@ class VeloxRasSuite extends SharedSparkSession {
RowUnary(ColumnarUnary(RowUnary(RowUnary(ColumnarUnary(RowLeaf(TRIVIAL_SCHEMA))))))))
val planner =
newRas(List(ConvertRowUnaryToColumnar))
- .newPlanner(in, PropertySet(List(Conventions.ROW_BASED)))
+ .newPlanner(in, PropertySet(List(Conv.req(ConventionReq.row))))
val out = planner.plan()
assert(out ==
ColumnarToRowExec(ColumnarUnary(ColumnarUnary(ColumnarUnary(ColumnarUnary(
ColumnarUnary(ColumnarUnary(ColumnarUnary(RowToColumnarExec(RowLeaf(TRIVIAL_SCHEMA)))))))))))
@@ -104,7 +105,7 @@ class VeloxRasSuite extends SharedSparkSession {
val in = RowUnary(RowLeaf(EMPTY_SCHEMA))
val planner =
- newRas().newPlanner(in, PropertySet(List(Conventions.ANY)))
+ newRas().newPlanner(in, PropertySet(List(Conv.any)))
val out = planner.plan()
assert(out == RowUnary(RowLeaf(EMPTY_SCHEMA)))
@@ -112,7 +113,7 @@ class VeloxRasSuite extends SharedSparkSession {
// Could not optimize to columnar output since R2C transitions for empty
schema node
// is not allowed.
val planner2 =
- newRas().newPlanner(in,
PropertySet(List(Conventions.VANILLA_COLUMNAR)))
+ newRas().newPlanner(in,
PropertySet(List(Conv.req(ConventionReq.vanillaBatch))))
planner2.plan()
}
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/execution/ColumnarToRowExecBase.scala
b/gluten-core/src/main/scala/org/apache/gluten/execution/ColumnarToRowExecBase.scala
index 6d3fa2dac..fd86106bf 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/execution/ColumnarToRowExecBase.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/execution/ColumnarToRowExecBase.scala
@@ -18,6 +18,8 @@ package org.apache.gluten.execution
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.extension.GlutenPlan
+import org.apache.gluten.extension.columnar.transition.ConventionReq
+import
org.apache.gluten.extension.columnar.transition.ConventionReq.KnownChildrenConventions
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
@@ -28,7 +30,8 @@ import
org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan}
abstract class ColumnarToRowExecBase(child: SparkPlan)
extends ColumnarToRowTransition
- with GlutenPlan {
+ with GlutenPlan
+ with KnownChildrenConventions {
// Note: "metrics" is made transient to avoid sending driver-side metrics to
tasks.
@transient override lazy val metrics =
@@ -50,4 +53,9 @@ abstract class ColumnarToRowExecBase(child: SparkPlan)
override def doExecute(): RDD[InternalRow] = {
doExecuteInternal()
}
+
+ override def requiredChildrenConventions(): Seq[ConventionReq] = {
+ List(ConventionReq.backendBatch)
+ }
+
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala
index 033e44b8c..8f1004be4 100644
--- a/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala
+++ b/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala
@@ -88,8 +88,7 @@ trait GlutenPlan extends SparkPlan with
Convention.KnownBatchType with LogLevelU
final override def batchType(): Convention.BatchType = {
if (!supportsColumnar) {
- throw new UnsupportedOperationException(
- s"Node $nodeName doesn't support columnar-batch processing")
+ return Convention.BatchType.None
}
val batchType = batchType0()
assert(batchType != Convention.BatchType.None)
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
index dc34bc1af..50f0dce13 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
@@ -17,8 +17,9 @@
package org.apache.gluten.extension.columnar.enumerated
import org.apache.gluten.extension.columnar.{OffloadExchange, OffloadJoin,
OffloadOthers, OffloadSingleNode}
+import org.apache.gluten.extension.columnar.transition.ConventionReq
import org.apache.gluten.planner.GlutenOptimization
-import org.apache.gluten.planner.property.Conventions
+import org.apache.gluten.planner.property.Conv
import org.apache.gluten.ras.property.PropertySet
import org.apache.gluten.utils.LogLevelUtil
@@ -48,9 +49,13 @@ case class EnumeratedTransform(session: SparkSession,
outputsColumnar: Boolean)
private val optimization = GlutenOptimization(rules ++ offloadRules)
- private val reqConvention = Conventions.ANY
- private val altConventions =
- Seq(Conventions.GLUTEN_COLUMNAR, Conventions.ROW_BASED)
+ private val reqConvention = Conv.any
+
+ private val altConventions = {
+ val rowBased: Conv = Conv.req(ConventionReq.row)
+ val backendBatchBased: Conv = Conv.req(ConventionReq.backendBatch)
+ Seq(rowBased, backendBatchBased)
+ }
override def apply(plan: SparkPlan): SparkPlan = {
val constraintSet = PropertySet(List(reqConvention))
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala
index 2774497d9..034b45851 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala
@@ -110,4 +110,8 @@ object Convention {
trait KnownBatchType {
def batchType(): BatchType
}
+
+ trait KnownRowType {
+ def rowType(): RowType
+ }
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionFunc.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionFunc.scala
index 28bd1d12c..453df5d88 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionFunc.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionFunc.scala
@@ -17,20 +17,22 @@
package org.apache.gluten.extension.columnar.transition
import org.apache.gluten.backendsapi.BackendsApiManager
-import org.apache.gluten.extension.columnar.transition.Convention.{BatchType,
RowType}
+import
org.apache.gluten.extension.columnar.transition.ConventionReq.KnownChildrenConventions
import org.apache.gluten.sql.shims.SparkShimLoader
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, UnionExec}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec,
QueryStageExec}
+import org.apache.spark.sql.execution.command.DataWritingCommandExec
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
-/** ConventionFunc is a utility to derive [[Convention]] from a query plan. */
+/** ConventionFunc is a utility to derive [[Convention]] or [[ConventionReq]]
from a query plan. */
trait ConventionFunc {
def conventionOf(plan: SparkPlan): Convention
+ def conventionReqOf(plan: SparkPlan): ConventionReq
}
object ConventionFunc {
- type BatchOverride = PartialFunction[SparkPlan, BatchType]
+ type BatchOverride = PartialFunction[SparkPlan, Convention.BatchType]
// For testing, to make things work without a backend loaded.
private var ignoreBackend: Boolean = false
@@ -47,18 +49,22 @@ object ConventionFunc {
}
def create(): ConventionFunc = {
+ val batchOverride = newOverride()
+ new BuiltinFunc(batchOverride)
+ }
+
+ private def newOverride(): BatchOverride = {
synchronized {
if (ignoreBackend) {
// For testing
- return new BuiltinFunc(PartialFunction.empty)
+ return PartialFunction.empty
}
}
- val batchOverride =
BackendsApiManager.getSparkPlanExecApiInstance.batchTypeFunc()
- new BuiltinFunc(batchOverride)
+ BackendsApiManager.getSparkPlanExecApiInstance.batchTypeFunc()
}
private class BuiltinFunc(o: BatchOverride) extends ConventionFunc {
-
+ import BuiltinFunc._
override def conventionOf(plan: SparkPlan): Convention = {
val conv = conventionOf0(plan)
conv
@@ -82,7 +88,7 @@ object ConventionFunc {
// See
org.apache.gluten.extension.columnar.transition.InsertTransitions.apply
BackendsApiManager.getSparkPlanExecApiInstance.batchType
} else {
- BatchType.None
+ Convention.BatchType.None
}
val conv = Convention.of(rowType, batchType)
conv
@@ -91,25 +97,94 @@ object ConventionFunc {
conv
}
- private def rowTypeOf(plan: SparkPlan): RowType = {
- if (!SparkShimLoader.getSparkShims.supportsRowBased(plan)) {
- return RowType.None
+ private def rowTypeOf(plan: SparkPlan): Convention.RowType = {
+ val out = plan match {
+ case k: Convention.KnownRowType =>
+ k.rowType()
+ case _ if !SparkShimLoader.getSparkShims.supportsRowBased(plan) =>
+ Convention.RowType.None
+ case _ =>
+ Convention.RowType.VanillaRow
}
- RowType.VanillaRow
+ if (out != Convention.RowType.None) {
+ assert(SparkShimLoader.getSparkShims.supportsRowBased(plan))
+ }
+ out
}
- private def batchTypeOf(plan: SparkPlan): BatchType = {
- if (!plan.supportsColumnar) {
- return BatchType.None
- }
- o.applyOrElse(
+ private def batchTypeOf(plan: SparkPlan): Convention.BatchType = {
+ val out = o.applyOrElse(
plan,
(p: SparkPlan) =>
p match {
- case g: Convention.KnownBatchType => g.batchType()
- case _ => BatchType.VanillaBatch
+ case k: Convention.KnownBatchType =>
+ k.batchType()
+ case _ if !plan.supportsColumnar =>
+ Convention.BatchType.None
+ case _ =>
+ Convention.BatchType.VanillaBatch
}
)
+ if (out != Convention.BatchType.None) {
+ assert(plan.supportsColumnar)
+ }
+ out
+ }
+
+ override def conventionReqOf(plan: SparkPlan): ConventionReq = {
+ val out = conventionReqOf0(plan)
+ out
+ }
+
+ private def conventionReqOf0(plan: SparkPlan): ConventionReq = plan match {
+ case k: KnownChildrenConventions =>
+ val reqs = k.requiredChildrenConventions().distinct
+ // This can be a temporary restriction.
+ assert(
+ reqs.size == 1,
+ "KnownChildrenConventions#requiredChildrenConventions should output
the same element" +
+ " for all children")
+ reqs.head
+ case RowToColumnarLike(_) =>
+ ConventionReq.of(
+ ConventionReq.RowType.Is(Convention.RowType.VanillaRow),
+ ConventionReq.BatchType.Any)
+ case ColumnarToRowExec(_) =>
+ ConventionReq.of(
+ ConventionReq.RowType.Any,
+ ConventionReq.BatchType.Is(Convention.BatchType.VanillaBatch))
+ case write: DataWritingCommandExec if
SparkShimLoader.getSparkShims.isPlannedV1Write(write) =>
+ // To align with
ApplyColumnarRulesAndInsertTransitions#insertTransitions
+ ConventionReq.any
+ case u: UnionExec =>
+ // We force vanilla union to output row data to get best compatibility
with vanilla Spark.
+ // As a result it's a common practice to rewrite it with GlutenPlan
for offloading.
+ ConventionReq.of(
+ ConventionReq.RowType.Is(Convention.RowType.VanillaRow),
+ ConventionReq.BatchType.Any)
+ case other =>
+ // In the normal case, children's convention should follow parent
node's convention.
+ // Note, we don't have consider C2R / R2C here since they are already
removed by
+ // RemoveTransitions.
+ val thisConv = conventionOf0(other)
+ thisConv.asReq()
+ }
+ }
+
+ private object BuiltinFunc {
+ implicit private class ConventionOps(conv: Convention) {
+ def asReq(): ConventionReq = {
+ val rowTypeReq = conv.rowType match {
+ case Convention.RowType.None => ConventionReq.RowType.Any
+ case r => ConventionReq.RowType.Is(r)
+ }
+
+ val batchTypeReq = conv.batchType match {
+ case Convention.BatchType.None => ConventionReq.BatchType.Any
+ case b => ConventionReq.BatchType.Is(b)
+ }
+ ConventionReq.of(rowTypeReq, batchTypeReq)
+ }
}
}
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionReq.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionReq.scala
index aac2084a7..65422b380 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionReq.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionReq.scala
@@ -16,6 +16,10 @@
*/
package org.apache.gluten.extension.columnar.transition
+import org.apache.gluten.backendsapi.BackendsApiManager
+
+import org.apache.spark.sql.execution.SparkPlan
+
/**
* ConventionReq describes the requirement for [[Convention]]. This is mostly
used in determining
* the acceptable conventions for its children of a parent plan node.
@@ -50,5 +54,16 @@ object ConventionReq {
) extends ConventionReq
val any: ConventionReq = Impl(RowType.Any, BatchType.Any)
- def of(rowType: RowType, batchType: BatchType): ConventionReq = new
Impl(rowType, batchType)
+ val row: ConventionReq = Impl(RowType.Is(Convention.RowType.VanillaRow),
BatchType.Any)
+ val vanillaBatch: ConventionReq =
+ Impl(RowType.Any, BatchType.Is(Convention.BatchType.VanillaBatch))
+ lazy val backendBatch: ConventionReq =
+ Impl(RowType.Any,
BatchType.Is(BackendsApiManager.getSparkPlanExecApiInstance.batchType))
+
+ def get(plan: SparkPlan): ConventionReq =
ConventionFunc.create().conventionReqOf(plan)
+ def of(rowType: RowType, batchType: BatchType): ConventionReq =
Impl(rowType, batchType)
+
+ trait KnownChildrenConventions {
+ def requiredChildrenConventions(): Seq[ConventionReq]
+ }
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
index 9b745f94d..73a126f8d 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
@@ -29,7 +29,21 @@ import scala.collection.mutable
* [[org.apache.gluten.extension.columnar.transition.Convention.BatchType]]'s
definition.
*/
trait Transition {
- def apply(plan: SparkPlan): SparkPlan
+ final def apply(plan: SparkPlan): SparkPlan = {
+ val out = apply0(plan)
+ if (out.fastEquals(plan)) {
+ assert(
+ this == Transition.empty,
+ "TransitionDef.empty / Transition.empty should be used when defining
an empty transition.")
+ }
+ out
+ }
+
+ final def isEmpty: Boolean = {
+ this == Transition.empty
+ }
+
+ protected def apply0(plan: SparkPlan): SparkPlan
}
trait TransitionDef {
@@ -53,12 +67,15 @@ object Transition {
}
private class ChainedTransition(first: Transition, second: Transition)
extends Transition {
- override def apply(plan: SparkPlan): SparkPlan = {
+ override def apply0(plan: SparkPlan): SparkPlan = {
second(first(plan))
}
}
private def chain(first: Transition, second: Transition): Transition = {
+ if (first.isEmpty && second.isEmpty) {
+ return Transition.empty
+ }
new ChainedTransition(first, second)
}
@@ -72,6 +89,15 @@ object Transition {
}
}
+ final def satisfies(conv: Convention, req: ConventionReq): Boolean = {
+ val none = new Transition {
+ override protected def apply0(plan: SparkPlan): SparkPlan =
+ throw new UnsupportedOperationException()
+ }
+ val transition = findTransition(conv, req)(none)
+ transition.isEmpty
+ }
+
protected def findTransition(from: Convention, to: ConventionReq)(
orElse: => Transition): Transition
private[transition] def update(): MutableFactory
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala
index e0758cff7..d02aadd49 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala
@@ -17,16 +17,13 @@
package org.apache.gluten.extension.columnar.transition
import org.apache.gluten.backendsapi.BackendsApiManager
-import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.{SparkPlan, UnionExec}
-import org.apache.spark.sql.execution.command.DataWritingCommandExec
+import org.apache.spark.sql.execution.SparkPlan
import scala.annotation.tailrec
case class InsertTransitions(outputsColumnar: Boolean) extends Rule[SparkPlan]
{
- import InsertTransitions._
private val convFunc = ConventionFunc.create()
override def apply(plan: SparkPlan): SparkPlan = {
@@ -47,7 +44,7 @@ case class InsertTransitions(outputsColumnar: Boolean)
extends Rule[SparkPlan] {
if (node.children.isEmpty) {
return node
}
- val convReq = childrenConvReqOf(node)
+ val convReq = convFunc.conventionReqOf(node)
val newChildren = node.children.map {
child =>
val from = convFunc.conventionOf(child)
@@ -64,47 +61,6 @@ case class InsertTransitions(outputsColumnar: Boolean)
extends Rule[SparkPlan] {
}
node.withNewChildren(newChildren)
}
-
- private def childrenConvReqOf(node: SparkPlan): ConventionReq = node match {
- // TODO: Consider C2C transitions as well when we have some.
- case ColumnarToRowLike(_) | RowToColumnarLike(_) =>
- // C2R / R2C here since they are already removed by
- // RemoveTransitions.
- // It's current rule's mission to add C2Rs / R2Cs on demand.
- throw new IllegalStateException("Unreachable code")
- case write: DataWritingCommandExec if
SparkShimLoader.getSparkShims.isPlannedV1Write(write) =>
- // To align with ApplyColumnarRulesAndInsertTransitions#insertTransitions
- ConventionReq.any
- case u: UnionExec =>
- // We force vanilla union to output row data to get best compatibility
with vanilla Spark.
- // As a result it's a common practice to rewrite it with GlutenPlan for
offloading.
- ConventionReq.of(
- ConventionReq.RowType.Is(Convention.RowType.VanillaRow),
- ConventionReq.BatchType.Any)
- case other =>
- // In the normal case, children's convention should follow parent node's
convention.
- // Note, we don't have consider C2R / R2C here since they are already
removed by
- // RemoveTransitions.
- val thisConv = convFunc.conventionOf(other)
- thisConv.asReq()
- }
-}
-
-object InsertTransitions {
- implicit private class ConventionOps(conv: Convention) {
- def asReq(): ConventionReq = {
- val rowTypeReq = conv.rowType match {
- case Convention.RowType.None => ConventionReq.RowType.Any
- case r => ConventionReq.RowType.Is(r)
- }
-
- val batchTypeReq = conv.batchType match {
- case Convention.BatchType.None => ConventionReq.BatchType.Any
- case b => ConventionReq.BatchType.Is(b)
- }
- ConventionReq.of(rowTypeReq, batchTypeReq)
- }
- }
}
object RemoveTransitions extends Rule[SparkPlan] {
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/planner/plan/GlutenPlanModel.scala
b/gluten-core/src/main/scala/org/apache/gluten/planner/plan/GlutenPlanModel.scala
index f0ae4286f..7417d9a5d 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/planner/plan/GlutenPlanModel.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/planner/plan/GlutenPlanModel.scala
@@ -16,15 +16,19 @@
*/
package org.apache.gluten.planner.plan
+import org.apache.gluten.extension.columnar.transition.{Convention,
ConventionReq}
+import
org.apache.gluten.extension.columnar.transition.Convention.{KnownBatchType,
KnownRowType}
import org.apache.gluten.planner.metadata.GlutenMetadata
-import org.apache.gluten.planner.property.{ConventionDef, Conventions}
+import org.apache.gluten.planner.property.{Conv, ConvDef}
import org.apache.gluten.ras.{Metadata, PlanModel}
import org.apache.gluten.ras.property.PropertySet
+import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan}
+import org.apache.spark.sql.execution.{ColumnarToRowExec, LeafExecNode,
SparkPlan}
+import org.apache.spark.util.{SparkTaskUtil, TaskResources}
import java.util.Objects
@@ -36,25 +40,61 @@ object GlutenPlanModel {
case class GroupLeafExec(
groupId: Int,
metadata: GlutenMetadata,
- propertySet: PropertySet[SparkPlan])
- extends LeafExecNode {
+ constraintSet: PropertySet[SparkPlan])
+ extends LeafExecNode
+ with KnownBatchType
+ with KnownRowType {
+ private val req: Conv.Req =
constraintSet.get(ConvDef).asInstanceOf[Conv.Req]
+
override protected def doExecute(): RDD[InternalRow] = throw new
IllegalStateException()
override def output: Seq[Attribute] = metadata.schema().output
- override def supportsColumnar: Boolean =
- propertySet.get(ConventionDef) match {
- case Conventions.ROW_BASED => false
- case Conventions.VANILLA_COLUMNAR => true
- case Conventions.GLUTEN_COLUMNAR => true
- case Conventions.ANY => true
+
+ override def supportsColumnar(): Boolean = {
+ batchType != Convention.BatchType.None
+ }
+
+ override val batchType: Convention.BatchType = {
+ val out = req.req.requiredBatchType match {
+ case ConventionReq.BatchType.Any => Convention.BatchType.None
+ case ConventionReq.BatchType.Is(b) => b
+ }
+ out
+ }
+
+ override val rowType: Convention.RowType = {
+ val out = req.req.requiredRowType match {
+ case ConventionReq.RowType.Any => Convention.RowType.None
+ case ConventionReq.RowType.Is(r) => r
}
+ out
+ }
}
private object PlanModelImpl extends PlanModel[SparkPlan] {
+ private val fakeTc = SparkShimLoader.getSparkShims.createTestTaskContext()
+ private def fakeTc[T](body: => T): T = {
+ assert(!TaskResources.inSparkTask())
+ SparkTaskUtil.setTaskContext(fakeTc)
+ try {
+ body
+ } finally {
+ SparkTaskUtil.unsetTaskContext()
+ }
+ }
+
override def childrenOf(node: SparkPlan): Seq[SparkPlan] = node.children
- override def withNewChildren(node: SparkPlan, children: Seq[SparkPlan]):
SparkPlan = {
- node.withNewChildren(children)
- }
+ override def withNewChildren(node: SparkPlan, children: Seq[SparkPlan]):
SparkPlan =
+ node match {
+ case c2r: ColumnarToRowExec =>
+ // Workaround: To bypass the assertion in ColumnarToRowExec's code
if child is
+ // a group leaf.
+ fakeTc {
+ c2r.withNewChildren(children)
+ }
+ case other =>
+ other.withNewChildren(children)
+ }
override def hashCode(node: SparkPlan): Int = Objects.hashCode(node)
@@ -63,8 +103,8 @@ object GlutenPlanModel {
override def newGroupLeaf(
groupId: Int,
metadata: Metadata,
- propSet: PropertySet[SparkPlan]): SparkPlan =
- GroupLeafExec(groupId, metadata.asInstanceOf[GlutenMetadata], propSet)
+ constraintSet: PropertySet[SparkPlan]): SparkPlan =
+ GroupLeafExec(groupId, metadata.asInstanceOf[GlutenMetadata],
constraintSet)
override def isGroupLeaf(node: SparkPlan): Boolean = node match {
case _: GroupLeafExec => true
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/planner/property/Conv.scala
b/gluten-core/src/main/scala/org/apache/gluten/planner/property/Conv.scala
new file mode 100644
index 000000000..475f62920
--- /dev/null
+++ b/gluten-core/src/main/scala/org/apache/gluten/planner/property/Conv.scala
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.planner.property
+
+import org.apache.gluten.extension.columnar.transition.{Convention,
ConventionReq, Transition}
+import org.apache.gluten.ras.{Property, PropertyDef}
+import org.apache.gluten.ras.rule.{RasRule, Shape, Shapes}
+
+import org.apache.spark.sql.execution._
+
+sealed trait Conv extends Property[SparkPlan] {
+ import Conv._
+ override def definition(): PropertyDef[SparkPlan, _ <: Property[SparkPlan]]
= {
+ ConvDef
+ }
+
+ override def satisfies(other: Property[SparkPlan]): Boolean = {
+ val req = other.asInstanceOf[Req]
+ if (req.isAny) {
+ return true
+ }
+ val prop = this.asInstanceOf[Prop]
+ val out = Transition.factory.satisfies(prop.prop, req.req)
+ out
+ }
+}
+
+object Conv {
+ val any: Conv = Req(ConventionReq.any)
+
+ def of(conv: Convention): Conv = Prop(conv)
+ def req(req: ConventionReq): Conv = Req(req)
+
+ def get(plan: SparkPlan): Conv = {
+ Conv.of(Convention.get(plan))
+ }
+
+ def findTransition(from: Conv, to: Conv): Transition = {
+ val prop = from.asInstanceOf[Prop]
+ val req = to.asInstanceOf[Req]
+ val out = Transition.factory.findTransition(prop.prop, req.req, new
IllegalStateException())
+ out
+ }
+
+ case class Prop(prop: Convention) extends Conv
+ case class Req(req: ConventionReq) extends Conv {
+ def isAny: Boolean = {
+ req.requiredBatchType == ConventionReq.BatchType.Any &&
+ req.requiredRowType == ConventionReq.RowType.Any
+ }
+ }
+}
+
+object ConvDef extends PropertyDef[SparkPlan, Conv] {
+ // TODO: Should the convention-transparent ops (e.g., aqe shuffle read)
support
+ // convention-propagation. Probably need to refactor
getChildrenPropertyRequirements.
+ override def getProperty(plan: SparkPlan): Conv = {
+ conventionOf(plan)
+ }
+
+ private def conventionOf(plan: SparkPlan): Conv = {
+ val out = Conv.get(plan)
+ out
+ }
+
+ override def getChildrenConstraints(
+ constraint: Property[SparkPlan],
+ plan: SparkPlan): Seq[Conv] = {
+ val out = List.tabulate(plan.children.size)(_ =>
Conv.req(ConventionReq.get(plan)))
+ out
+ }
+
+ override def any(): Conv = Conv.any
+}
+
+case class ConvEnforcerRule(reqConv: Conv) extends RasRule[SparkPlan] {
+ override def shift(node: SparkPlan): Iterable[SparkPlan] = {
+ if (node.output.isEmpty) {
+ // Disable transitions for node that has output with empty schema.
+ return List.empty
+ }
+ val conv = Conv.get(node)
+ if (conv.satisfies(reqConv)) {
+ return List.empty
+ }
+ val transition = Conv.findTransition(conv, reqConv)
+ val after = transition.apply(node)
+ List(after)
+ }
+
+ override def shape(): Shape[SparkPlan] = Shapes.fixedHeight(1)
+}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/planner/property/Convention.scala
b/gluten-core/src/main/scala/org/apache/gluten/planner/property/Convention.scala
deleted file mode 100644
index 5fe96ab79..000000000
---
a/gluten-core/src/main/scala/org/apache/gluten/planner/property/Convention.scala
+++ /dev/null
@@ -1,147 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.gluten.planner.property
-
-import org.apache.gluten.execution.RowToColumnarExecBase
-import org.apache.gluten.extension.GlutenPlan
-import org.apache.gluten.extension.columnar.transition.{ColumnarToRowLike,
RowToColumnarLike, Transitions}
-import org.apache.gluten.planner.plan.GlutenPlanModel.GroupLeafExec
-import org.apache.gluten.ras.{Property, PropertyDef}
-import org.apache.gluten.ras.rule.{RasRule, Shape, Shapes}
-import org.apache.gluten.sql.shims.SparkShimLoader
-import org.apache.gluten.utils.PlanUtil
-
-import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec,
AQEShuffleReadExec, QueryStageExec}
-import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
-import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
-
-sealed trait Convention extends Property[SparkPlan] {
- override def definition(): PropertyDef[SparkPlan, _ <: Property[SparkPlan]]
= {
- ConventionDef
- }
-
- override def satisfies(other: Property[SparkPlan]): Boolean = other match {
- case Conventions.ANY => true
- case c: Convention => c == this
- case _ => throw new IllegalStateException()
- }
-}
-
-object Conventions {
- // FIXME: Velox and CH should have different conventions?
- case object ROW_BASED extends Convention
- case object VANILLA_COLUMNAR extends Convention
- case object GLUTEN_COLUMNAR extends Convention
- case object ANY extends Convention
-}
-
-object ConventionDef extends PropertyDef[SparkPlan, Convention] {
- // TODO: Should the convention-transparent ops (e.g., aqe shuffle read)
support
- // convention-propagation. Probably need to refactor
getChildrenPropertyRequirements.
- override def getProperty(plan: SparkPlan): Convention = plan match {
- case _: GroupLeafExec => throw new IllegalStateException()
- case other => conventionOf(other)
- }
-
- private def conventionOf(plan: SparkPlan): Convention = plan match {
- case g: GroupLeafExec => g.propertySet.get(ConventionDef)
- case ColumnarToRowExec(child) => Conventions.ROW_BASED
- case RowToColumnarExec(child) => Conventions.VANILLA_COLUMNAR
- case ColumnarToRowLike(child) => Conventions.ROW_BASED
- case RowToColumnarLike(child) => Conventions.GLUTEN_COLUMNAR
- case q: QueryStageExec => conventionOf(q.plan)
- case r: ReusedExchangeExec => conventionOf(r.child)
- case a: AdaptiveSparkPlanExec => conventionOf(a.executedPlan)
- case i: InMemoryTableScanExec => getCacheConvention(i)
- case p if canPropagateConvention(p) =>
- val childrenProps = p.children.map(conventionOf).distinct
- assert(childrenProps.size == 1)
- childrenProps.head
- case _: GlutenPlan => Conventions.GLUTEN_COLUMNAR
- case p if p.supportsColumnar => Conventions.VANILLA_COLUMNAR
- case p if SparkShimLoader.getSparkShims.supportsRowBased(p) =>
Conventions.ROW_BASED
- case other => throw new IllegalStateException(s"Unable to get convention
of $other")
- }
-
- override def getChildrenConstraints(
- constraint: Property[SparkPlan],
- plan: SparkPlan): Seq[Convention] = plan match {
- case ColumnarToRowExec(child) => Seq(Conventions.VANILLA_COLUMNAR)
- case ColumnarToRowLike(child) => Seq(Conventions.GLUTEN_COLUMNAR)
- case RowToColumnarLike(child) => Seq(Conventions.ROW_BASED)
- case p if canPropagateConvention(p) =>
- p.children.map(_ => constraint.asInstanceOf[Convention])
- case other =>
- val conv = conventionOf(other)
- other.children.map(_ => conv)
- }
-
- override def any(): Convention = Conventions.ANY
-
- private def canPropagateConvention(plan: SparkPlan): Boolean = plan match {
- case p: AQEShuffleReadExec => true
- case p: InputAdapter => true
- case p: WholeStageCodegenExec => true
- case _ => false
- }
-
- private def getCacheConvention(i: InMemoryTableScanExec): Convention = {
- if (PlanUtil.isGlutenTableCache(i)) {
- Conventions.GLUTEN_COLUMNAR
- } else if (i.supportsColumnar) {
- Conventions.VANILLA_COLUMNAR
- } else {
- Conventions.ROW_BASED
- }
- }
-}
-
-case class ConventionEnforcerRule(reqConv: Convention) extends
RasRule[SparkPlan] {
- override def shift(node: SparkPlan): Iterable[SparkPlan] = {
- if (node.output.isEmpty) {
- // Disable transitions for node that has output with empty schema.
- return List.empty
- }
- val conv = ConventionDef.getProperty(node)
- if (conv.satisfies(reqConv)) {
- return List.empty
- }
- (conv, reqConv) match {
- case (Conventions.VANILLA_COLUMNAR, Conventions.ROW_BASED) =>
- List(ColumnarToRowExec(node))
- case (Conventions.ROW_BASED, Conventions.VANILLA_COLUMNAR) =>
- List(RowToColumnarExec(node))
- case (Conventions.GLUTEN_COLUMNAR, Conventions.ROW_BASED) =>
- List(Transitions.toRowPlan(node))
- case (Conventions.ROW_BASED, Conventions.GLUTEN_COLUMNAR) =>
- val attempt = Transitions.toBackendBatchPlan(node)
- if (attempt.asInstanceOf[RowToColumnarExecBase].doValidate().isValid) {
- List(attempt)
- } else {
- List.empty
- }
- case (Conventions.VANILLA_COLUMNAR, Conventions.GLUTEN_COLUMNAR) =>
- List(Transitions.toBackendBatchPlan(ColumnarToRowExec(node)))
- case (Conventions.GLUTEN_COLUMNAR, Conventions.VANILLA_COLUMNAR) =>
- List(RowToColumnarExec(Transitions.toRowPlan(node)))
- case _ => List.empty
- }
- }
-
- override def shape(): Shape[SparkPlan] = Shapes.fixedHeight(1)
-}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/planner/property/GlutenPropertyModel.scala
b/gluten-core/src/main/scala/org/apache/gluten/planner/property/GlutenPropertyModel.scala
index a998c935d..115ab4471 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/planner/property/GlutenPropertyModel.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/planner/property/GlutenPropertyModel.scala
@@ -28,14 +28,14 @@ object GlutenPropertyModel {
private object PropertyModelImpl extends PropertyModel[SparkPlan] {
override def propertyDefs: Seq[PropertyDef[SparkPlan, _ <:
Property[SparkPlan]]] =
- Seq(ConventionDef)
+ Seq(ConvDef)
override def newEnforcerRuleFactory(
propertyDef: PropertyDef[SparkPlan, _ <: Property[SparkPlan]])
: EnforcerRuleFactory[SparkPlan] = (reqProp: Property[SparkPlan]) => {
propertyDef match {
- case ConventionDef =>
- Seq(ConventionEnforcerRule(reqProp.asInstanceOf[Convention]))
+ case ConvDef =>
+ Seq(ConvEnforcerRule(reqProp.asInstanceOf[Conv]))
}
}
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/PlanModel.scala
b/gluten-core/src/main/scala/org/apache/spark/util/SparkTaskUtil.scala
similarity index 63%
copy from gluten-ras/common/src/main/scala/org/apache/gluten/ras/PlanModel.scala
copy to gluten-core/src/main/scala/org/apache/spark/util/SparkTaskUtil.scala
index 34924ccbf..92a12b3c6 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/PlanModel.scala
+++ b/gluten-core/src/main/scala/org/apache/spark/util/SparkTaskUtil.scala
@@ -14,19 +14,16 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.gluten.ras
+package org.apache.spark.util
-import org.apache.gluten.ras.property.PropertySet
+import org.apache.spark.TaskContext
-trait PlanModel[T <: AnyRef] {
- // Trivial tree operations.
- def childrenOf(node: T): Seq[T]
- def withNewChildren(node: T, children: Seq[T]): T
- def hashCode(node: T): Int
- def equals(one: T, other: T): Boolean
+object SparkTaskUtil {
+ def setTaskContext(taskContext: TaskContext): Unit = {
+ TaskContext.setTaskContext(taskContext)
+ }
- // Group operations.
- def newGroupLeaf(groupId: Int, meta: Metadata, propSet: PropertySet[T]): T
- def isGroupLeaf(node: T): Boolean
- def getGroupId(node: T): Int
+ def unsetTaskContext(): Unit = {
+ TaskContext.unset()
+ }
}
diff --git
a/gluten-data/src/main/java/org/apache/gluten/columnarbatch/ArrowBatch.scala
b/gluten-data/src/main/java/org/apache/gluten/columnarbatch/ArrowBatch.scala
index 3f40793d9..58a88e1f4 100644
--- a/gluten-data/src/main/java/org/apache/gluten/columnarbatch/ArrowBatch.scala
+++ b/gluten-data/src/main/java/org/apache/gluten/columnarbatch/ArrowBatch.scala
@@ -17,7 +17,8 @@
package org.apache.gluten.columnarbatch
-import org.apache.gluten.extension.columnar.transition.Convention
+import org.apache.gluten.extension.columnar.transition.{Convention,
TransitionDef}
+import
org.apache.gluten.extension.columnar.transition.Convention.BatchType.VanillaBatch
import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan}
@@ -38,4 +39,8 @@ object ArrowBatch extends Convention.BatchType {
(plan: SparkPlan) => {
ColumnarToRowExec(plan)
})
+
+ // Arrow batch is one-way compatible with vanilla batch since it provides
valid
+ // #get<type>(...) implementations.
+ toBatch(VanillaBatch, TransitionDef.empty)
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/PlanModel.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/PlanModel.scala
index 34924ccbf..bac9d0b64 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/PlanModel.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/PlanModel.scala
@@ -26,7 +26,7 @@ trait PlanModel[T <: AnyRef] {
def equals(one: T, other: T): Boolean
// Group operations.
- def newGroupLeaf(groupId: Int, meta: Metadata, propSet: PropertySet[T]): T
+ def newGroupLeaf(groupId: Int, meta: Metadata, constraintSet:
PropertySet[T]): T
def isGroupLeaf(node: T): Boolean
def getGroupId(node: T): Int
}
diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
index 804d04d81..f705a2901 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
@@ -30,9 +30,7 @@ trait Optimization[T <: AnyRef] {
plan: T,
constraintSet: PropertySet[T],
altConstraintSets: Seq[PropertySet[T]]): RasPlanner[T]
-
- def propSetOf(plan: T): PropertySet[T]
-
+ def anyPropSet(): PropertySet[T]
def withNewConfig(confFunc: RasConfig => RasConfig): Optimization[T]
}
@@ -49,7 +47,7 @@ object Optimization {
implicit class OptimizationImplicits[T <: AnyRef](opt: Optimization[T]) {
def newPlanner(plan: T): RasPlanner[T] = {
- opt.newPlanner(plan, opt.propSetOf(plan), List.empty)
+ opt.newPlanner(plan, opt.anyPropSet(), List.empty)
}
def newPlanner(plan: T, constraintSet: PropertySet[T]): RasPlanner[T] = {
opt.newPlanner(plan, constraintSet, List.empty)
@@ -113,15 +111,6 @@ class Ras[T <: AnyRef] private (
// Node groups don't have user-defined cost, expect exception here.
metadataModel.metadataOf(dummyGroup)
}
- propertyModel.propertyDefs.foreach {
- propDef =>
- // Node groups don't have user-defined property, expect exception here.
- assertThrows(
- "Group is not allowed to return its property directly to optimizer
(optimizer already" +
- " knew that). It's expected to throw an exception when getting its
property but not") {
- propDef.getProperty(dummyGroup)
- }
- }
}
override def newPlanner(
@@ -131,7 +120,12 @@ class Ras[T <: AnyRef] private (
RasPlanner(this, altConstraintSets, constraintSet, plan)
}
- override def propSetOf(plan: T): PropertySet[T] =
propertySetFactory().get(plan)
+ override def anyPropSet(): PropertySet[T] = propertySetFactory().any()
+
+ private[ras] def propSetOf(plan: T): PropertySet[T] = {
+ val out = propertySetFactory().get(plan)
+ out
+ }
private[ras] def withNewChildren(node: T, newChildren: Seq[T]): T = {
val oldChildren = planModel.childrenOf(node)
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasGroup.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasGroup.scala
index b5e9c9891..9591fbb22 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasGroup.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasGroup.scala
@@ -22,7 +22,7 @@ import org.apache.gluten.ras.property.PropertySet
trait RasGroup[T <: AnyRef] {
def id(): Int
def clusterKey(): RasClusterKey
- def propSet(): PropertySet[T]
+ def constraintSet(): PropertySet[T]
def self(): T
def nodes(store: MemoStore[T]): Iterable[CanonicalNode[T]]
}
@@ -40,17 +40,17 @@ object RasGroup {
ras: Ras[T],
clusterKey: RasClusterKey,
override val id: Int,
- override val propSet: PropertySet[T])
+ override val constraintSet: PropertySet[T])
extends RasGroup[T] {
- private val groupLeaf: T = ras.planModel.newGroupLeaf(id,
clusterKey.metadata, propSet)
+ private val groupLeaf: T = ras.planModel.newGroupLeaf(id,
clusterKey.metadata, constraintSet)
override def clusterKey(): RasClusterKey = clusterKey
override def self(): T = groupLeaf
override def nodes(store: MemoStore[T]): Iterable[CanonicalNode[T]] = {
- store.getCluster(clusterKey).nodes().filter(n =>
n.propSet().satisfies(propSet))
+ store.getCluster(clusterKey).nodes().filter(n =>
n.propSet().satisfies(constraintSet))
}
override def toString(): String = {
- s"RasGroup(id=$id, clusterKey=$clusterKey, propSet=$propSet))"
+ s"RasGroup(id=$id, clusterKey=$clusterKey,
constraintSet=$constraintSet))"
}
}
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala
index 710a4e682..8c9b52605 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala
@@ -95,7 +95,11 @@ trait GroupNode[T <: AnyRef] extends RasNode[T] {
object GroupNode {
def apply[T <: AnyRef](ras: Ras[T], group: RasGroup[T]): GroupNode[T] = {
- new GroupNodeImpl[T](ras, group.self(), group.propSet(), group.id())
+ val self = group.self()
+ // Re-derive property set of group leaf. User should define an appropriate
conversion
+ // from group constraints to its output properties in property model or
plan model.
+ val propSet = ras.propSetOf(self)
+ new GroupNodeImpl[T](ras, self, propSet, group.id())
}
private class GroupNodeImpl[T <: AnyRef](
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
index 3db649b64..c4d3e4881 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
@@ -130,7 +130,7 @@ object ExhaustivePlanner {
private def applyEnforcerRules(): Unit = {
allGroups.foreach {
group =>
- val constraintSet = group.propSet()
+ val constraintSet = group.constraintSet()
val enforcerRules = enforcerRuleSet.rulesOf(constraintSet)
if (enforcerRules.nonEmpty) {
val shapes = enforcerRules.map(_.shape())
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala
index 6b4082c7e..3d94a9996 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala
@@ -64,7 +64,7 @@ object RuleApplier {
equiv =>
closure
.openFor(cKey)
- .memorize(equiv, ras.propertySetFactory().get(equiv))
+ .memorize(equiv, ras.anyPropSet())
}
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
index b420d8c29..d7d14cf3a 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
@@ -148,7 +148,7 @@ class GraphvizVisualizer[T <: AnyRef](ras: Ras[T],
memoState: MemoState[T], best
}
private def describeGroupVerbose(group: RasGroup[T]): String = {
- s"[Group ${group.id()}: ${group.propSet().getMap.values.toIndexedSeq}]"
+ s"[Group ${group.id()}:
${group.constraintSet().getMap.values.toIndexedSeq}]"
}
private def describeNode(
diff --git
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/OperationSuite.scala
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/OperationSuite.scala
index f1c319873..60ec2eedd 100644
---
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/OperationSuite.scala
+++
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/OperationSuite.scala
@@ -411,9 +411,12 @@ object OperationSuite {
equalsCount += 1
delegated.equals(one, other)
}
- override def newGroupLeaf(groupId: Int, metadata: Metadata, propSet:
PropertySet[T]): T = {
+ override def newGroupLeaf(
+ groupId: Int,
+ metadata: Metadata,
+ constraintSet: PropertySet[T]): T = {
newGroupLeafCount += 1
- delegated.newGroupLeaf(groupId, metadata, propSet)
+ delegated.newGroupLeaf(groupId, metadata, constraintSet)
}
override def isGroupLeaf(node: T): Boolean = {
isGroupLeafCount += 1
diff --git
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/PropertySuite.scala
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/PropertySuite.scala
index aed032226..eb4babe06 100644
--- a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/PropertySuite.scala
+++ b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/PropertySuite.scala
@@ -72,7 +72,7 @@ abstract class PropertySuite extends AnyFunSuite {
memo.memorize(ras, PassNodeType(1, PassNodeType(1, PassNodeType(1,
TypedLeaf(TypeB, 1)))))
val state = memo.newState()
assert(state.allClusters().size == 4)
- assert(state.getGroupCount() == 8)
+ assert(state.getGroupCount() == 4)
}
test(s"Get property") {
@@ -573,7 +573,7 @@ object PropertySuite {
override def any(): DummyProperty = DummyProperty(Int.MinValue)
override def getProperty(plan: TestNode): DummyProperty = {
plan match {
- case Group(_, _, _) => throw new IllegalStateException()
+ case g: Group => g.constraintSet.get(this)
case PUnary(_, prop, _) => prop
case PLeaf(_, prop) => prop
case PBinary(_, prop, _, _) => prop
@@ -645,7 +645,7 @@ object PropertySuite {
case class PassNodeType(override val selfCost: Long, child: TestNode)
extends TypedNode {
override def nodeType: NodeType = child match {
case n: TypedNode => n.nodeType
- case g: Group => g.propSet.get(NodeTypeDef)
+ case g: Group => g.constraintSet.get(NodeTypeDef)
case _ => throw new IllegalStateException()
}
@@ -669,7 +669,7 @@ object PropertySuite {
override def shift(node: TestNode): Iterable[TestNode] = {
node match {
case group: Group =>
- val groupType = group.propSet.get(NodeTypeDef)
+ val groupType = group.constraintSet.get(NodeTypeDef)
if (groupType.satisfies(reqType)) {
List(group)
} else {
@@ -710,6 +710,7 @@ object PropertySuite {
object NodeTypeDef extends PropertyDef[TestNode, NodeType] {
override def getProperty(plan: TestNode): NodeType = plan match {
+ case g: Group => g.constraintSet.get(this)
case typed: TypedNode => typed.nodeType
case _ => throw new IllegalStateException()
}
diff --git
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuiteBase.scala
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuiteBase.scala
index b5455d6af..65c4d5a07 100644
--- a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuiteBase.scala
+++ b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuiteBase.scala
@@ -49,7 +49,7 @@ object RasSuiteBase {
def withNewChildren(children: Seq[TestNode]): TestNode = this
}
- case class Group(id: Int, meta: Metadata, propSet: PropertySet[TestNode])
extends LeafLike {
+ case class Group(id: Int, meta: Metadata, constraintSet:
PropertySet[TestNode]) extends LeafLike {
override def selfCost(): Long = Long.MaxValue
override def makeCopy(): LeafLike = copy()
}
@@ -113,8 +113,8 @@ object RasSuiteBase {
override def newGroupLeaf(
groupId: Int,
meta: Metadata,
- propSet: PropertySet[TestNode]): TestNode =
- Group(groupId, meta, propSet)
+ constraintSet: PropertySet[TestNode]): TestNode =
+ Group(groupId, meta, constraintSet)
override def getGroupId(node: TestNode): Int = node match {
case ngl: Group => ngl.id
@@ -163,7 +163,7 @@ object RasSuiteBase {
implicit class MemoLikeImplicits[T <: AnyRef](val memo: MemoLike[T]) {
def memorize(ras: Ras[T], node: T): RasGroup[T] = {
- memo.memorize(node, ras.propSetOf(node))
+ memo.memorize(node, ras.anyPropSet())
}
}
diff --git
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockMemoState.scala
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockMemoState.scala
index 7bb713afe..37d66e2bd 100644
---
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockMemoState.scala
+++
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockMemoState.scala
@@ -121,7 +121,7 @@ object MockMemoState {
class MockMutableGroup[T <: AnyRef] private (
override val id: Int,
override val clusterKey: RasClusterKey,
- override val propSet: PropertySet[T],
+ override val constraintSet: PropertySet[T],
override val self: T)
extends RasGroup[T] {
private val nodes: mutable.ArrayBuffer[CanonicalNode[T]] =
mutable.ArrayBuffer()
diff --git
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockRasPath.scala
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockRasPath.scala
index cd8050e5f..bf267a4b6 100644
---
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockRasPath.scala
+++
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockRasPath.scala
@@ -27,7 +27,7 @@ object MockRasPath {
def mock[T <: AnyRef](ras: Ras[T], node: T, keys: PathKeySet): RasPath[T] = {
val memo = Memo(ras)
- val g = memo.memorize(node, ras.propSetOf(node))
+ val g = memo.memorize(node, ras.anyPropSet())
val state = memo.newState()
val groupSupplier = state.asGroupSupplier()
assert(g.nodes(state).size == 1)
diff --git
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/specific/DistributedSuite.scala
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/specific/DistributedSuite.scala
index 2aefc54e9..e930e4da2 100644
---
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/specific/DistributedSuite.scala
+++
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/specific/DistributedSuite.scala
@@ -251,6 +251,7 @@ object DistributedSuite {
private object DistributionDef extends PropertyDef[TestNode, Distribution] {
override def getProperty(plan: TestNode): Distribution = plan match {
+ case g: Group => g.constraintSet.get(this)
case d: DNode => d.getDistribution()
case _ =>
throw new UnsupportedOperationException()
@@ -308,6 +309,7 @@ object DistributedSuite {
// FIXME: Handle non-ordering as well as non-distribution
private object OrderingDef extends PropertyDef[TestNode, Ordering] {
override def getProperty(plan: TestNode): Ordering = plan match {
+ case g: Group => g.constraintSet.get(this)
case d: DNode => d.getOrdering()
case _ => throw new UnsupportedOperationException()
}
@@ -383,7 +385,7 @@ object DistributedSuite {
with UnaryLike {
override def getDistribution(): Distribution = {
val childDistribution = child match {
- case g: Group => g.propSet.get(DistributionDef)
+ case g: Group => g.constraintSet.get(DistributionDef)
case other => DistributionDef.getProperty(other)
}
if (childDistribution == NoneDistribution) {
@@ -415,7 +417,7 @@ object DistributedSuite {
extends DNode
with UnaryLike {
override def getDistribution(): Distribution = child match {
- case g: Group => g.propSet.get(DistributionDef)
+ case g: Group => g.constraintSet.get(DistributionDef)
case other => DistributionDef.getProperty(other)
}
@@ -433,7 +435,7 @@ object DistributedSuite {
with UnaryLike {
override def getDistribution(): Distribution = {
val childDistribution = child match {
- case g: Group => g.propSet.get(DistributionDef)
+ case g: Group => g.constraintSet.get(DistributionDef)
case other => DistributionDef.getProperty(other)
}
if (childDistribution == NoneDistribution) {
@@ -463,12 +465,12 @@ object DistributedSuite {
case class DProject(override val child: TestNode) extends DNode with
UnaryLike {
override def getDistribution(): Distribution = child match {
- case g: Group => g.propSet.get(DistributionDef)
+ case g: Group => g.constraintSet.get(DistributionDef)
case other => DistributionDef.getProperty(other)
}
override def getDistributionConstraints(req: Distribution):
Seq[Distribution] = List(req)
override def getOrdering(): Ordering = child match {
- case g: Group => g.propSet.get(OrderingDef)
+ case g: Group => g.constraintSet.get(OrderingDef)
case other => OrderingDef.getProperty(other)
}
override def getOrderingConstraints(req: Ordering): Seq[Ordering] =
List(req)
@@ -482,7 +484,7 @@ object DistributedSuite {
with UnaryLike {
override def getDistribution(): Distribution = {
val childDistribution = child match {
- case g: Group => g.propSet.get(DistributionDef)
+ case g: Group => g.constraintSet.get(DistributionDef)
case other => DistributionDef.getProperty(other)
}
if (childDistribution == NoneDistribution) {
@@ -501,13 +503,13 @@ object DistributedSuite {
case class DSort(keys: Seq[String], override val child: TestNode) extends
DNode with UnaryLike {
override def getDistribution(): Distribution = child match {
- case g: Group => g.propSet.get(DistributionDef)
+ case g: Group => g.constraintSet.get(DistributionDef)
case other => DistributionDef.getProperty(other)
}
override def getDistributionConstraints(req: Distribution):
Seq[Distribution] = List(req)
override def getOrdering(): Ordering = {
val childOrdering = child match {
- case g: Group => g.propSet.get(OrderingDef)
+ case g: Group => g.constraintSet.get(OrderingDef)
case other => OrderingDef.getProperty(other)
}
if (childOrdering.satisfies(SimpleOrdering(keys))) {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]