This is an automated email from the ASF dual-hosted git repository.

hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new d6b298221 [VL] RAS: Reuse same code path with heuristic planner for 
convention enforcement (#5824)
d6b298221 is described below

commit d6b298221f1360626e52862985f78abc7436183d
Author: Hongze Zhang <[email protected]>
AuthorDate: Wed May 22 17:55:00 2024 +0800

    [VL] RAS: Reuse same code path with heuristic planner for convention 
enforcement (#5824)
---
 .../backendsapi/velox/VeloxSparkPlanExecApi.scala  |   3 +-
 .../org/apache/gluten/planner/VeloxRasSuite.scala  |  15 ++-
 .../gluten/execution/ColumnarToRowExecBase.scala   |  10 +-
 .../org/apache/gluten/extension/GlutenPlan.scala   |   3 +-
 .../columnar/enumerated/EnumeratedTransform.scala  |  13 +-
 .../extension/columnar/transition/Convention.scala |   4 +
 .../columnar/transition/ConventionFunc.scala       | 115 +++++++++++++---
 .../columnar/transition/ConventionReq.scala        |  17 ++-
 .../extension/columnar/transition/Transition.scala |  30 ++++-
 .../columnar/transition/Transitions.scala          |  48 +------
 .../gluten/planner/plan/GlutenPlanModel.scala      |  70 +++++++---
 .../org/apache/gluten/planner/property/Conv.scala  | 106 +++++++++++++++
 .../gluten/planner/property/Convention.scala       | 147 ---------------------
 .../planner/property/GlutenPropertyModel.scala     |   6 +-
 .../org/apache/spark/util/SparkTaskUtil.scala      |  21 ++-
 .../apache/gluten/columnarbatch/ArrowBatch.scala   |   7 +-
 .../scala/org/apache/gluten/ras/PlanModel.scala    |   2 +-
 .../src/main/scala/org/apache/gluten/ras/Ras.scala |  22 ++-
 .../scala/org/apache/gluten/ras/RasGroup.scala     |  10 +-
 .../main/scala/org/apache/gluten/ras/RasNode.scala |   6 +-
 .../gluten/ras/exaustive/ExhaustivePlanner.scala   |   2 +-
 .../org/apache/gluten/ras/rule/RuleApplier.scala   |   2 +-
 .../apache/gluten/ras/vis/GraphvizVisualizer.scala |   2 +-
 .../org/apache/gluten/ras/OperationSuite.scala     |   7 +-
 .../org/apache/gluten/ras/PropertySuite.scala      |   9 +-
 .../scala/org/apache/gluten/ras/RasSuiteBase.scala |   8 +-
 .../org/apache/gluten/ras/mock/MockMemoState.scala |   2 +-
 .../org/apache/gluten/ras/mock/MockRasPath.scala   |   2 +-
 .../gluten/ras/specific/DistributedSuite.scala     |  18 +--
 29 files changed, 401 insertions(+), 306 deletions(-)

diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
index 2d37b1185..322116582 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
@@ -81,7 +81,8 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
    */
   override def batchTypeFunc(): BatchOverride = {
     case i: InMemoryTableScanExec
-        if 
i.relation.cacheBuilder.serializer.isInstanceOf[ColumnarCachedBatchSerializer] 
=>
+        if i.supportsColumnar && i.relation.cacheBuilder.serializer
+          .isInstanceOf[ColumnarCachedBatchSerializer] =>
       VeloxBatch
   }
 
diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/planner/VeloxRasSuite.scala 
b/backends-velox/src/test/scala/org/apache/gluten/planner/VeloxRasSuite.scala
index 4690ef516..ae2cea0ba 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/planner/VeloxRasSuite.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/planner/VeloxRasSuite.scala
@@ -16,7 +16,8 @@
  */
 package org.apache.gluten.planner
 
-import org.apache.gluten.planner.property.Conventions
+import org.apache.gluten.extension.columnar.transition.ConventionReq
+import org.apache.gluten.planner.property.Conv
 import org.apache.gluten.ras.Best.BestNotFoundException
 import org.apache.gluten.ras.Ras
 import org.apache.gluten.ras.RasSuiteBase._
@@ -44,7 +45,7 @@ class VeloxRasSuite extends SharedSparkSession {
   test("C2R, R2C - explicitly requires any properties") {
     val in = RowUnary(RowLeaf(TRIVIAL_SCHEMA))
     val planner =
-      newRas().newPlanner(in, PropertySet(List(Conventions.ANY)))
+      newRas().newPlanner(in, PropertySet(List(Conv.any)))
     val out = planner.plan()
     assert(out == RowUnary(RowLeaf(TRIVIAL_SCHEMA)))
   }
@@ -52,7 +53,7 @@ class VeloxRasSuite extends SharedSparkSession {
   test("C2R, R2C - requires columnar output") {
     val in = RowUnary(RowLeaf(TRIVIAL_SCHEMA))
     val planner =
-      newRas().newPlanner(in, PropertySet(List(Conventions.VANILLA_COLUMNAR)))
+      newRas().newPlanner(in, 
PropertySet(List(Conv.req(ConventionReq.vanillaBatch))))
     val out = planner.plan()
     assert(out == RowToColumnarExec(RowUnary(RowLeaf(TRIVIAL_SCHEMA))))
   }
@@ -63,7 +64,7 @@ class VeloxRasSuite extends SharedSparkSession {
         RowUnary(
           
RowUnary(ColumnarUnary(RowUnary(RowUnary(ColumnarUnary(RowLeaf(TRIVIAL_SCHEMA))))))))
     val planner =
-      newRas().newPlanner(in, PropertySet(List(Conventions.ROW_BASED)))
+      newRas().newPlanner(in, PropertySet(List(Conv.req(ConventionReq.row))))
     val out = planner.plan()
     assert(
       out == ColumnarToRowExec(
@@ -91,7 +92,7 @@ class VeloxRasSuite extends SharedSparkSession {
           
RowUnary(ColumnarUnary(RowUnary(RowUnary(ColumnarUnary(RowLeaf(TRIVIAL_SCHEMA))))))))
     val planner =
       newRas(List(ConvertRowUnaryToColumnar))
-        .newPlanner(in, PropertySet(List(Conventions.ROW_BASED)))
+        .newPlanner(in, PropertySet(List(Conv.req(ConventionReq.row))))
     val out = planner.plan()
     assert(out == 
ColumnarToRowExec(ColumnarUnary(ColumnarUnary(ColumnarUnary(ColumnarUnary(
       
ColumnarUnary(ColumnarUnary(ColumnarUnary(RowToColumnarExec(RowLeaf(TRIVIAL_SCHEMA)))))))))))
@@ -104,7 +105,7 @@ class VeloxRasSuite extends SharedSparkSession {
     val in = RowUnary(RowLeaf(EMPTY_SCHEMA))
 
     val planner =
-      newRas().newPlanner(in, PropertySet(List(Conventions.ANY)))
+      newRas().newPlanner(in, PropertySet(List(Conv.any)))
     val out = planner.plan()
     assert(out == RowUnary(RowLeaf(EMPTY_SCHEMA)))
 
@@ -112,7 +113,7 @@ class VeloxRasSuite extends SharedSparkSession {
       // Could not optimize to columnar output since R2C transitions for empty 
schema node
       // is not allowed.
       val planner2 =
-        newRas().newPlanner(in, 
PropertySet(List(Conventions.VANILLA_COLUMNAR)))
+        newRas().newPlanner(in, 
PropertySet(List(Conv.req(ConventionReq.vanillaBatch))))
       planner2.plan()
     }
   }
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/execution/ColumnarToRowExecBase.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/execution/ColumnarToRowExecBase.scala
index 6d3fa2dac..fd86106bf 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/execution/ColumnarToRowExecBase.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/execution/ColumnarToRowExecBase.scala
@@ -18,6 +18,8 @@ package org.apache.gluten.execution
 
 import org.apache.gluten.backendsapi.BackendsApiManager
 import org.apache.gluten.extension.GlutenPlan
+import org.apache.gluten.extension.columnar.transition.ConventionReq
+import 
org.apache.gluten.extension.columnar.transition.ConventionReq.KnownChildrenConventions
 
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.rdd.RDD
@@ -28,7 +30,8 @@ import 
org.apache.spark.sql.execution.{ColumnarToRowTransition, SparkPlan}
 
 abstract class ColumnarToRowExecBase(child: SparkPlan)
   extends ColumnarToRowTransition
-  with GlutenPlan {
+  with GlutenPlan
+  with KnownChildrenConventions {
 
   // Note: "metrics" is made transient to avoid sending driver-side metrics to 
tasks.
   @transient override lazy val metrics =
@@ -50,4 +53,9 @@ abstract class ColumnarToRowExecBase(child: SparkPlan)
   override def doExecute(): RDD[InternalRow] = {
     doExecuteInternal()
   }
+
+  override def requiredChildrenConventions(): Seq[ConventionReq] = {
+    List(ConventionReq.backendBatch)
+  }
+
 }
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala 
b/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala
index 033e44b8c..8f1004be4 100644
--- a/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala
+++ b/gluten-core/src/main/scala/org/apache/gluten/extension/GlutenPlan.scala
@@ -88,8 +88,7 @@ trait GlutenPlan extends SparkPlan with 
Convention.KnownBatchType with LogLevelU
 
   final override def batchType(): Convention.BatchType = {
     if (!supportsColumnar) {
-      throw new UnsupportedOperationException(
-        s"Node $nodeName doesn't support columnar-batch processing")
+      return Convention.BatchType.None
     }
     val batchType = batchType0()
     assert(batchType != Convention.BatchType.None)
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
index dc34bc1af..50f0dce13 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
@@ -17,8 +17,9 @@
 package org.apache.gluten.extension.columnar.enumerated
 
 import org.apache.gluten.extension.columnar.{OffloadExchange, OffloadJoin, 
OffloadOthers, OffloadSingleNode}
+import org.apache.gluten.extension.columnar.transition.ConventionReq
 import org.apache.gluten.planner.GlutenOptimization
-import org.apache.gluten.planner.property.Conventions
+import org.apache.gluten.planner.property.Conv
 import org.apache.gluten.ras.property.PropertySet
 import org.apache.gluten.utils.LogLevelUtil
 
@@ -48,9 +49,13 @@ case class EnumeratedTransform(session: SparkSession, 
outputsColumnar: Boolean)
 
   private val optimization = GlutenOptimization(rules ++ offloadRules)
 
-  private val reqConvention = Conventions.ANY
-  private val altConventions =
-    Seq(Conventions.GLUTEN_COLUMNAR, Conventions.ROW_BASED)
+  private val reqConvention = Conv.any
+
+  private val altConventions = {
+    val rowBased: Conv = Conv.req(ConventionReq.row)
+    val backendBatchBased: Conv = Conv.req(ConventionReq.backendBatch)
+    Seq(rowBased, backendBatchBased)
+  }
 
   override def apply(plan: SparkPlan): SparkPlan = {
     val constraintSet = PropertySet(List(reqConvention))
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala
index 2774497d9..034b45851 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Convention.scala
@@ -110,4 +110,8 @@ object Convention {
   trait KnownBatchType {
     def batchType(): BatchType
   }
+
+  trait KnownRowType {
+    def rowType(): RowType
+  }
 }
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionFunc.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionFunc.scala
index 28bd1d12c..453df5d88 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionFunc.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionFunc.scala
@@ -17,20 +17,22 @@
 package org.apache.gluten.extension.columnar.transition
 
 import org.apache.gluten.backendsapi.BackendsApiManager
-import org.apache.gluten.extension.columnar.transition.Convention.{BatchType, 
RowType}
+import 
org.apache.gluten.extension.columnar.transition.ConventionReq.KnownChildrenConventions
 import org.apache.gluten.sql.shims.SparkShimLoader
 
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, UnionExec}
 import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, 
QueryStageExec}
+import org.apache.spark.sql.execution.command.DataWritingCommandExec
 import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
 
-/** ConventionFunc is a utility to derive [[Convention]] from a query plan. */
+/** ConventionFunc is a utility to derive [[Convention]] or [[ConventionReq]] 
from a query plan. */
 trait ConventionFunc {
   def conventionOf(plan: SparkPlan): Convention
+  def conventionReqOf(plan: SparkPlan): ConventionReq
 }
 
 object ConventionFunc {
-  type BatchOverride = PartialFunction[SparkPlan, BatchType]
+  type BatchOverride = PartialFunction[SparkPlan, Convention.BatchType]
 
   // For testing, to make things work without a backend loaded.
   private var ignoreBackend: Boolean = false
@@ -47,18 +49,22 @@ object ConventionFunc {
   }
 
   def create(): ConventionFunc = {
+    val batchOverride = newOverride()
+    new BuiltinFunc(batchOverride)
+  }
+
+  private def newOverride(): BatchOverride = {
     synchronized {
       if (ignoreBackend) {
         // For testing
-        return new BuiltinFunc(PartialFunction.empty)
+        return PartialFunction.empty
       }
     }
-    val batchOverride = 
BackendsApiManager.getSparkPlanExecApiInstance.batchTypeFunc()
-    new BuiltinFunc(batchOverride)
+    BackendsApiManager.getSparkPlanExecApiInstance.batchTypeFunc()
   }
 
   private class BuiltinFunc(o: BatchOverride) extends ConventionFunc {
-
+    import BuiltinFunc._
     override def conventionOf(plan: SparkPlan): Convention = {
       val conv = conventionOf0(plan)
       conv
@@ -82,7 +88,7 @@ object ConventionFunc {
           // See 
org.apache.gluten.extension.columnar.transition.InsertTransitions.apply
           BackendsApiManager.getSparkPlanExecApiInstance.batchType
         } else {
-          BatchType.None
+          Convention.BatchType.None
         }
         val conv = Convention.of(rowType, batchType)
         conv
@@ -91,25 +97,94 @@ object ConventionFunc {
         conv
     }
 
-    private def rowTypeOf(plan: SparkPlan): RowType = {
-      if (!SparkShimLoader.getSparkShims.supportsRowBased(plan)) {
-        return RowType.None
+    private def rowTypeOf(plan: SparkPlan): Convention.RowType = {
+      val out = plan match {
+        case k: Convention.KnownRowType =>
+          k.rowType()
+        case _ if !SparkShimLoader.getSparkShims.supportsRowBased(plan) =>
+          Convention.RowType.None
+        case _ =>
+          Convention.RowType.VanillaRow
       }
-      RowType.VanillaRow
+      if (out != Convention.RowType.None) {
+        assert(SparkShimLoader.getSparkShims.supportsRowBased(plan))
+      }
+      out
     }
 
-    private def batchTypeOf(plan: SparkPlan): BatchType = {
-      if (!plan.supportsColumnar) {
-        return BatchType.None
-      }
-      o.applyOrElse(
+    private def batchTypeOf(plan: SparkPlan): Convention.BatchType = {
+      val out = o.applyOrElse(
         plan,
         (p: SparkPlan) =>
           p match {
-            case g: Convention.KnownBatchType => g.batchType()
-            case _ => BatchType.VanillaBatch
+            case k: Convention.KnownBatchType =>
+              k.batchType()
+            case _ if !plan.supportsColumnar =>
+              Convention.BatchType.None
+            case _ =>
+              Convention.BatchType.VanillaBatch
           }
       )
+      if (out != Convention.BatchType.None) {
+        assert(plan.supportsColumnar)
+      }
+      out
+    }
+
+    override def conventionReqOf(plan: SparkPlan): ConventionReq = {
+      val out = conventionReqOf0(plan)
+      out
+    }
+
+    private def conventionReqOf0(plan: SparkPlan): ConventionReq = plan match {
+      case k: KnownChildrenConventions =>
+        val reqs = k.requiredChildrenConventions().distinct
+        // This can be a temporary restriction.
+        assert(
+          reqs.size == 1,
+          "KnownChildrenConventions#requiredChildrenConventions should output 
the same element" +
+            " for all children")
+        reqs.head
+      case RowToColumnarLike(_) =>
+        ConventionReq.of(
+          ConventionReq.RowType.Is(Convention.RowType.VanillaRow),
+          ConventionReq.BatchType.Any)
+      case ColumnarToRowExec(_) =>
+        ConventionReq.of(
+          ConventionReq.RowType.Any,
+          ConventionReq.BatchType.Is(Convention.BatchType.VanillaBatch))
+      case write: DataWritingCommandExec if 
SparkShimLoader.getSparkShims.isPlannedV1Write(write) =>
+        // To align with 
ApplyColumnarRulesAndInsertTransitions#insertTransitions
+        ConventionReq.any
+      case u: UnionExec =>
+        // We force vanilla union to output row data to get best compatibility 
with vanilla Spark.
+        // As a result it's a common practice to rewrite it with GlutenPlan 
for offloading.
+        ConventionReq.of(
+          ConventionReq.RowType.Is(Convention.RowType.VanillaRow),
+          ConventionReq.BatchType.Any)
+      case other =>
+        // In the normal case, children's convention should follow parent 
node's convention.
+        // Note, we don't have consider C2R / R2C here since they are already 
removed by
+        // RemoveTransitions.
+        val thisConv = conventionOf0(other)
+        thisConv.asReq()
+    }
+  }
+
+  private object BuiltinFunc {
+    implicit private class ConventionOps(conv: Convention) {
+      def asReq(): ConventionReq = {
+        val rowTypeReq = conv.rowType match {
+          case Convention.RowType.None => ConventionReq.RowType.Any
+          case r => ConventionReq.RowType.Is(r)
+        }
+
+        val batchTypeReq = conv.batchType match {
+          case Convention.BatchType.None => ConventionReq.BatchType.Any
+          case b => ConventionReq.BatchType.Is(b)
+        }
+        ConventionReq.of(rowTypeReq, batchTypeReq)
+      }
     }
   }
 }
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionReq.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionReq.scala
index aac2084a7..65422b380 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionReq.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/ConventionReq.scala
@@ -16,6 +16,10 @@
  */
 package org.apache.gluten.extension.columnar.transition
 
+import org.apache.gluten.backendsapi.BackendsApiManager
+
+import org.apache.spark.sql.execution.SparkPlan
+
 /**
  * ConventionReq describes the requirement for [[Convention]]. This is mostly 
used in determining
  * the acceptable conventions for its children of a parent plan node.
@@ -50,5 +54,16 @@ object ConventionReq {
   ) extends ConventionReq
 
   val any: ConventionReq = Impl(RowType.Any, BatchType.Any)
-  def of(rowType: RowType, batchType: BatchType): ConventionReq = new 
Impl(rowType, batchType)
+  val row: ConventionReq = Impl(RowType.Is(Convention.RowType.VanillaRow), 
BatchType.Any)
+  val vanillaBatch: ConventionReq =
+    Impl(RowType.Any, BatchType.Is(Convention.BatchType.VanillaBatch))
+  lazy val backendBatch: ConventionReq =
+    Impl(RowType.Any, 
BatchType.Is(BackendsApiManager.getSparkPlanExecApiInstance.batchType))
+
+  def get(plan: SparkPlan): ConventionReq = 
ConventionFunc.create().conventionReqOf(plan)
+  def of(rowType: RowType, batchType: BatchType): ConventionReq = 
Impl(rowType, batchType)
+
+  trait KnownChildrenConventions {
+    def requiredChildrenConventions(): Seq[ConventionReq]
+  }
 }
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
index 9b745f94d..73a126f8d 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
@@ -29,7 +29,21 @@ import scala.collection.mutable
  * [[org.apache.gluten.extension.columnar.transition.Convention.BatchType]]'s 
definition.
  */
 trait Transition {
-  def apply(plan: SparkPlan): SparkPlan
+  final def apply(plan: SparkPlan): SparkPlan = {
+    val out = apply0(plan)
+    if (out.fastEquals(plan)) {
+      assert(
+        this == Transition.empty,
+        "TransitionDef.empty / Transition.empty should be used when defining 
an empty transition.")
+    }
+    out
+  }
+
+  final def isEmpty: Boolean = {
+    this == Transition.empty
+  }
+
+  protected def apply0(plan: SparkPlan): SparkPlan
 }
 
 trait TransitionDef {
@@ -53,12 +67,15 @@ object Transition {
   }
 
   private class ChainedTransition(first: Transition, second: Transition) 
extends Transition {
-    override def apply(plan: SparkPlan): SparkPlan = {
+    override def apply0(plan: SparkPlan): SparkPlan = {
       second(first(plan))
     }
   }
 
   private def chain(first: Transition, second: Transition): Transition = {
+    if (first.isEmpty && second.isEmpty) {
+      return Transition.empty
+    }
     new ChainedTransition(first, second)
   }
 
@@ -72,6 +89,15 @@ object Transition {
       }
     }
 
+    final def satisfies(conv: Convention, req: ConventionReq): Boolean = {
+      val none = new Transition {
+        override protected def apply0(plan: SparkPlan): SparkPlan =
+          throw new UnsupportedOperationException()
+      }
+      val transition = findTransition(conv, req)(none)
+      transition.isEmpty
+    }
+
     protected def findTransition(from: Convention, to: ConventionReq)(
         orElse: => Transition): Transition
     private[transition] def update(): MutableFactory
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala
index e0758cff7..d02aadd49 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transitions.scala
@@ -17,16 +17,13 @@
 package org.apache.gluten.extension.columnar.transition
 
 import org.apache.gluten.backendsapi.BackendsApiManager
-import org.apache.gluten.sql.shims.SparkShimLoader
 
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.{SparkPlan, UnionExec}
-import org.apache.spark.sql.execution.command.DataWritingCommandExec
+import org.apache.spark.sql.execution.SparkPlan
 
 import scala.annotation.tailrec
 
 case class InsertTransitions(outputsColumnar: Boolean) extends Rule[SparkPlan] 
{
-  import InsertTransitions._
   private val convFunc = ConventionFunc.create()
 
   override def apply(plan: SparkPlan): SparkPlan = {
@@ -47,7 +44,7 @@ case class InsertTransitions(outputsColumnar: Boolean) 
extends Rule[SparkPlan] {
     if (node.children.isEmpty) {
       return node
     }
-    val convReq = childrenConvReqOf(node)
+    val convReq = convFunc.conventionReqOf(node)
     val newChildren = node.children.map {
       child =>
         val from = convFunc.conventionOf(child)
@@ -64,47 +61,6 @@ case class InsertTransitions(outputsColumnar: Boolean) 
extends Rule[SparkPlan] {
     }
     node.withNewChildren(newChildren)
   }
-
-  private def childrenConvReqOf(node: SparkPlan): ConventionReq = node match {
-    // TODO: Consider C2C transitions as well when we have some.
-    case ColumnarToRowLike(_) | RowToColumnarLike(_) =>
-      // C2R / R2C here since they are already removed by
-      // RemoveTransitions.
-      // It's current rule's mission to add C2Rs / R2Cs on demand.
-      throw new IllegalStateException("Unreachable code")
-    case write: DataWritingCommandExec if 
SparkShimLoader.getSparkShims.isPlannedV1Write(write) =>
-      // To align with ApplyColumnarRulesAndInsertTransitions#insertTransitions
-      ConventionReq.any
-    case u: UnionExec =>
-      // We force vanilla union to output row data to get best compatibility 
with vanilla Spark.
-      // As a result it's a common practice to rewrite it with GlutenPlan for 
offloading.
-      ConventionReq.of(
-        ConventionReq.RowType.Is(Convention.RowType.VanillaRow),
-        ConventionReq.BatchType.Any)
-    case other =>
-      // In the normal case, children's convention should follow parent node's 
convention.
-      // Note, we don't have consider C2R / R2C here since they are already 
removed by
-      // RemoveTransitions.
-      val thisConv = convFunc.conventionOf(other)
-      thisConv.asReq()
-  }
-}
-
-object InsertTransitions {
-  implicit private class ConventionOps(conv: Convention) {
-    def asReq(): ConventionReq = {
-      val rowTypeReq = conv.rowType match {
-        case Convention.RowType.None => ConventionReq.RowType.Any
-        case r => ConventionReq.RowType.Is(r)
-      }
-
-      val batchTypeReq = conv.batchType match {
-        case Convention.BatchType.None => ConventionReq.BatchType.Any
-        case b => ConventionReq.BatchType.Is(b)
-      }
-      ConventionReq.of(rowTypeReq, batchTypeReq)
-    }
-  }
 }
 
 object RemoveTransitions extends Rule[SparkPlan] {
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/planner/plan/GlutenPlanModel.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/planner/plan/GlutenPlanModel.scala
index f0ae4286f..7417d9a5d 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/planner/plan/GlutenPlanModel.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/planner/plan/GlutenPlanModel.scala
@@ -16,15 +16,19 @@
  */
 package org.apache.gluten.planner.plan
 
+import org.apache.gluten.extension.columnar.transition.{Convention, 
ConventionReq}
+import 
org.apache.gluten.extension.columnar.transition.Convention.{KnownBatchType, 
KnownRowType}
 import org.apache.gluten.planner.metadata.GlutenMetadata
-import org.apache.gluten.planner.property.{ConventionDef, Conventions}
+import org.apache.gluten.planner.property.{Conv, ConvDef}
 import org.apache.gluten.ras.{Metadata, PlanModel}
 import org.apache.gluten.ras.property.PropertySet
+import org.apache.gluten.sql.shims.SparkShimLoader
 
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan}
+import org.apache.spark.sql.execution.{ColumnarToRowExec, LeafExecNode, 
SparkPlan}
+import org.apache.spark.util.{SparkTaskUtil, TaskResources}
 
 import java.util.Objects
 
@@ -36,25 +40,61 @@ object GlutenPlanModel {
   case class GroupLeafExec(
       groupId: Int,
       metadata: GlutenMetadata,
-      propertySet: PropertySet[SparkPlan])
-    extends LeafExecNode {
+      constraintSet: PropertySet[SparkPlan])
+    extends LeafExecNode
+    with KnownBatchType
+    with KnownRowType {
+    private val req: Conv.Req = 
constraintSet.get(ConvDef).asInstanceOf[Conv.Req]
+
     override protected def doExecute(): RDD[InternalRow] = throw new 
IllegalStateException()
     override def output: Seq[Attribute] = metadata.schema().output
-    override def supportsColumnar: Boolean =
-      propertySet.get(ConventionDef) match {
-        case Conventions.ROW_BASED => false
-        case Conventions.VANILLA_COLUMNAR => true
-        case Conventions.GLUTEN_COLUMNAR => true
-        case Conventions.ANY => true
+
+    override def supportsColumnar(): Boolean = {
+      batchType != Convention.BatchType.None
+    }
+
+    override val batchType: Convention.BatchType = {
+      val out = req.req.requiredBatchType match {
+        case ConventionReq.BatchType.Any => Convention.BatchType.None
+        case ConventionReq.BatchType.Is(b) => b
+      }
+      out
+    }
+
+    override val rowType: Convention.RowType = {
+      val out = req.req.requiredRowType match {
+        case ConventionReq.RowType.Any => Convention.RowType.None
+        case ConventionReq.RowType.Is(r) => r
       }
+      out
+    }
   }
 
   private object PlanModelImpl extends PlanModel[SparkPlan] {
+    private val fakeTc = SparkShimLoader.getSparkShims.createTestTaskContext()
+    private def fakeTc[T](body: => T): T = {
+      assert(!TaskResources.inSparkTask())
+      SparkTaskUtil.setTaskContext(fakeTc)
+      try {
+        body
+      } finally {
+        SparkTaskUtil.unsetTaskContext()
+      }
+    }
+
     override def childrenOf(node: SparkPlan): Seq[SparkPlan] = node.children
 
-    override def withNewChildren(node: SparkPlan, children: Seq[SparkPlan]): 
SparkPlan = {
-      node.withNewChildren(children)
-    }
+    override def withNewChildren(node: SparkPlan, children: Seq[SparkPlan]): 
SparkPlan =
+      node match {
+        case c2r: ColumnarToRowExec =>
+          // Workaround: To bypass the assertion in ColumnarToRowExec's code 
if child is
+          // a group leaf.
+          fakeTc {
+            c2r.withNewChildren(children)
+          }
+        case other =>
+          other.withNewChildren(children)
+      }
 
     override def hashCode(node: SparkPlan): Int = Objects.hashCode(node)
 
@@ -63,8 +103,8 @@ object GlutenPlanModel {
     override def newGroupLeaf(
         groupId: Int,
         metadata: Metadata,
-        propSet: PropertySet[SparkPlan]): SparkPlan =
-      GroupLeafExec(groupId, metadata.asInstanceOf[GlutenMetadata], propSet)
+        constraintSet: PropertySet[SparkPlan]): SparkPlan =
+      GroupLeafExec(groupId, metadata.asInstanceOf[GlutenMetadata], 
constraintSet)
 
     override def isGroupLeaf(node: SparkPlan): Boolean = node match {
       case _: GroupLeafExec => true
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/planner/property/Conv.scala 
b/gluten-core/src/main/scala/org/apache/gluten/planner/property/Conv.scala
new file mode 100644
index 000000000..475f62920
--- /dev/null
+++ b/gluten-core/src/main/scala/org/apache/gluten/planner/property/Conv.scala
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.planner.property
+
+import org.apache.gluten.extension.columnar.transition.{Convention, 
ConventionReq, Transition}
+import org.apache.gluten.ras.{Property, PropertyDef}
+import org.apache.gluten.ras.rule.{RasRule, Shape, Shapes}
+
+import org.apache.spark.sql.execution._
+
+sealed trait Conv extends Property[SparkPlan] {
+  import Conv._
+  override def definition(): PropertyDef[SparkPlan, _ <: Property[SparkPlan]] 
= {
+    ConvDef
+  }
+
+  override def satisfies(other: Property[SparkPlan]): Boolean = {
+    val req = other.asInstanceOf[Req]
+    if (req.isAny) {
+      return true
+    }
+    val prop = this.asInstanceOf[Prop]
+    val out = Transition.factory.satisfies(prop.prop, req.req)
+    out
+  }
+}
+
+object Conv {
+  val any: Conv = Req(ConventionReq.any)
+
+  def of(conv: Convention): Conv = Prop(conv)
+  def req(req: ConventionReq): Conv = Req(req)
+
+  def get(plan: SparkPlan): Conv = {
+    Conv.of(Convention.get(plan))
+  }
+
+  def findTransition(from: Conv, to: Conv): Transition = {
+    val prop = from.asInstanceOf[Prop]
+    val req = to.asInstanceOf[Req]
+    val out = Transition.factory.findTransition(prop.prop, req.req, new 
IllegalStateException())
+    out
+  }
+
+  case class Prop(prop: Convention) extends Conv
+  case class Req(req: ConventionReq) extends Conv {
+    def isAny: Boolean = {
+      req.requiredBatchType == ConventionReq.BatchType.Any &&
+      req.requiredRowType == ConventionReq.RowType.Any
+    }
+  }
+}
+
+object ConvDef extends PropertyDef[SparkPlan, Conv] {
+  // TODO: Should the convention-transparent ops (e.g., aqe shuffle read) 
support
+  //  convention-propagation. Probably need to refactor 
getChildrenPropertyRequirements.
+  override def getProperty(plan: SparkPlan): Conv = {
+    conventionOf(plan)
+  }
+
+  private def conventionOf(plan: SparkPlan): Conv = {
+    val out = Conv.get(plan)
+    out
+  }
+
+  override def getChildrenConstraints(
+      constraint: Property[SparkPlan],
+      plan: SparkPlan): Seq[Conv] = {
+    val out = List.tabulate(plan.children.size)(_ => 
Conv.req(ConventionReq.get(plan)))
+    out
+  }
+
+  override def any(): Conv = Conv.any
+}
+
+case class ConvEnforcerRule(reqConv: Conv) extends RasRule[SparkPlan] {
+  override def shift(node: SparkPlan): Iterable[SparkPlan] = {
+    if (node.output.isEmpty) {
+      // Disable transitions for node that has output with empty schema.
+      return List.empty
+    }
+    val conv = Conv.get(node)
+    if (conv.satisfies(reqConv)) {
+      return List.empty
+    }
+    val transition = Conv.findTransition(conv, reqConv)
+    val after = transition.apply(node)
+    List(after)
+  }
+
+  override def shape(): Shape[SparkPlan] = Shapes.fixedHeight(1)
+}
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/planner/property/Convention.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/planner/property/Convention.scala
deleted file mode 100644
index 5fe96ab79..000000000
--- 
a/gluten-core/src/main/scala/org/apache/gluten/planner/property/Convention.scala
+++ /dev/null
@@ -1,147 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.gluten.planner.property
-
-import org.apache.gluten.execution.RowToColumnarExecBase
-import org.apache.gluten.extension.GlutenPlan
-import org.apache.gluten.extension.columnar.transition.{ColumnarToRowLike, 
RowToColumnarLike, Transitions}
-import org.apache.gluten.planner.plan.GlutenPlanModel.GroupLeafExec
-import org.apache.gluten.ras.{Property, PropertyDef}
-import org.apache.gluten.ras.rule.{RasRule, Shape, Shapes}
-import org.apache.gluten.sql.shims.SparkShimLoader
-import org.apache.gluten.utils.PlanUtil
-
-import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, 
AQEShuffleReadExec, QueryStageExec}
-import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
-import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
-
-sealed trait Convention extends Property[SparkPlan] {
-  override def definition(): PropertyDef[SparkPlan, _ <: Property[SparkPlan]] 
= {
-    ConventionDef
-  }
-
-  override def satisfies(other: Property[SparkPlan]): Boolean = other match {
-    case Conventions.ANY => true
-    case c: Convention => c == this
-    case _ => throw new IllegalStateException()
-  }
-}
-
-object Conventions {
-  // FIXME: Velox and CH should have different conventions?
-  case object ROW_BASED extends Convention
-  case object VANILLA_COLUMNAR extends Convention
-  case object GLUTEN_COLUMNAR extends Convention
-  case object ANY extends Convention
-}
-
-object ConventionDef extends PropertyDef[SparkPlan, Convention] {
-  // TODO: Should the convention-transparent ops (e.g., aqe shuffle read) 
support
-  //  convention-propagation. Probably need to refactor 
getChildrenPropertyRequirements.
-  override def getProperty(plan: SparkPlan): Convention = plan match {
-    case _: GroupLeafExec => throw new IllegalStateException()
-    case other => conventionOf(other)
-  }
-
-  private def conventionOf(plan: SparkPlan): Convention = plan match {
-    case g: GroupLeafExec => g.propertySet.get(ConventionDef)
-    case ColumnarToRowExec(child) => Conventions.ROW_BASED
-    case RowToColumnarExec(child) => Conventions.VANILLA_COLUMNAR
-    case ColumnarToRowLike(child) => Conventions.ROW_BASED
-    case RowToColumnarLike(child) => Conventions.GLUTEN_COLUMNAR
-    case q: QueryStageExec => conventionOf(q.plan)
-    case r: ReusedExchangeExec => conventionOf(r.child)
-    case a: AdaptiveSparkPlanExec => conventionOf(a.executedPlan)
-    case i: InMemoryTableScanExec => getCacheConvention(i)
-    case p if canPropagateConvention(p) =>
-      val childrenProps = p.children.map(conventionOf).distinct
-      assert(childrenProps.size == 1)
-      childrenProps.head
-    case _: GlutenPlan => Conventions.GLUTEN_COLUMNAR
-    case p if p.supportsColumnar => Conventions.VANILLA_COLUMNAR
-    case p if SparkShimLoader.getSparkShims.supportsRowBased(p) => 
Conventions.ROW_BASED
-    case other => throw new IllegalStateException(s"Unable to get convention 
of $other")
-  }
-
-  override def getChildrenConstraints(
-      constraint: Property[SparkPlan],
-      plan: SparkPlan): Seq[Convention] = plan match {
-    case ColumnarToRowExec(child) => Seq(Conventions.VANILLA_COLUMNAR)
-    case ColumnarToRowLike(child) => Seq(Conventions.GLUTEN_COLUMNAR)
-    case RowToColumnarLike(child) => Seq(Conventions.ROW_BASED)
-    case p if canPropagateConvention(p) =>
-      p.children.map(_ => constraint.asInstanceOf[Convention])
-    case other =>
-      val conv = conventionOf(other)
-      other.children.map(_ => conv)
-  }
-
-  override def any(): Convention = Conventions.ANY
-
-  private def canPropagateConvention(plan: SparkPlan): Boolean = plan match {
-    case p: AQEShuffleReadExec => true
-    case p: InputAdapter => true
-    case p: WholeStageCodegenExec => true
-    case _ => false
-  }
-
-  private def getCacheConvention(i: InMemoryTableScanExec): Convention = {
-    if (PlanUtil.isGlutenTableCache(i)) {
-      Conventions.GLUTEN_COLUMNAR
-    } else if (i.supportsColumnar) {
-      Conventions.VANILLA_COLUMNAR
-    } else {
-      Conventions.ROW_BASED
-    }
-  }
-}
-
-case class ConventionEnforcerRule(reqConv: Convention) extends 
RasRule[SparkPlan] {
-  override def shift(node: SparkPlan): Iterable[SparkPlan] = {
-    if (node.output.isEmpty) {
-      // Disable transitions for node that has output with empty schema.
-      return List.empty
-    }
-    val conv = ConventionDef.getProperty(node)
-    if (conv.satisfies(reqConv)) {
-      return List.empty
-    }
-    (conv, reqConv) match {
-      case (Conventions.VANILLA_COLUMNAR, Conventions.ROW_BASED) =>
-        List(ColumnarToRowExec(node))
-      case (Conventions.ROW_BASED, Conventions.VANILLA_COLUMNAR) =>
-        List(RowToColumnarExec(node))
-      case (Conventions.GLUTEN_COLUMNAR, Conventions.ROW_BASED) =>
-        List(Transitions.toRowPlan(node))
-      case (Conventions.ROW_BASED, Conventions.GLUTEN_COLUMNAR) =>
-        val attempt = Transitions.toBackendBatchPlan(node)
-        if (attempt.asInstanceOf[RowToColumnarExecBase].doValidate().isValid) {
-          List(attempt)
-        } else {
-          List.empty
-        }
-      case (Conventions.VANILLA_COLUMNAR, Conventions.GLUTEN_COLUMNAR) =>
-        List(Transitions.toBackendBatchPlan(ColumnarToRowExec(node)))
-      case (Conventions.GLUTEN_COLUMNAR, Conventions.VANILLA_COLUMNAR) =>
-        List(RowToColumnarExec(Transitions.toRowPlan(node)))
-      case _ => List.empty
-    }
-  }
-
-  override def shape(): Shape[SparkPlan] = Shapes.fixedHeight(1)
-}
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/planner/property/GlutenPropertyModel.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/planner/property/GlutenPropertyModel.scala
index a998c935d..115ab4471 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/planner/property/GlutenPropertyModel.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/planner/property/GlutenPropertyModel.scala
@@ -28,14 +28,14 @@ object GlutenPropertyModel {
 
   private object PropertyModelImpl extends PropertyModel[SparkPlan] {
     override def propertyDefs: Seq[PropertyDef[SparkPlan, _ <: 
Property[SparkPlan]]] =
-      Seq(ConventionDef)
+      Seq(ConvDef)
 
     override def newEnforcerRuleFactory(
         propertyDef: PropertyDef[SparkPlan, _ <: Property[SparkPlan]])
         : EnforcerRuleFactory[SparkPlan] = (reqProp: Property[SparkPlan]) => {
       propertyDef match {
-        case ConventionDef =>
-          Seq(ConventionEnforcerRule(reqProp.asInstanceOf[Convention]))
+        case ConvDef =>
+          Seq(ConvEnforcerRule(reqProp.asInstanceOf[Conv]))
       }
     }
   }
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/PlanModel.scala 
b/gluten-core/src/main/scala/org/apache/spark/util/SparkTaskUtil.scala
similarity index 63%
copy from gluten-ras/common/src/main/scala/org/apache/gluten/ras/PlanModel.scala
copy to gluten-core/src/main/scala/org/apache/spark/util/SparkTaskUtil.scala
index 34924ccbf..92a12b3c6 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/PlanModel.scala
+++ b/gluten-core/src/main/scala/org/apache/spark/util/SparkTaskUtil.scala
@@ -14,19 +14,16 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.gluten.ras
+package org.apache.spark.util
 
-import org.apache.gluten.ras.property.PropertySet
+import org.apache.spark.TaskContext
 
-trait PlanModel[T <: AnyRef] {
-  // Trivial tree operations.
-  def childrenOf(node: T): Seq[T]
-  def withNewChildren(node: T, children: Seq[T]): T
-  def hashCode(node: T): Int
-  def equals(one: T, other: T): Boolean
+object SparkTaskUtil {
+  def setTaskContext(taskContext: TaskContext): Unit = {
+    TaskContext.setTaskContext(taskContext)
+  }
 
-  // Group operations.
-  def newGroupLeaf(groupId: Int, meta: Metadata, propSet: PropertySet[T]): T
-  def isGroupLeaf(node: T): Boolean
-  def getGroupId(node: T): Int
+  def unsetTaskContext(): Unit = {
+    TaskContext.unset()
+  }
 }
diff --git 
a/gluten-data/src/main/java/org/apache/gluten/columnarbatch/ArrowBatch.scala 
b/gluten-data/src/main/java/org/apache/gluten/columnarbatch/ArrowBatch.scala
index 3f40793d9..58a88e1f4 100644
--- a/gluten-data/src/main/java/org/apache/gluten/columnarbatch/ArrowBatch.scala
+++ b/gluten-data/src/main/java/org/apache/gluten/columnarbatch/ArrowBatch.scala
@@ -17,7 +17,8 @@
 
 package org.apache.gluten.columnarbatch
 
-import org.apache.gluten.extension.columnar.transition.Convention
+import org.apache.gluten.extension.columnar.transition.{Convention, 
TransitionDef}
+import 
org.apache.gluten.extension.columnar.transition.Convention.BatchType.VanillaBatch
 
 import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan}
 
@@ -38,4 +39,8 @@ object ArrowBatch extends Convention.BatchType {
       (plan: SparkPlan) => {
         ColumnarToRowExec(plan)
       })
+
+  // Arrow batch is one-way compatible with vanilla batch since it provides 
valid
+  // #get<type>(...) implementations.
+  toBatch(VanillaBatch, TransitionDef.empty)
 }
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/PlanModel.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/PlanModel.scala
index 34924ccbf..bac9d0b64 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/PlanModel.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/PlanModel.scala
@@ -26,7 +26,7 @@ trait PlanModel[T <: AnyRef] {
   def equals(one: T, other: T): Boolean
 
   // Group operations.
-  def newGroupLeaf(groupId: Int, meta: Metadata, propSet: PropertySet[T]): T
+  def newGroupLeaf(groupId: Int, meta: Metadata, constraintSet: 
PropertySet[T]): T
   def isGroupLeaf(node: T): Boolean
   def getGroupId(node: T): Int
 }
diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
index 804d04d81..f705a2901 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
@@ -30,9 +30,7 @@ trait Optimization[T <: AnyRef] {
       plan: T,
       constraintSet: PropertySet[T],
       altConstraintSets: Seq[PropertySet[T]]): RasPlanner[T]
-
-  def propSetOf(plan: T): PropertySet[T]
-
+  def anyPropSet(): PropertySet[T]
   def withNewConfig(confFunc: RasConfig => RasConfig): Optimization[T]
 }
 
@@ -49,7 +47,7 @@ object Optimization {
 
   implicit class OptimizationImplicits[T <: AnyRef](opt: Optimization[T]) {
     def newPlanner(plan: T): RasPlanner[T] = {
-      opt.newPlanner(plan, opt.propSetOf(plan), List.empty)
+      opt.newPlanner(plan, opt.anyPropSet(), List.empty)
     }
     def newPlanner(plan: T, constraintSet: PropertySet[T]): RasPlanner[T] = {
       opt.newPlanner(plan, constraintSet, List.empty)
@@ -113,15 +111,6 @@ class Ras[T <: AnyRef] private (
       // Node groups don't have user-defined cost, expect exception here.
       metadataModel.metadataOf(dummyGroup)
     }
-    propertyModel.propertyDefs.foreach {
-      propDef =>
-        // Node groups don't have user-defined property, expect exception here.
-        assertThrows(
-          "Group is not allowed to return its property directly to optimizer 
(optimizer already" +
-            " knew that). It's expected to throw an exception when getting its 
property but not") {
-          propDef.getProperty(dummyGroup)
-        }
-    }
   }
 
   override def newPlanner(
@@ -131,7 +120,12 @@ class Ras[T <: AnyRef] private (
     RasPlanner(this, altConstraintSets, constraintSet, plan)
   }
 
-  override def propSetOf(plan: T): PropertySet[T] = 
propertySetFactory().get(plan)
+  override def anyPropSet(): PropertySet[T] = propertySetFactory().any()
+
+  private[ras] def propSetOf(plan: T): PropertySet[T] = {
+    val out = propertySetFactory().get(plan)
+    out
+  }
 
   private[ras] def withNewChildren(node: T, newChildren: Seq[T]): T = {
     val oldChildren = planModel.childrenOf(node)
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasGroup.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasGroup.scala
index b5e9c9891..9591fbb22 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasGroup.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasGroup.scala
@@ -22,7 +22,7 @@ import org.apache.gluten.ras.property.PropertySet
 trait RasGroup[T <: AnyRef] {
   def id(): Int
   def clusterKey(): RasClusterKey
-  def propSet(): PropertySet[T]
+  def constraintSet(): PropertySet[T]
   def self(): T
   def nodes(store: MemoStore[T]): Iterable[CanonicalNode[T]]
 }
@@ -40,17 +40,17 @@ object RasGroup {
       ras: Ras[T],
       clusterKey: RasClusterKey,
       override val id: Int,
-      override val propSet: PropertySet[T])
+      override val constraintSet: PropertySet[T])
     extends RasGroup[T] {
-    private val groupLeaf: T = ras.planModel.newGroupLeaf(id, 
clusterKey.metadata, propSet)
+    private val groupLeaf: T = ras.planModel.newGroupLeaf(id, 
clusterKey.metadata, constraintSet)
 
     override def clusterKey(): RasClusterKey = clusterKey
     override def self(): T = groupLeaf
     override def nodes(store: MemoStore[T]): Iterable[CanonicalNode[T]] = {
-      store.getCluster(clusterKey).nodes().filter(n => 
n.propSet().satisfies(propSet))
+      store.getCluster(clusterKey).nodes().filter(n => 
n.propSet().satisfies(constraintSet))
     }
     override def toString(): String = {
-      s"RasGroup(id=$id, clusterKey=$clusterKey, propSet=$propSet))"
+      s"RasGroup(id=$id, clusterKey=$clusterKey, 
constraintSet=$constraintSet))"
     }
   }
 }
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala
index 710a4e682..8c9b52605 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala
@@ -95,7 +95,11 @@ trait GroupNode[T <: AnyRef] extends RasNode[T] {
 
 object GroupNode {
   def apply[T <: AnyRef](ras: Ras[T], group: RasGroup[T]): GroupNode[T] = {
-    new GroupNodeImpl[T](ras, group.self(), group.propSet(), group.id())
+    val self = group.self()
+    // Re-derive property set of group leaf. User should define an appropriate 
conversion
+    // from group constraints to its output properties in property model or 
plan model.
+    val propSet = ras.propSetOf(self)
+    new GroupNodeImpl[T](ras, self, propSet, group.id())
   }
 
   private class GroupNodeImpl[T <: AnyRef](
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
index 3db649b64..c4d3e4881 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
@@ -130,7 +130,7 @@ object ExhaustivePlanner {
     private def applyEnforcerRules(): Unit = {
       allGroups.foreach {
         group =>
-          val constraintSet = group.propSet()
+          val constraintSet = group.constraintSet()
           val enforcerRules = enforcerRuleSet.rulesOf(constraintSet)
           if (enforcerRules.nonEmpty) {
             val shapes = enforcerRules.map(_.shape())
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala
index 6b4082c7e..3d94a9996 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala
@@ -64,7 +64,7 @@ object RuleApplier {
         equiv =>
           closure
             .openFor(cKey)
-            .memorize(equiv, ras.propertySetFactory().get(equiv))
+            .memorize(equiv, ras.anyPropSet())
       }
     }
 
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
index b420d8c29..d7d14cf3a 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
@@ -148,7 +148,7 @@ class GraphvizVisualizer[T <: AnyRef](ras: Ras[T], 
memoState: MemoState[T], best
   }
 
   private def describeGroupVerbose(group: RasGroup[T]): String = {
-    s"[Group ${group.id()}: ${group.propSet().getMap.values.toIndexedSeq}]"
+    s"[Group ${group.id()}: 
${group.constraintSet().getMap.values.toIndexedSeq}]"
   }
 
   private def describeNode(
diff --git 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/OperationSuite.scala 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/OperationSuite.scala
index f1c319873..60ec2eedd 100644
--- 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/OperationSuite.scala
+++ 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/OperationSuite.scala
@@ -411,9 +411,12 @@ object OperationSuite {
       equalsCount += 1
       delegated.equals(one, other)
     }
-    override def newGroupLeaf(groupId: Int, metadata: Metadata, propSet: 
PropertySet[T]): T = {
+    override def newGroupLeaf(
+        groupId: Int,
+        metadata: Metadata,
+        constraintSet: PropertySet[T]): T = {
       newGroupLeafCount += 1
-      delegated.newGroupLeaf(groupId, metadata, propSet)
+      delegated.newGroupLeaf(groupId, metadata, constraintSet)
     }
     override def isGroupLeaf(node: T): Boolean = {
       isGroupLeafCount += 1
diff --git 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/PropertySuite.scala 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/PropertySuite.scala
index aed032226..eb4babe06 100644
--- a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/PropertySuite.scala
+++ b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/PropertySuite.scala
@@ -72,7 +72,7 @@ abstract class PropertySuite extends AnyFunSuite {
     memo.memorize(ras, PassNodeType(1, PassNodeType(1, PassNodeType(1, 
TypedLeaf(TypeB, 1)))))
     val state = memo.newState()
     assert(state.allClusters().size == 4)
-    assert(state.getGroupCount() == 8)
+    assert(state.getGroupCount() == 4)
   }
 
   test(s"Get property") {
@@ -573,7 +573,7 @@ object PropertySuite {
     override def any(): DummyProperty = DummyProperty(Int.MinValue)
     override def getProperty(plan: TestNode): DummyProperty = {
       plan match {
-        case Group(_, _, _) => throw new IllegalStateException()
+        case g: Group => g.constraintSet.get(this)
         case PUnary(_, prop, _) => prop
         case PLeaf(_, prop) => prop
         case PBinary(_, prop, _, _) => prop
@@ -645,7 +645,7 @@ object PropertySuite {
   case class PassNodeType(override val selfCost: Long, child: TestNode) 
extends TypedNode {
     override def nodeType: NodeType = child match {
       case n: TypedNode => n.nodeType
-      case g: Group => g.propSet.get(NodeTypeDef)
+      case g: Group => g.constraintSet.get(NodeTypeDef)
       case _ => throw new IllegalStateException()
     }
 
@@ -669,7 +669,7 @@ object PropertySuite {
     override def shift(node: TestNode): Iterable[TestNode] = {
       node match {
         case group: Group =>
-          val groupType = group.propSet.get(NodeTypeDef)
+          val groupType = group.constraintSet.get(NodeTypeDef)
           if (groupType.satisfies(reqType)) {
             List(group)
           } else {
@@ -710,6 +710,7 @@ object PropertySuite {
 
   object NodeTypeDef extends PropertyDef[TestNode, NodeType] {
     override def getProperty(plan: TestNode): NodeType = plan match {
+      case g: Group => g.constraintSet.get(this)
       case typed: TypedNode => typed.nodeType
       case _ => throw new IllegalStateException()
     }
diff --git 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuiteBase.scala 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuiteBase.scala
index b5455d6af..65c4d5a07 100644
--- a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuiteBase.scala
+++ b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuiteBase.scala
@@ -49,7 +49,7 @@ object RasSuiteBase {
     def withNewChildren(children: Seq[TestNode]): TestNode = this
   }
 
-  case class Group(id: Int, meta: Metadata, propSet: PropertySet[TestNode]) 
extends LeafLike {
+  case class Group(id: Int, meta: Metadata, constraintSet: 
PropertySet[TestNode]) extends LeafLike {
     override def selfCost(): Long = Long.MaxValue
     override def makeCopy(): LeafLike = copy()
   }
@@ -113,8 +113,8 @@ object RasSuiteBase {
     override def newGroupLeaf(
         groupId: Int,
         meta: Metadata,
-        propSet: PropertySet[TestNode]): TestNode =
-      Group(groupId, meta, propSet)
+        constraintSet: PropertySet[TestNode]): TestNode =
+      Group(groupId, meta, constraintSet)
 
     override def getGroupId(node: TestNode): Int = node match {
       case ngl: Group => ngl.id
@@ -163,7 +163,7 @@ object RasSuiteBase {
 
   implicit class MemoLikeImplicits[T <: AnyRef](val memo: MemoLike[T]) {
     def memorize(ras: Ras[T], node: T): RasGroup[T] = {
-      memo.memorize(node, ras.propSetOf(node))
+      memo.memorize(node, ras.anyPropSet())
     }
   }
 
diff --git 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockMemoState.scala
 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockMemoState.scala
index 7bb713afe..37d66e2bd 100644
--- 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockMemoState.scala
+++ 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockMemoState.scala
@@ -121,7 +121,7 @@ object MockMemoState {
   class MockMutableGroup[T <: AnyRef] private (
       override val id: Int,
       override val clusterKey: RasClusterKey,
-      override val propSet: PropertySet[T],
+      override val constraintSet: PropertySet[T],
       override val self: T)
     extends RasGroup[T] {
     private val nodes: mutable.ArrayBuffer[CanonicalNode[T]] = 
mutable.ArrayBuffer()
diff --git 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockRasPath.scala 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockRasPath.scala
index cd8050e5f..bf267a4b6 100644
--- 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockRasPath.scala
+++ 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockRasPath.scala
@@ -27,7 +27,7 @@ object MockRasPath {
 
   def mock[T <: AnyRef](ras: Ras[T], node: T, keys: PathKeySet): RasPath[T] = {
     val memo = Memo(ras)
-    val g = memo.memorize(node, ras.propSetOf(node))
+    val g = memo.memorize(node, ras.anyPropSet())
     val state = memo.newState()
     val groupSupplier = state.asGroupSupplier()
     assert(g.nodes(state).size == 1)
diff --git 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/specific/DistributedSuite.scala
 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/specific/DistributedSuite.scala
index 2aefc54e9..e930e4da2 100644
--- 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/specific/DistributedSuite.scala
+++ 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/specific/DistributedSuite.scala
@@ -251,6 +251,7 @@ object DistributedSuite {
 
   private object DistributionDef extends PropertyDef[TestNode, Distribution] {
     override def getProperty(plan: TestNode): Distribution = plan match {
+      case g: Group => g.constraintSet.get(this)
       case d: DNode => d.getDistribution()
       case _ =>
         throw new UnsupportedOperationException()
@@ -308,6 +309,7 @@ object DistributedSuite {
   // FIXME: Handle non-ordering as well as non-distribution
   private object OrderingDef extends PropertyDef[TestNode, Ordering] {
     override def getProperty(plan: TestNode): Ordering = plan match {
+      case g: Group => g.constraintSet.get(this)
       case d: DNode => d.getOrdering()
       case _ => throw new UnsupportedOperationException()
     }
@@ -383,7 +385,7 @@ object DistributedSuite {
     with UnaryLike {
     override def getDistribution(): Distribution = {
       val childDistribution = child match {
-        case g: Group => g.propSet.get(DistributionDef)
+        case g: Group => g.constraintSet.get(DistributionDef)
         case other => DistributionDef.getProperty(other)
       }
       if (childDistribution == NoneDistribution) {
@@ -415,7 +417,7 @@ object DistributedSuite {
     extends DNode
     with UnaryLike {
     override def getDistribution(): Distribution = child match {
-      case g: Group => g.propSet.get(DistributionDef)
+      case g: Group => g.constraintSet.get(DistributionDef)
       case other => DistributionDef.getProperty(other)
     }
 
@@ -433,7 +435,7 @@ object DistributedSuite {
     with UnaryLike {
     override def getDistribution(): Distribution = {
       val childDistribution = child match {
-        case g: Group => g.propSet.get(DistributionDef)
+        case g: Group => g.constraintSet.get(DistributionDef)
         case other => DistributionDef.getProperty(other)
       }
       if (childDistribution == NoneDistribution) {
@@ -463,12 +465,12 @@ object DistributedSuite {
 
   case class DProject(override val child: TestNode) extends DNode with 
UnaryLike {
     override def getDistribution(): Distribution = child match {
-      case g: Group => g.propSet.get(DistributionDef)
+      case g: Group => g.constraintSet.get(DistributionDef)
       case other => DistributionDef.getProperty(other)
     }
     override def getDistributionConstraints(req: Distribution): 
Seq[Distribution] = List(req)
     override def getOrdering(): Ordering = child match {
-      case g: Group => g.propSet.get(OrderingDef)
+      case g: Group => g.constraintSet.get(OrderingDef)
       case other => OrderingDef.getProperty(other)
     }
     override def getOrderingConstraints(req: Ordering): Seq[Ordering] = 
List(req)
@@ -482,7 +484,7 @@ object DistributedSuite {
     with UnaryLike {
     override def getDistribution(): Distribution = {
       val childDistribution = child match {
-        case g: Group => g.propSet.get(DistributionDef)
+        case g: Group => g.constraintSet.get(DistributionDef)
         case other => DistributionDef.getProperty(other)
       }
       if (childDistribution == NoneDistribution) {
@@ -501,13 +503,13 @@ object DistributedSuite {
 
   case class DSort(keys: Seq[String], override val child: TestNode) extends 
DNode with UnaryLike {
     override def getDistribution(): Distribution = child match {
-      case g: Group => g.propSet.get(DistributionDef)
+      case g: Group => g.constraintSet.get(DistributionDef)
       case other => DistributionDef.getProperty(other)
     }
     override def getDistributionConstraints(req: Distribution): 
Seq[Distribution] = List(req)
     override def getOrdering(): Ordering = {
       val childOrdering = child match {
-        case g: Group => g.propSet.get(OrderingDef)
+        case g: Group => g.constraintSet.get(OrderingDef)
         case other => OrderingDef.getProperty(other)
       }
       if (childOrdering.satisfies(SimpleOrdering(keys))) {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to