This is an automated email from the ASF dual-hosted git repository.

hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 19da3bb1c [CORE][VL] RAS: Group expansion support (#5323)
19da3bb1c is described below

commit 19da3bb1cc8677913c097d7c4ccd1fc17189c438
Author: Hongze Zhang <[email protected]>
AuthorDate: Mon Apr 8 18:16:45 2024 +0800

    [CORE][VL] RAS: Group expansion support (#5323)
---
 .../src/main/scala/org/apache/gluten/ras/Ras.scala |   6 +-
 .../main/scala/org/apache/gluten/ras/RasNode.scala |   2 +-
 .../org/apache/gluten/ras/best/BestFinder.scala    |   2 -
 .../gluten/ras/best/GroupBasedBestFinder.scala     |   2 +-
 .../scala/org/apache/gluten/ras/dp/DpPlanner.scala |  23 ++--
 .../gluten/ras/exaustive/ExhaustivePlanner.scala   |  21 ++--
 .../apache/gluten/ras/memo/ForwardMemoTable.scala  |  25 +++--
 .../scala/org/apache/gluten/ras/memo/Memo.scala    |  28 +++--
 .../org/apache/gluten/ras/memo/MemoTable.scala     |  17 ++-
 .../org/apache/gluten/ras/path/OutputFilter.scala  |  33 ++++--
 .../org/apache/gluten/ras/path/OutputWizard.scala  | 100 +++++++++++-------
 .../org/apache/gluten/ras/path/PathFinder.scala    |  35 +++++--
 .../scala/org/apache/gluten/ras/path/Pattern.scala |   9 +-
 .../org/apache/gluten/ras/rule/RuleApplier.scala   |  12 +--
 .../scala/org/apache/gluten/ras/rule/Shape.scala   |  10 ++
 .../org/apache/gluten/ras/PropertySuite.scala      | 116 ++++++++++++++++-----
 .../scala/org/apache/gluten/ras/RasSuite.scala     |  67 ++++++++++++
 .../scala/org/apache/gluten/ras/RasSuiteBase.scala |   2 +-
 .../org/apache/gluten/ras/mock/MockMemoState.scala |   5 +
 .../org/apache/gluten/ras/mock/MockRasPath.scala   |   2 +-
 .../apache/gluten/ras/path/PathFinderSuite.scala   |  60 ++++++++++-
 .../org/apache/gluten/ras/path/WizardSuite.scala   |  14 +++
 .../gluten/ras/specific/DistributedSuite.scala     |   2 +
 23 files changed, 449 insertions(+), 144 deletions(-)

diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
index 9910fab6f..f3d46847e 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
@@ -31,7 +31,7 @@ trait Optimization[T <: AnyRef] {
       constraintSet: PropertySet[T],
       altConstraintSets: Seq[PropertySet[T]]): RasPlanner[T]
 
-  def propSetsOf(plan: T): PropertySet[T]
+  def propSetOf(plan: T): PropertySet[T]
 
   def withNewConfig(confFunc: RasConfig => RasConfig): Optimization[T]
 }
@@ -49,7 +49,7 @@ object Optimization {
 
   implicit class OptimizationImplicits[T <: AnyRef](opt: Optimization[T]) {
     def newPlanner(plan: T): RasPlanner[T] = {
-      opt.newPlanner(plan, opt.propSetsOf(plan), List.empty)
+      opt.newPlanner(plan, opt.propSetOf(plan), List.empty)
     }
     def newPlanner(plan: T, constraintSet: PropertySet[T]): RasPlanner[T] = {
       opt.newPlanner(plan, constraintSet, List.empty)
@@ -131,7 +131,7 @@ class Ras[T <: AnyRef] private (
     RasPlanner(this, altConstraintSets, constraintSet, plan)
   }
 
-  override def propSetsOf(plan: T): PropertySet[T] = 
propertySetFactory().get(plan)
+  override def propSetOf(plan: T): PropertySet[T] = 
propertySetFactory().get(plan)
 
   private[ras] def withNewChildren(node: T, newChildren: Seq[T]): T = {
     val oldChildren = planModel.childrenOf(node)
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala
index 878020391..65ff8b735 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala
@@ -54,7 +54,7 @@ trait CanonicalNode[T <: AnyRef] extends RasNode[T] {
 object CanonicalNode {
   def apply[T <: AnyRef](ras: Ras[T], canonical: T): CanonicalNode[T] = {
     assert(ras.isCanonical(canonical))
-    val propSet = ras.propSetsOf(canonical)
+    val propSet = ras.propSetOf(canonical)
     val children = ras.planModel.childrenOf(canonical)
     new CanonicalNodeImpl[T](ras, canonical, propSet, children.size)
   }
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/BestFinder.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/BestFinder.scala
index 0912ab536..90a0adfb2 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/BestFinder.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/BestFinder.scala
@@ -52,12 +52,10 @@ object BestFinder {
 
   private[best] def newBest[T <: AnyRef](
       ras: Ras[T],
-      allGroups: Seq[RasGroup[T]],
       group: RasGroup[T],
       groupToCosts: Map[Int, KnownCostGroup[T]]): Best[T] = {
 
     val bestPath = groupToCosts(group.id()).best()
-    val bestRoot = bestPath.rasPath.node()
     val winnerNodes = groupToCosts.map { case (id, g) => InGroupNode(id, 
g.bestNode) }.toSeq
     val costsMap = mutable.Map[InGroupNode.HashKey, Cost]()
     groupToCosts.foreach {
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/GroupBasedBestFinder.scala
 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/GroupBasedBestFinder.scala
index 6db3600de..effebd41b 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/GroupBasedBestFinder.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/GroupBasedBestFinder.scala
@@ -43,7 +43,7 @@ private class GroupBasedBestFinder[T <: AnyRef](
         s"Best path not found. Memo state (Graphviz): \n" +
           s"${memoState.formatGraphvizWithoutBest(groupId)}")
     }
-    BestFinder.newBest(ras, allGroups, group, groupToCosts)
+    BestFinder.newBest(ras, group, groupToCosts)
   }
 
   private def fillBests(group: RasGroup[T]): Map[Int, KnownCostGroup[T]] = {
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
index 1be728ae6..8acf66c59 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
@@ -99,8 +99,6 @@ object DpPlanner {
       rules: Seq[RuleApplier[T]],
       enforcerRuleSet: EnforcerRuleSet[T])
     extends DpClusterAlgo.Adjustment[T] {
-    private val allGroups = memoTable.allGroups()
-    private val clusterLookup = cKey => memoTable.getCluster(cKey)
 
     override def exploreChildX(
         panel: Panel[InClusterNode[T], RasClusterKey],
@@ -127,33 +125,30 @@ object DpPlanner {
       if (rules.isEmpty) {
         return
       }
-      val cluster = clusterLookup(cKey)
-      cluster.nodes().foreach {
-        node =>
-          val shapes = rules.map(_.shape())
-          findPaths(node, shapes)(path => rules.foreach(rule => 
applyRule(panel, cKey, rule, path)))
+      val dummyGroup = memoTable.getDummyGroup(cKey)
+      val shapes = rules.map(_.shape())
+      findPaths(GroupNode(ras, dummyGroup), shapes) {
+        path => rules.foreach(rule => applyRule(panel, cKey, rule, path))
       }
     }
 
     private def applyEnforcerRules(
         panel: Panel[InClusterNode[T], RasClusterKey],
         cKey: RasClusterKey): Unit = {
-      val cluster = clusterLookup(cKey)
+      val dummyGroup = memoTable.getDummyGroup(cKey)
       cKey.propSets(memoTable).foreach {
         constraintSet =>
           val enforcerRules = enforcerRuleSet.rulesOf(constraintSet)
           if (enforcerRules.nonEmpty) {
             val shapes = enforcerRules.map(_.shape())
-            cluster.nodes().foreach {
-              node =>
-                findPaths(node, shapes)(
-                  path => enforcerRules.foreach(rule => applyRule(panel, cKey, 
rule, path)))
+            findPaths(GroupNode(ras, dummyGroup), shapes) {
+              path => enforcerRules.foreach(rule => applyRule(panel, cKey, 
rule, path))
             }
           }
       }
     }
 
-    private def findPaths(canonical: CanonicalNode[T], shapes: Seq[Shape[T]])(
+    private def findPaths(gn: GroupNode[T], shapes: Seq[Shape[T]])(
         onFound: RasPath[T] => Unit): Unit = {
       val finder = shapes
         .foldLeft(
@@ -163,7 +158,7 @@ object DpPlanner {
             builder.output(shape.wizard())
         }
         .build()
-      finder.find(canonical).foreach(path => onFound(path))
+      finder.find(gn).foreach(path => onFound(path))
     }
 
     private def applyRule(
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
index 47945fc14..3db649b64 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
@@ -94,7 +94,7 @@ object ExhaustivePlanner {
       applyRules()
     }
 
-    private def findPaths(canonical: CanonicalNode[T], shapes: Seq[Shape[T]])(
+    private def findPaths(gn: GroupNode[T], shapes: Seq[Shape[T]])(
         onFound: RasPath[T] => Unit): Unit = {
       val finder = shapes
         .foldLeft(
@@ -104,7 +104,7 @@ object ExhaustivePlanner {
             builder.output(shape.wizard())
         }
         .build()
-      finder.find(canonical).foreach(path => onFound(path))
+      finder.find(gn).foreach(path => onFound(path))
     }
 
     private def applyRule(rule: RuleApplier[T], icp: InClusterPath[T]): Unit = 
{
@@ -120,12 +120,10 @@ object ExhaustivePlanner {
         .clusterLookup()
         .foreach {
           case (cKey, cluster) =>
-            cluster
-              .nodes()
-              .foreach(
-                node =>
-                  findPaths(node, shapes)(
-                    path => rules.foreach(rule => applyRule(rule, 
InClusterPath(cKey, path)))))
+            val dummyGroup = memoState.getDummyGroup(cKey)
+            findPaths(GroupNode(ras, dummyGroup), shapes) {
+              path => rules.foreach(rule => applyRule(rule, 
InClusterPath(cKey, path)))
+            }
         }
     }
 
@@ -137,10 +135,9 @@ object ExhaustivePlanner {
           if (enforcerRules.nonEmpty) {
             val shapes = enforcerRules.map(_.shape())
             val cKey = group.clusterKey()
-            memoState.clusterLookup()(cKey).nodes().foreach {
-              node =>
-                findPaths(node, shapes)(
-                  path => enforcerRules.foreach(rule => applyRule(rule, 
InClusterPath(cKey, path))))
+            val dummyGroup = memoState.getDummyGroup(cKey)
+            findPaths(GroupNode(ras, dummyGroup), shapes) {
+              path => enforcerRules.foreach(rule => applyRule(rule, 
InClusterPath(cKey, path)))
             }
           }
       }
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
index e3ae03ebf..0895544d5 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
@@ -29,11 +29,12 @@ class ForwardMemoTable[T <: AnyRef] private (override val 
ras: Ras[T])
   import ForwardMemoTable._
 
   private val groupBuffer: mutable.ArrayBuffer[RasGroup[T]] = 
mutable.ArrayBuffer()
+  private val dummyGroupBuffer: mutable.ArrayBuffer[RasGroup[T]] =
+    mutable.ArrayBuffer[RasGroup[T]]()
 
   private val clusterKeyBuffer: mutable.ArrayBuffer[IntClusterKey] = 
mutable.ArrayBuffer()
   private val clusterBuffer: mutable.ArrayBuffer[MutableRasCluster[T]] = 
mutable.ArrayBuffer()
   private val clusterDisjointSet: IndexDisjointSet = IndexDisjointSet()
-  private val clusterDummyGroupBuffer = mutable.ArrayBuffer[RasGroup[T]]()
 
   private val groupLookup: mutable.ArrayBuffer[mutable.Map[PropertySet[T], 
RasGroup[T]]] =
     mutable.ArrayBuffer()
@@ -55,13 +56,16 @@ class ForwardMemoTable[T <: AnyRef] private (override val 
ras: Ras[T])
     clusterDisjointSet.grow()
     groupLookup += mutable.Map()
     // Normal groups start with ID 0, so it's safe to use negative IDs for 
dummy groups.
-    clusterDummyGroupBuffer += RasGroup(ras, key, -clusterId, 
ras.propertySetFactory().any())
+    // Dummy group ID starts from -1.
+    dummyGroupBuffer += RasGroup(ras, key, -(clusterId + 1), 
ras.propertySetFactory().any())
     key
   }
 
-  override def dummyGroupOf(key: RasClusterKey): RasGroup[T] = {
+  override def getDummyGroup(key: RasClusterKey): RasGroup[T] = {
     val ancestor = ancestorClusterIdOf(key)
-    clusterDummyGroupBuffer(ancestor)
+    val out = dummyGroupBuffer(ancestor)
+    assert(out.id() == -(ancestor + 1))
+    out
   }
 
   override def groupOf(key: RasClusterKey, propSet: PropertySet[T]): 
RasGroup[T] = {
@@ -142,12 +146,21 @@ class ForwardMemoTable[T <: AnyRef] private (override val 
ras: Ras[T])
     memoWriteCount += 1
   }
 
-  override def getGroup(id: Int): RasGroup[T] = groupBuffer(id)
+  override def getGroup(id: Int): RasGroup[T] = {
+    if (id < 0) {
+      val out = dummyGroupBuffer((-id - 1))
+      assert(out.id() == id)
+      return out
+    }
+    groupBuffer(id)
+  }
 
   override def allClusters(): Seq[RasClusterKey] = clusterKeyBuffer
 
   override def allGroups(): Seq[RasGroup[T]] = groupBuffer
 
+  override def allDummyGroups(): Seq[RasGroup[T]] = dummyGroupBuffer
+
   private def ancestorClusterIdOf(key: RasClusterKey): Int = {
     clusterDisjointSet.find(key.id())
   }
@@ -156,7 +169,7 @@ class ForwardMemoTable[T <: AnyRef] private (override val 
ras: Ras[T])
     assert(clusterKeyBuffer.size == clusterBuffer.size)
     assert(clusterKeyBuffer.size == clusterDisjointSet.size)
     assert(clusterKeyBuffer.size == groupLookup.size)
-    assert(clusterKeyBuffer.size == clusterDummyGroupBuffer.size)
+    assert(clusterKeyBuffer.size == dummyGroupBuffer.size)
   }
 
   override def probe(): MemoTable.Probe[T] = new 
ForwardMemoTable.Probe[T](this)
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
index 49281a82d..c1bb0a6bf 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
@@ -75,7 +75,7 @@ object Memo {
     }
 
     private def dummyGroupOf(clusterKey: RasClusterKey): RasGroup[T] = {
-      memoTable.dummyGroupOf(clusterKey)
+      memoTable.getDummyGroup(clusterKey)
     }
 
     private def toCacheKeyUnsafe(n: T): MemoCacheKey[T] = {
@@ -84,7 +84,7 @@ object Memo {
 
     private def prepareInsert(n: T): Prepare[T] = {
       if (ras.isGroupLeaf(n)) {
-        val group = memoTable.allGroups()(ras.planModel.getGroupId(n))
+        val group = memoTable.getGroup(ras.planModel.getGroupId(n))
         return Prepare.cluster(this, group.clusterKey())
       }
 
@@ -129,7 +129,7 @@ object Memo {
       // TODO: Traverse up the tree to do more merges.
       private def prepareInsert(node: T): Prepare[T] = {
         if (ras.isGroupLeaf(node)) {
-          val group = 
parent.memoTable.allGroups()(ras.planModel.getGroupId(node))
+          val group = parent.memoTable.getGroup(ras.planModel.getGroupId(node))
           val residentCluster = group.clusterKey()
 
           if (residentCluster == targetCluster) {
@@ -246,6 +246,7 @@ object Memo {
 
 trait MemoStore[T <: AnyRef] {
   def getCluster(key: RasClusterKey): RasCluster[T]
+  def getDummyGroup(key: RasClusterKey): RasGroup[T]
   def getGroup(id: Int): RasGroup[T]
 }
 
@@ -260,6 +261,7 @@ object MemoStore {
 trait MemoState[T <: AnyRef] extends MemoStore[T] {
   def ras(): Ras[T]
   def clusterLookup(): Map[RasClusterKey, RasCluster[T]]
+  def clusterDummyGroupLookup(): Map[RasClusterKey, RasGroup[T]]
   def allClusters(): Iterable[RasCluster[T]]
   def allGroups(): Seq[RasGroup[T]]
 }
@@ -268,19 +270,31 @@ object MemoState {
   def apply[T <: AnyRef](
       ras: Ras[T],
       clusterLookup: Map[RasClusterKey, ImmutableRasCluster[T]],
-      allGroups: Seq[RasGroup[T]]): MemoState[T] = {
-    MemoStateImpl(ras, clusterLookup, allGroups)
+      clusterDummyGroupLookup: Map[RasClusterKey, RasGroup[T]],
+      allGroups: Seq[RasGroup[T]],
+      allDummyGroups: Seq[RasGroup[T]]): MemoState[T] = {
+    MemoStateImpl(ras, clusterLookup, clusterDummyGroupLookup, allGroups, 
allDummyGroups)
   }
 
   private case class MemoStateImpl[T <: AnyRef](
       override val ras: Ras[T],
       override val clusterLookup: Map[RasClusterKey, ImmutableRasCluster[T]],
-      override val allGroups: Seq[RasGroup[T]])
+      override val clusterDummyGroupLookup: Map[RasClusterKey, RasGroup[T]],
+      override val allGroups: Seq[RasGroup[T]],
+      allDummyGroups: Seq[RasGroup[T]])
     extends MemoState[T] {
     private val allClustersCopy = clusterLookup.values
 
     override def getCluster(key: RasClusterKey): RasCluster[T] = 
clusterLookup(key)
-    override def getGroup(id: Int): RasGroup[T] = allGroups(id)
+    override def getDummyGroup(key: RasClusterKey): RasGroup[T] = 
clusterDummyGroupLookup(key)
+    override def getGroup(id: Int): RasGroup[T] = {
+      if (id < 0) {
+        val out = allDummyGroups((-id - 1))
+        assert(out.id() == id)
+        return out
+      }
+      allGroups(id)
+    }
     override def allClusters(): Iterable[RasCluster[T]] = allClustersCopy
   }
 
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala
index 3baba8eae..8e5dac02e 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala
@@ -30,6 +30,7 @@ sealed trait MemoTable[T <: AnyRef] extends MemoStore[T] {
 
   def allClusters(): Seq[RasClusterKey]
   def allGroups(): Seq[RasGroup[T]]
+  def allDummyGroups(): Seq[RasGroup[T]]
 
   def getClusterPropSets(key: RasClusterKey): Set[PropertySet[T]]
 
@@ -44,7 +45,6 @@ object MemoTable {
   trait Writable[T <: AnyRef] extends MemoTable[T] {
     def newCluster(metadata: Metadata): RasClusterKey
     def groupOf(key: RasClusterKey, propertySet: PropertySet[T]): RasGroup[T]
-    def dummyGroupOf(key: RasClusterKey): RasGroup[T]
 
     def addToCluster(key: RasClusterKey, node: CanonicalNode[T]): Unit
     def mergeClusters(one: RasClusterKey, other: RasClusterKey): Unit
@@ -74,7 +74,20 @@ object MemoTable {
         .allClusters()
         .map(key => key -> ImmutableRasCluster(table.ras, 
table.getCluster(key)))
         .toMap
-      MemoState(table.ras, immutableClusters, table.allGroups())
+      val immutableDummyGroups = table
+        .allClusters()
+        .map(key => key -> table.getDummyGroup(key))
+        .toMap
+      table.allDummyGroups().zipWithIndex.foreach {
+        case (group, idx) =>
+          assert(group.id() == -(idx + 1))
+      }
+      MemoState(
+        table.ras,
+        immutableClusters,
+        immutableDummyGroups,
+        table.allGroups(),
+        table.allDummyGroups())
     }
 
     def doExhaustively(func: => Unit): Unit = {
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/OutputFilter.scala
 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/OutputFilter.scala
index ed1c326f9..126ae7766 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/OutputFilter.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/OutputFilter.scala
@@ -17,14 +17,15 @@
 package org.apache.gluten.ras.path
 
 import org.apache.gluten.ras.{CanonicalNode, GroupNode}
-import org.apache.gluten.ras.path.FilterWizard.FilterAction
-import org.apache.gluten.ras.path.OutputWizard.OutputAction
+import org.apache.gluten.ras.path.FilterWizard.{FilterAction, 
FilterAdvanceAction}
+import org.apache.gluten.ras.path.OutputWizard.{AdvanceAction, OutputAction}
 import org.apache.gluten.ras.util.CycleDetector
 
 trait FilterWizard[T <: AnyRef] {
   import FilterWizard._
   def omit(can: CanonicalNode[T]): FilterAction[T]
-  def omit(group: GroupNode[T], offset: Int, count: Int): FilterAction[T]
+  def omit(group: GroupNode[T]): FilterAction[T]
+  def advance(offset: Int, count: Int): FilterAdvanceAction[T]
 }
 
 object FilterWizard {
@@ -40,6 +41,11 @@ object FilterWizard {
 
     case class Continue[T <: AnyRef](newWizard: FilterWizard[T]) extends 
FilterAction[T]
   }
+
+  sealed trait FilterAdvanceAction[T <: AnyRef]
+  object FilterAdvanceAction {
+    case class Continue[T <: AnyRef](newWizard: FilterWizard[T]) extends 
FilterAdvanceAction[T]
+  }
 }
 
 object FilterWizards {
@@ -55,12 +61,15 @@ object FilterWizards {
       FilterAction.Continue(this)
     }
 
-    override def omit(group: GroupNode[T], offset: Int, count: Int): 
FilterAction[T] = {
+    override def omit(group: GroupNode[T]): FilterAction[T] = {
       if (detector.contains(group)) {
         return FilterAction.omit
       }
       FilterAction.Continue(new OmitCycles(detector.append(group)))
     }
+
+    override def advance(offset: Int, count: Int): FilterAdvanceAction[T] =
+      FilterAdvanceAction.Continue(this)
   }
 
   private object OmitCycles {
@@ -98,11 +107,11 @@ object OutputFilter {
       }
     }
 
-    override def advance(group: GroupNode[T], offset: Int, count: Int): 
OutputAction[T] = {
-      filterWizard.omit(group: GroupNode[T], offset: Int, count: Int) match {
+    override def visit(group: GroupNode[T]): OutputAction[T] = {
+      filterWizard.omit(group: GroupNode[T]) match {
         case FilterAction.Omit() => OutputAction.stop
         case FilterAction.Continue(newFilterWizard) =>
-          outputWizard.advance(group, offset, count) match {
+          outputWizard.visit(group) match {
             case stop @ OutputAction.Stop(_) => stop
             case OutputAction.Continue(drain, newOutputWizard) =>
               OutputAction.Continue(drain, new 
OutputFilterImpl(newOutputWizard, newFilterWizard))
@@ -110,6 +119,16 @@ object OutputFilter {
       }
     }
 
+    override def advance(offset: Int, count: Int): 
OutputWizard.AdvanceAction[T] = {
+      val newOutputWizard = outputWizard.advance(offset, count) match {
+        case AdvanceAction.Continue(newWizard) => newWizard
+      }
+      val newFilterWizard = filterWizard.advance(offset, count) match {
+        case FilterAdvanceAction.Continue(newWizard) => newWizard
+      }
+      AdvanceAction.Continue(new OutputFilterImpl(newOutputWizard, 
newFilterWizard))
+    }
+
     override def withPathKey(newKey: PathKey): OutputWizard[T] =
       new OutputFilterImpl[T](outputWizard.withPathKey(newKey), filterWizard)
   }
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/OutputWizard.scala
 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/OutputWizard.scala
index 0e40c36d8..73e4a5c19 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/OutputWizard.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/OutputWizard.scala
@@ -17,7 +17,7 @@
 package org.apache.gluten.ras.path
 
 import org.apache.gluten.ras.{CanonicalNode, GroupNode, Ras, RasGroup}
-import org.apache.gluten.ras.path.OutputWizard.{OutputAction, PathDrain}
+import org.apache.gluten.ras.path.OutputWizard.{AdvanceAction, OutputAction, 
PathDrain}
 
 import scala.collection.{mutable, Seq}
 
@@ -25,11 +25,11 @@ trait OutputWizard[T <: AnyRef] {
   import OutputWizard._
   // Visit a new node.
   def visit(can: CanonicalNode[T]): OutputAction[T]
-  // The returned object is a wizard for one of the node's children at the
+  // Visit a new group.
+  def visit(group: GroupNode[T]): OutputAction[T]
+  // The returned action is typically a wizard for one of the node's children 
at the
   // known offset among all children.
-  def advance(group: GroupNode[T], offset: Int, count: Int): OutputAction[T]
-  // The returned wizard would be same with this wizard
-  // except it drains paths with the input path key.
+  def advance(offset: Int, count: Int): AdvanceAction[T]
   def withPathKey(newKey: PathKey): OutputWizard[T]
 }
 
@@ -50,6 +50,11 @@ object OutputWizard {
       extends OutputAction[T]
   }
 
+  sealed trait AdvanceAction[T <: AnyRef]
+  object AdvanceAction {
+    case class Continue[T <: AnyRef](newWizard: OutputWizard[T]) extends 
AdvanceAction[T]
+  }
+
   // Path drain provides possibility to lazily materialize the yielded paths 
using path key.
   // Otherwise if each wizard emits its own paths during visiting, the de-dup 
operation
   // will be required and could cause serious performance issues.
@@ -97,12 +102,8 @@ object OutputWizard {
       new NodePrepareImpl[T](ras, wizard, allGroups, can)
     }
 
-    def prepareForGroup(
-        ras: Ras[T],
-        group: GroupNode[T],
-        offset: Int,
-        count: Int): GroupPrepare[T] = {
-      new GroupPrepareImpl[T](ras, wizard, group, offset, count)
+    def prepareForGroup(ras: Ras[T], group: GroupNode[T]): GroupPrepare[T] = {
+      new GroupPrepareImpl[T](ras, wizard, group)
     }
   }
 
@@ -112,7 +113,7 @@ object OutputWizard {
     }
 
     sealed trait GroupPrepare[T <: AnyRef] {
-      def advance(): Terminate[T]
+      def visit(): Terminate[T]
     }
 
     sealed trait Terminate[T <: AnyRef] {
@@ -154,12 +155,10 @@ object OutputWizard {
     private class GroupPrepareImpl[T <: AnyRef](
         ras: Ras[T],
         wizard: OutputWizard[T],
-        group: GroupNode[T],
-        offset: Int,
-        count: Int)
+        group: GroupNode[T])
       extends GroupPrepare[T] {
-      override def advance(): Terminate[T] = {
-        val action = wizard.advance(group, offset, count)
+      override def visit(): Terminate[T] = {
+        val action = wizard.visit(group)
         val drained = if (action.drain().isEmpty()) {
           List.empty
         } else {
@@ -201,8 +200,12 @@ object OutputWizards {
       OutputAction.Stop(PathDrain.none)
     }
 
-    override def advance(group: GroupNode[T], offset: Int, count: Int): 
OutputAction[T] =
+    override def visit(group: GroupNode[T]): OutputAction[T] = {
       OutputAction.Stop(PathDrain.none)
+    }
+
+    override def advance(offset: Int, count: Int): 
OutputWizard.AdvanceAction[T] =
+      AdvanceAction.Continue(this)
 
     override def withPathKey(newKey: PathKey): OutputWizard[T] = this
   }
@@ -219,9 +222,11 @@ object OutputWizards {
       OutputAction.Continue(PathDrain.none, this)
     }
 
-    override def advance(group: GroupNode[T], offset: Int, count: Int): 
OutputAction[T] =
+    override def visit(group: GroupNode[T]): OutputAction[T] =
       OutputAction.Continue(PathDrain.none, this)
 
+    override def advance(offset: Int, count: Int): AdvanceAction[T] = 
AdvanceAction.Continue(this)
+
     override def withPathKey(newKey: PathKey): OutputWizard[T] =
       new Emit[T](PathDrain.Specific(List(newKey)))
   }
@@ -267,12 +272,20 @@ object OutputWizards {
       act(actions)
     }
 
-    override def advance(group: GroupNode[T], offset: Int, count: Int): 
OutputAction[T] = {
+    override def visit(group: GroupNode[T]): OutputAction[T] = {
       val actions = wizards
-        .map(_.advance(group, offset, count))
+        .map(_.visit(group))
       act(actions)
     }
 
+    override def advance(offset: Int, count: Int): AdvanceAction[T] = {
+      val newWizards = wizards.map(_.advance(offset, count)).map {
+        case AdvanceAction.Continue(newWizard) =>
+          newWizard
+      }
+      AdvanceAction.Continue(new Union(newWizards))
+    }
+
     override def withPathKey(newKey: PathKey): OutputWizard[T] =
       new Union[T](wizards.map(w => w.withPathKey(newKey)))
   }
@@ -331,13 +344,17 @@ object OutputWizards {
       OutputAction.Continue(PathDrain.none, this)
     }
 
-    override def advance(group: GroupNode[T], offset: Int, count: Int): 
OutputAction[T] = {
-      var skipCursor = ele + 1
-      (0 until offset).foreach(_ => skipCursor = mask.skip(skipCursor))
-      if (mask.isAny(skipCursor)) {
+    override def visit(group: GroupNode[T]): OutputAction[T] = {
+      if (mask.isAny(ele)) {
         return OutputAction.Stop(drain)
       }
-      OutputAction.Continue(PathDrain.none, new WithMask[T](drain, mask, 
skipCursor))
+      OutputAction.Continue(PathDrain.none, this)
+    }
+
+    override def advance(offset: Int, count: Int): AdvanceAction[T] = {
+      var skipCursor = ele + 1
+      (0 until offset).foreach(_ => skipCursor = mask.skip(skipCursor))
+      AdvanceAction.Continue(new WithMask[T](drain, mask, skipCursor))
     }
 
     override def withPathKey(newKey: PathKey): OutputWizard[T] =
@@ -372,13 +389,16 @@ object OutputWizards {
       OutputAction.Continue(PathDrain.none, this)
     }
 
-    override def advance(group: GroupNode[T], offset: Int, count: Int): 
OutputAction[T] = {
-      // Omit should be done in #advance.
-      val child = pNode.children(count)(offset)
-      if (child.skip()) {
+    override def visit(group: GroupNode[T]): OutputAction[T] = {
+      if (pNode.skip()) {
         return OutputAction.Stop(drain)
       }
-      OutputAction.Continue(PathDrain.none, new WithPattern(drain, pattern, 
child))
+      OutputAction.Continue(PathDrain.none, this)
+    }
+
+    override def advance(offset: Int, count: Int): AdvanceAction[T] = {
+      val child = pNode.children(count)(offset)
+      AdvanceAction.Continue(new WithPattern(drain, pattern, child))
     }
 
     override def withPathKey(newKey: PathKey): OutputWizard[T] =
@@ -398,8 +418,8 @@ object OutputWizards {
 
     override def visit(can: CanonicalNode[T]): OutputAction[T] = {
       assert(
-        currentDepth <= depth,
-        "Current depth already larger than the maximum depth to prune. " +
+        currentDepth < depth,
+        "Current depth already larger than (or equals) the maximum depth to 
prune. " +
           "It probably because a zero depth was specified for path finding."
       )
       if (can.isLeaf()) {
@@ -408,13 +428,15 @@ object OutputWizards {
       OutputAction.Continue(PathDrain.none, this)
     }
 
-    override def advance(group: GroupNode[T], offset: Int, count: Int): 
OutputAction[T] = {
-      assert(currentDepth <= depth)
-      val nextDepth = currentDepth + 1
-      if (nextDepth > depth) {
+    override def visit(group: GroupNode[T]): OutputAction[T] = {
+      if (currentDepth >= depth) {
         return OutputAction.Stop(drain)
       }
-      OutputAction.Continue(PathDrain.none, new WithMaxDepth(drain, depth, 
nextDepth))
+      OutputAction.Continue(PathDrain.none, this)
+    }
+
+    override def advance(offset: Int, count: Int): AdvanceAction[T] = {
+      AdvanceAction.Continue(new WithMaxDepth(drain, depth, currentDepth + 1))
     }
 
     override def withPathKey(newKey: PathKey): OutputWizard[T] =
@@ -423,7 +445,7 @@ object OutputWizards {
 
   private object WithMaxDepth {
     def apply[T <: AnyRef](depth: Int): WithMaxDepth[T] = {
-      new WithMaxDepth(PathDrain.Specific(List(PathKey.random())), depth, 1)
+      new WithMaxDepth(PathDrain.Specific(List(PathKey.random())), depth, 0)
     }
   }
 }
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/PathFinder.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/PathFinder.scala
index c5949a9d7..78aed142a 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/PathFinder.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/PathFinder.scala
@@ -18,11 +18,13 @@ package org.apache.gluten.ras.path
 
 import org.apache.gluten.ras.{CanonicalNode, GroupNode, Ras}
 import org.apache.gluten.ras.memo.MemoStore
+import org.apache.gluten.ras.path.OutputWizard.AdvanceAction
 
 import scala.collection.mutable
 
 trait PathFinder[T <: AnyRef] {
   def find(base: CanonicalNode[T]): Iterable[RasPath[T]]
+  def find(base: GroupNode[T]): Iterable[RasPath[T]]
   def find(base: RasPath[T]): Iterable[RasPath[T]]
 }
 
@@ -88,6 +90,14 @@ object PathFinder {
       all
     }
 
+    override def find(base: GroupNode[T]): Iterable[RasPath[T]] = {
+      val all =
+        wizard.prepareForGroup(ras, base).visit().onContinue {
+          newWizard => enumerateFromGroup(base, newWizard)
+        }
+      all
+    }
+
     override def find(base: RasPath[T]): Iterable[RasPath[T]] = {
       val can = base.node().self().asCanonical()
       val all = wizard.prepareForNode(ras, memoStore.asGroupSupplier(), 
can).visit().onContinue {
@@ -123,9 +133,13 @@ object PathFinder {
         childrenGroups.zipWithIndex.map {
           case (childGroup, index) =>
             wizard
-              .prepareForGroup(ras, childGroup, index, childrenGroups.size)
-              .advance()
-              .onContinue(newWizard => enumerateFromGroup(childGroup, 
newWizard))
+              .advance(index, childrenGroups.size) match {
+              case AdvanceAction.Continue(newWizard) =>
+                newWizard
+                  .prepareForGroup(ras, childGroup)
+                  .visit()
+                  .onContinue(newWizard => enumerateFromGroup(childGroup, 
newWizard))
+            }
         }
       RasPath.cartesianProduct(ras, canonical, expandedChildren)
     }
@@ -168,11 +182,16 @@ object PathFinder {
         children.zip(childrenGroups).zipWithIndex.map {
           case ((child, childGroup), index) =>
             wizard
-              .prepareForGroup(ras, childGroup, index, childrenGroups.size)
-              .advance()
-              .onContinue {
-                newWizard => diveFromGroup(depth - 1, 
GroupedPathNode(childGroup, child), newWizard)
-              }
+              .advance(index, childrenGroups.size) match {
+              case AdvanceAction.Continue(newWizard) =>
+                newWizard
+                  .prepareForGroup(ras, childGroup)
+                  .visit()
+                  .onContinue {
+                    newWizard =>
+                      diveFromGroup(depth - 1, GroupedPathNode(childGroup, 
child), newWizard)
+                  }
+            }
         }
       )
     }
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala
index b694279ec..0471fb42a 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala
@@ -90,21 +90,20 @@ object Pattern {
 
   private case class PatternImpl[T <: AnyRef](root: Node[T]) extends 
Pattern[T] {
     override def matches(path: RasPath[T], depth: Int): Boolean = {
-      assert(depth >= 1)
+      assert(depth >= 0)
       assert(depth <= path.height())
       def dfs(remainingDepth: Int, patternN: Node[T], n: PathNode[T]): Boolean 
= {
         assert(remainingDepth >= 0)
-        assert(n.self().isCanonical)
         if (remainingDepth == 0) {
           return true
         }
+        if (patternN.skip()) {
+          return true
+        }
         val can = n.self().asCanonical()
         if (patternN.abort(can)) {
           return false
         }
-        if (patternN.skip()) {
-          return true
-        }
         if (!patternN.matches(can)) {
           return false
         }
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala
index 0a7bf0c76..01e826f06 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala
@@ -20,6 +20,7 @@ import org.apache.gluten.ras._
 import org.apache.gluten.ras.Ras.UnsafeKey
 import org.apache.gluten.ras.memo.Closure
 import org.apache.gluten.ras.path.InClusterPath
+import org.apache.gluten.ras.property.PropertySet
 
 import scala.collection.mutable
 
@@ -82,8 +83,8 @@ object RuleApplier {
     override def apply(icp: InClusterPath[T]): Unit = {
       val cKey = icp.cluster()
       val path = icp.path()
-      val can = path.node().self().asCanonical()
-      if (can.propSet().get(constraintDef).satisfies(constraint)) {
+      val propSet = path.node().self().propSet()
+      if (propSet.get(constraintDef).satisfies(constraint)) {
         return
       }
       val plan = path.plan()
@@ -92,13 +93,12 @@ object RuleApplier {
       if (appliedPlans.contains(pKey)) {
         return
       }
-      apply0(cKey, plan)
+      val constraintSet = propSet.withProp(constraint)
+      apply0(cKey, constraintSet, plan)
       appliedPlans += pKey
     }
 
-    private def apply0(cKey: RasClusterKey, plan: T): Unit = {
-      val propSet = ras.propertySetFactory().get(plan)
-      val constraintSet = propSet.withProp(constraint)
+    private def apply0(cKey: RasClusterKey, constraintSet: PropertySet[T], 
plan: T): Unit = {
       val equivalents = rule.shift(plan)
       equivalents.foreach {
         equiv =>
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/Shape.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/Shape.scala
index 400f0eeba..2a8255861 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/Shape.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/Shape.scala
@@ -33,10 +33,20 @@ object Shapes {
     new FixedHeight[T](height)
   }
 
+  def pattern[T <: AnyRef](pattern: org.apache.gluten.ras.path.Pattern[T]): 
Shape[T] = {
+    new Pattern(pattern)
+  }
+
   def none[T <: AnyRef](): Shape[T] = {
     new None()
   }
 
+  private class Pattern[T <: AnyRef](pattern: 
org.apache.gluten.ras.path.Pattern[T])
+    extends Shape[T] {
+    override def wizard(): OutputWizard[T] = OutputWizards.withPattern(pattern)
+    override def identify(path: RasPath[T]): Boolean = pattern.matches(path, 
path.height())
+  }
+
   private class FixedHeight[T <: AnyRef](height: Int) extends Shape[T] {
     override def wizard(): OutputWizard[T] = OutputWizards.withMaxDepth(height)
     override def identify(path: RasPath[T]): Boolean = path.height() == height
diff --git 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/PropertySuite.scala 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/PropertySuite.scala
index e48604116..aed032226 100644
--- a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/PropertySuite.scala
+++ b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/PropertySuite.scala
@@ -27,16 +27,29 @@ import org.scalatest.funsuite.AnyFunSuite
 
 class ExhaustivePlannerPropertySuite extends PropertySuite {
   override protected def conf: RasConfig = RasConfig(plannerType = 
PlannerType.Exhaustive)
+  override protected def zeroDepth: Boolean = false
 }
 
 class DpPlannerPropertySuite extends PropertySuite {
   override protected def conf: RasConfig = RasConfig(plannerType = 
PlannerType.Dp)
+  override protected def zeroDepth: Boolean = false
+}
+
+class ExhaustivePlannerPropertyZeroDepthSuite extends PropertySuite {
+  override protected def conf: RasConfig = RasConfig(plannerType = 
PlannerType.Exhaustive)
+  override protected def zeroDepth: Boolean = true
+}
+
+class DpPlannerPropertyZeroDepthSuite extends PropertySuite {
+  override protected def conf: RasConfig = RasConfig(plannerType = 
PlannerType.Dp)
+  override protected def zeroDepth: Boolean = true
 }
 
 abstract class PropertySuite extends AnyFunSuite {
   import PropertySuite._
 
   protected def conf: RasConfig
+  protected def zeroDepth: Boolean
 
   test("Group memo - cache") {
     val ras =
@@ -44,7 +57,7 @@ abstract class PropertySuite extends AnyFunSuite {
         PlanModelImpl,
         CostModelImpl,
         MetadataModelImpl,
-        NodeTypePropertyModelWithOutEnforcerRules,
+        propertyModelWithoutEnforcerRules(),
         ExplainImpl,
         RasRule.Factory.none())
         .withNewConfig(_ => conf)
@@ -86,7 +99,7 @@ abstract class PropertySuite extends AnyFunSuite {
         PlanModelImpl,
         CostModelImpl,
         MetadataModelImpl,
-        NodeTypePropertyModelWithOutEnforcerRules,
+        propertyModelWithoutEnforcerRules(),
         ExplainImpl,
         RasRule.Factory.none())
         .withNewConfig(_ => conf)
@@ -103,7 +116,7 @@ abstract class PropertySuite extends AnyFunSuite {
         PlanModelImpl,
         CostModelImpl,
         MetadataModelImpl,
-        NodeTypePropertyModel,
+        propertyModel(zeroDepth),
         ExplainImpl,
         RasRule.Factory.none())
         .withNewConfig(_ => conf)
@@ -121,7 +134,7 @@ abstract class PropertySuite extends AnyFunSuite {
         PlanModelImpl,
         CostModelImpl,
         MetadataModelImpl,
-        NodeTypePropertyModel,
+        propertyModel(zeroDepth),
         ExplainImpl,
         RasRule.Factory.reuse(List(ReplaceByTypeARule, ReplaceByTypeBRule)))
         .withNewConfig(_ => conf)
@@ -177,7 +190,7 @@ abstract class PropertySuite extends AnyFunSuite {
         PlanModelImpl,
         CostModelImpl,
         MetadataModelImpl,
-        NodeTypePropertyModelWithOutEnforcerRules,
+        propertyModelWithoutEnforcerRules(),
         ExplainImpl,
         RasRule.Factory.reuse(List(ReplaceLeafAByLeafBRule, HitCacheOp, 
FinalOp))
       )
@@ -223,7 +236,7 @@ abstract class PropertySuite extends AnyFunSuite {
         PlanModelImpl,
         CostModelImpl,
         MetadataModelImpl,
-        NodeTypePropertyModelWithOutEnforcerRules,
+        propertyModelWithoutEnforcerRules(),
         ExplainImpl,
         RasRule.Factory.reuse(List(ReplaceLeafAByLeafBRule, 
ReplaceUnaryBByUnaryARule))
       )
@@ -252,7 +265,7 @@ abstract class PropertySuite extends AnyFunSuite {
         PlanModelImpl,
         CostModelImpl,
         MetadataModelImpl,
-        NodeTypePropertyModel,
+        propertyModel(zeroDepth),
         ExplainImpl,
         RasRule.Factory.reuse(List(ConvertEnforcerAndTypeAToTypeB)))
         .withNewConfig(_ => conf)
@@ -290,7 +303,7 @@ abstract class PropertySuite extends AnyFunSuite {
         PlanModelImpl,
         CostModelImpl,
         MetadataModelImpl,
-        NodeTypePropertyModel,
+        propertyModel(zeroDepth),
         ExplainImpl,
         RasRule.Factory.reuse(List(ReplaceByTypeARule, 
ReplaceNonUnaryByTypeBRule))
       )
@@ -342,9 +355,10 @@ abstract class PropertySuite extends AnyFunSuite {
         PlanModelImpl,
         CostModelImpl,
         MetadataModelImpl,
-        NodeTypePropertyModel,
+        propertyModel(zeroDepth),
         ExplainImpl,
-        RasRule.Factory.reuse(List(ReduceTypeBCost, ConvertUnaryTypeBToTypeC)))
+        RasRule.Factory.reuse(List(ReduceTypeBCost, ConvertUnaryTypeBToTypeC))
+      )
         .withNewConfig(_ => conf)
 
     val plan =
@@ -394,7 +408,7 @@ abstract class PropertySuite extends AnyFunSuite {
         PlanModelImpl,
         CostModelImpl,
         MetadataModelImpl,
-        NodeTypePropertyModel,
+        propertyModel(zeroDepth),
         ExplainImpl,
         RasRule.Factory.reuse(List(LeftOp, RightOp))
       )
@@ -441,7 +455,7 @@ abstract class PropertySuite extends AnyFunSuite {
         PlanModelImpl,
         CostModelImpl,
         MetadataModelImpl,
-        NodeTypePropertyModel,
+        propertyModel(zeroDepth),
         ExplainImpl,
         RasRule.Factory.reuse(List(ConvertTypeBEnforcerAndLeafToTypeC, 
ConvertTypeATypeCToTypeC))
       )
@@ -494,7 +508,7 @@ abstract class PropertySuite extends AnyFunSuite {
         PlanModelImpl,
         CostModelImpl,
         MetadataModelImpl,
-        NodeTypePropertyModel,
+        propertyModel(zeroDepth),
         ExplainImpl,
         RasRule.Factory.reuse(List(ReplaceNonUnaryByTypeBRule, 
ReduceTypeBCost))
       )
@@ -651,6 +665,23 @@ object PropertySuite {
     override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
   }
 
+  case class ZeroDepthNodeTypeEnforcerRule(reqType: NodeType) extends 
RasRule[TestNode] {
+    override def shift(node: TestNode): Iterable[TestNode] = {
+      node match {
+        case group: Group =>
+          val groupType = group.propSet.get(NodeTypeDef)
+          if (groupType.satisfies(reqType)) {
+            List(group)
+          } else {
+            List(TypeEnforcer(reqType, 1, group))
+          }
+        case _ => throw new IllegalStateException()
+      }
+    }
+
+    override def shape(): Shape[TestNode] = Shapes.fixedHeight(0)
+  }
+
   object ReplaceByTypeARule extends RasRule[TestNode] {
     override def shift(node: TestNode): Iterable[TestNode] = {
       node match {
@@ -739,24 +770,53 @@ object PropertySuite {
     }
   }
 
-  object NodeTypePropertyModel extends PropertyModel[TestNode] {
-    override def propertyDefs: Seq[PropertyDef[TestNode, _ <: 
Property[TestNode]]] = Seq(
-      NodeTypeDef)
-
-    override def newEnforcerRuleFactory(propertyDef: PropertyDef[TestNode, _ 
<: Property[TestNode]])
-        : EnforcerRuleFactory[TestNode] = {
-      (constraint: Property[TestNode]) =>
-        {
-          List(NodeTypeEnforcerRule(constraint.asInstanceOf[NodeType]))
-        }
+  private def propertyModel(zeroDepth: Boolean): PropertyModel[TestNode] = {
+    if (zeroDepth) {
+      return PropertyModels.NodeTypePropertyModelZeroDepth
     }
+    PropertyModels.NodeTypePropertyModel
   }
 
-  object NodeTypePropertyModelWithOutEnforcerRules extends 
PropertyModel[TestNode] {
-    override def propertyDefs: Seq[PropertyDef[TestNode, _ <: 
Property[TestNode]]] = Seq(
-      NodeTypeDef)
+  private def propertyModelWithoutEnforcerRules(): PropertyModel[TestNode] = {
+    PropertyModels.NodeTypePropertyModelWithoutEnforcerRules
+  }
 
-    override def newEnforcerRuleFactory(propertyDef: PropertyDef[TestNode, _ 
<: Property[TestNode]])
-        : EnforcerRuleFactory[TestNode] = (_: Property[TestNode]) => List.empty
+  private object PropertyModels {
+    object NodeTypePropertyModel extends PropertyModel[TestNode] {
+      override def propertyDefs: Seq[PropertyDef[TestNode, _ <: 
Property[TestNode]]] = Seq(
+        NodeTypeDef)
+
+      override def newEnforcerRuleFactory(
+          propertyDef: PropertyDef[TestNode, _ <: Property[TestNode]])
+          : EnforcerRuleFactory[TestNode] = {
+        (constraint: Property[TestNode]) =>
+          {
+            List(NodeTypeEnforcerRule(constraint.asInstanceOf[NodeType]))
+          }
+      }
+    }
+
+    object NodeTypePropertyModelZeroDepth extends PropertyModel[TestNode] {
+      override def propertyDefs: Seq[PropertyDef[TestNode, _ <: 
Property[TestNode]]] = Seq(
+        NodeTypeDef)
+
+      override def newEnforcerRuleFactory(
+          propertyDef: PropertyDef[TestNode, _ <: Property[TestNode]])
+          : EnforcerRuleFactory[TestNode] = {
+        (constraint: Property[TestNode]) =>
+          {
+            
List(ZeroDepthNodeTypeEnforcerRule(constraint.asInstanceOf[NodeType]))
+          }
+      }
+    }
+
+    object NodeTypePropertyModelWithoutEnforcerRules extends 
PropertyModel[TestNode] {
+      override def propertyDefs: Seq[PropertyDef[TestNode, _ <: 
Property[TestNode]]] = Seq(
+        NodeTypeDef)
+
+      override def newEnforcerRuleFactory(
+          propertyDef: PropertyDef[TestNode, _ <: Property[TestNode]])
+          : EnforcerRuleFactory[TestNode] = (_: Property[TestNode]) => 
List.empty
+    }
   }
 }
diff --git 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala
index f8a3d0799..b29e0c267 100644
--- a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala
+++ b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala
@@ -19,6 +19,7 @@ package org.apache.gluten.ras
 import org.apache.gluten.ras.RasConfig.PlannerType
 import org.apache.gluten.ras.RasSuiteBase._
 import org.apache.gluten.ras.memo.Memo
+import org.apache.gluten.ras.path.Pattern
 import org.apache.gluten.ras.rule.{RasRule, Shape, Shapes}
 
 import org.scalatest.funsuite.AnyFunSuite
@@ -198,6 +199,72 @@ abstract class RasSuite extends AnyFunSuite {
     assert(optimized == Leaf(70))
   }
 
+  test(s"Group expansion - fixed height") {
+    object AddUnary extends RasRule[TestNode] {
+      override def shift(node: TestNode): Iterable[TestNode] = {
+        assert(node.isInstanceOf[Group])
+        List(Unary(50, node))
+      }
+
+      override def shape(): Shape[TestNode] = Shapes.fixedHeight(0)
+    }
+
+    val ras =
+      Ras[TestNode](
+        PlanModelImpl,
+        CostModelImpl,
+        MetadataModelImpl,
+        PropertyModelImpl,
+        ExplainImpl,
+        RasRule.Factory.reuse(List(AddUnary)))
+        .withNewConfig(_ => conf)
+    val plan = Unary(60, Unary(90, Leaf(70)))
+    val planner = ras.newPlanner(plan)
+    val optimized = planner.plan()
+
+    assert(optimized == Unary(60, Unary(90, Leaf(70))))
+
+    val state = planner.newState().memoState()
+    val allPaths = state.collectAllPaths(Int.MaxValue)
+
+    assert(state.allClusters().size == 3)
+    assert(state.allGroups().size == 3)
+    assert(allPaths.size == 15)
+  }
+
+  test(s"Group expansion - pattern") {
+    object AddUnary extends RasRule[TestNode] {
+      override def shift(node: TestNode): Iterable[TestNode] = {
+        assert(node.isInstanceOf[Group])
+        List(Unary(50, node))
+      }
+
+      override def shape(): Shape[TestNode] = 
Shapes.pattern(Pattern.ignore.build())
+    }
+
+    val ras =
+      Ras[TestNode](
+        PlanModelImpl,
+        CostModelImpl,
+        MetadataModelImpl,
+        PropertyModelImpl,
+        ExplainImpl,
+        RasRule.Factory.reuse(List(AddUnary)))
+        .withNewConfig(_ => conf)
+    val plan = Unary(60, Unary(90, Leaf(70)))
+    val planner = ras.newPlanner(plan)
+    val optimized = planner.plan()
+
+    assert(optimized == Unary(60, Unary(90, Leaf(70))))
+
+    val state = planner.newState().memoState()
+    val allPaths = state.collectAllPaths(Int.MaxValue)
+
+    assert(state.allClusters().size == 3)
+    assert(state.allGroups().size == 3)
+    assert(allPaths.size == 15)
+  }
+
   test(s"Unary node insertion") {
     object InsertUnary2 extends RasRule[TestNode] {
       override def shift(node: TestNode): Iterable[TestNode] = node match {
diff --git 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuiteBase.scala 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuiteBase.scala
index 5432a6c78..b5455d6af 100644
--- a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuiteBase.scala
+++ b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuiteBase.scala
@@ -163,7 +163,7 @@ object RasSuiteBase {
 
   implicit class MemoLikeImplicits[T <: AnyRef](val memo: MemoLike[T]) {
     def memorize(ras: Ras[T], node: T): RasGroup[T] = {
-      memo.memorize(node, ras.propSetsOf(node))
+      memo.memorize(node, ras.propSetOf(node))
     }
   }
 
diff --git 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockMemoState.scala
 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockMemoState.scala
index e4c8cb30b..7bb713afe 100644
--- 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockMemoState.scala
+++ 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockMemoState.scala
@@ -48,6 +48,11 @@ case class MockMemoState[T <: AnyRef] private (
   override def getCluster(key: RasClusterKey): RasCluster[T] = 
clusterLookup(key)
 
   override def getGroup(id: Int): RasGroup[T] = allGroups(id)
+
+  override def clusterDummyGroupLookup(): Map[RasClusterKey, RasGroup[T]] = 
Map.empty
+
+  override def getDummyGroup(key: RasClusterKey): RasGroup[T] =
+    throw new UnsupportedOperationException()
 }
 
 object MockMemoState {
diff --git 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockRasPath.scala 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockRasPath.scala
index eac3d5f70..cd8050e5f 100644
--- 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockRasPath.scala
+++ 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockRasPath.scala
@@ -27,7 +27,7 @@ object MockRasPath {
 
   def mock[T <: AnyRef](ras: Ras[T], node: T, keys: PathKeySet): RasPath[T] = {
     val memo = Memo(ras)
-    val g = memo.memorize(node, ras.propSetsOf(node))
+    val g = memo.memorize(node, ras.propSetOf(node))
     val state = memo.newState()
     val groupSupplier = state.asGroupSupplier()
     assert(g.nodes(state).size == 1)
diff --git 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/PathFinderSuite.scala
 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/PathFinderSuite.scala
index 0c8880c29..b5ea3fc3c 100644
--- 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/PathFinderSuite.scala
+++ 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/PathFinderSuite.scala
@@ -16,7 +16,7 @@
  */
 package org.apache.gluten.ras.path
 
-import org.apache.gluten.ras.{CanonicalNode, Ras}
+import org.apache.gluten.ras.{CanonicalNode, Ras, RasGroup}
 import org.apache.gluten.ras.RasSuiteBase._
 import org.apache.gluten.ras.mock.MockMemoState
 import org.apache.gluten.ras.rule.RasRule
@@ -81,6 +81,64 @@ class PathFinderSuite extends AnyFunSuite {
         Binary(n1, Unary(n2, Leaf(n4, 1)), Unary(n3, Leaf(n6, 1)))))
   }
 
+  test("Find - from group") {
+    val ras =
+      Ras[TestNode](
+        PlanModelImpl,
+        CostModelImpl,
+        MetadataModelImpl,
+        PropertyModelImpl,
+        ExplainImpl,
+        RasRule.Factory.none())
+
+    val mock = MockMemoState.Builder(ras)
+    val cluster = mock.newCluster()
+    val groupA = cluster.newGroup()
+    val groupB = cluster.newGroup()
+    val groupC = cluster.newGroup()
+    val groupD = cluster.newGroup()
+    val groupE = cluster.newGroup()
+    val n1 = "n1"
+    val n2 = "n2"
+    val n3 = "n3"
+    val n4 = "n4"
+    val n5 = "n5"
+    val n6 = "n6"
+    val node1 = Binary(n1, groupB.self, groupC.self).asCanonical(ras)
+    val node2 = Unary(n2, groupD.self).asCanonical(ras)
+    val node3 = Unary(n3, groupE.self).asCanonical(ras)
+    val node4 = Leaf(n4, 1).asCanonical(ras)
+    val node5 = Leaf(n5, 1).asCanonical(ras)
+    val node6 = Leaf(n6, 1).asCanonical(ras)
+
+    groupA.add(node1)
+    groupB.add(node2)
+    groupC.add(node3)
+    groupD.add(node4)
+    groupE.add(List(node5, node6))
+
+    val state = mock.build()
+
+    def find(group: RasGroup[TestNode], depth: Int): 
Iterable[RasPath[TestNode]] = {
+      val finder = PathFinder.builder(ras, state).depth(depth).build()
+      finder.find(group.asGroup(ras))
+    }
+
+    val height0 = find(groupA, 0).map(_.plan()).toSeq
+    val height1 = find(groupA, 1).map(_.plan()).toSeq
+    val height2 = find(groupA, 2).map(_.plan()).toSeq
+    val heightInf = find(groupA, RasPath.INF_DEPTH).map(_.plan()).toSeq
+
+    assert(height0 == List(Group(0)))
+    assert(height1 == List(Binary(n1, Group(1), Group(2))))
+    assert(height2 == List(Binary(n1, Unary(n2, Group(3)), Unary(n3, 
Group(4)))))
+    assert(
+      heightInf == List(
+        Binary(n1, Unary(n2, Leaf(n4, 1)), Unary(n3, Leaf(n5, 1))),
+        Binary(n1, Unary(n2, Leaf(n4, 1)), Unary(n3, Leaf(n6, 1)))))
+
+  }
+
   test("Find - multiple depths") {
     val ras =
       Ras[TestNode](
diff --git 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/WizardSuite.scala 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/WizardSuite.scala
index c7505ef74..523a22689 100644
--- 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/WizardSuite.scala
+++ 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/WizardSuite.scala
@@ -178,8 +178,22 @@ class WizardSuite extends AnyFunSuite {
       finder.find(node1).map(_.plan()).toSeq
     }
 
+    def findWithPatternsFromGroup(patterns: Seq[Pattern[TestNode]]): 
Seq[TestNode] = {
+      val builder = PathFinder.builder(ras, state)
+      val finder = patterns
+        .foldLeft(builder) {
+          case (builder, pattern) =>
+            builder.output(OutputWizards.withPattern(pattern))
+        }
+        .build()
+      finder.find(groupA.asGroup(ras)).map(_.plan()).toSeq
+    }
+
+    assert(findWithPatternsFromGroup(List(Pattern.ignore[TestNode].build())) 
== List(Group(0)))
+
     assert(
       findWithPatterns(List(Pattern.any[TestNode].build())) == List(Binary(n1, 
Group(1), Group(2))))
+
     assert(
       findWithPatterns(
         List(
diff --git 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/specific/DistributedSuite.scala
 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/specific/DistributedSuite.scala
index de71cba5b..2aefc54e9 100644
--- 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/specific/DistributedSuite.scala
+++ 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/specific/DistributedSuite.scala
@@ -242,6 +242,7 @@ object DistributedSuite {
 
   case object NoneDistribution extends Distribution {
     override def satisfies(other: Property[TestNode]): Boolean = other match {
+      case AnyDistribution => true
       case _: Distribution => false
       case _ => throw new UnsupportedOperationException()
     }
@@ -296,6 +297,7 @@ object DistributedSuite {
 
   case object NoneOrdering extends Ordering {
     override def satisfies(other: Property[TestNode]): Boolean = other match {
+      case AnyOrdering => true
       case _: Ordering => false
       case _ => throw new UnsupportedOperationException()
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to