This is an automated email from the ASF dual-hosted git repository.

hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new bc99c4910 [CORE][VL] Avoid re-exploring explored nodes in DpPlanner 
(#5363)
bc99c4910 is described below

commit bc99c4910e26648195fec789c090b06c2b4379e2
Author: Hongze Zhang <[email protected]>
AuthorDate: Thu Apr 11 15:53:26 2024 +0800

    [CORE][VL] Avoid re-exploring explored nodes in DpPlanner (#5363)
---
 .../src/main/scala/org/apache/gluten/ras/Ras.scala | 14 ++---
 .../scala/org/apache/gluten/ras/RasCluster.scala   |  8 +--
 .../main/scala/org/apache/gluten/ras/RasNode.scala | 21 ++++----
 .../scala/org/apache/gluten/ras/RasPlanner.scala   | 10 ++--
 .../org/apache/gluten/ras/best/BestFinder.scala    |  6 +--
 .../org/apache/gluten/ras/dp/DpClusterAlgo.scala   |  2 +-
 .../org/apache/gluten/ras/dp/DpGroupAlgo.scala     |  2 +-
 .../scala/org/apache/gluten/ras/dp/DpPlanner.scala | 62 +++++++++++++++++-----
 .../scala/org/apache/gluten/ras/memo/Memo.scala    |  6 +--
 .../org/apache/gluten/ras/path/OutputFilter.scala  | 15 ++++++
 .../org/apache/gluten/ras/rule/EnforcerRule.scala  | 39 +++++++++-----
 .../org/apache/gluten/ras/rule/RuleApplier.scala   | 10 ++--
 .../apache/gluten/ras/vis/GraphvizVisualizer.scala |  4 +-
 .../scala/org/apache/gluten/ras/RasSuite.scala     | 53 ++++++++++++++++++
 14 files changed, 184 insertions(+), 68 deletions(-)

diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
index f3d46847e..804d04d81 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
@@ -171,7 +171,7 @@ class Ras[T <: AnyRef] private (
 
   private[ras] def isInfCost(cost: Cost) = 
costModel.costComparator().equiv(cost, infCost)
 
-  private[ras] def toUnsafeKey(node: T): UnsafeKey[T] = UnsafeKey(this, node)
+  private[ras] def toHashKey(node: T): UnsafeHashKey[T] = UnsafeHashKey(this, 
node)
 }
 
 object Ras {
@@ -251,15 +251,17 @@ object Ras {
     }
   }
 
-  trait UnsafeKey[T]
+  trait UnsafeHashKey[T]
 
-  private object UnsafeKey {
-    def apply[T <: AnyRef](ras: Ras[T], self: T): UnsafeKey[T] = new 
UnsafeKeyImpl(ras, self)
-    private class UnsafeKeyImpl[T <: AnyRef](ras: Ras[T], val self: T) extends 
UnsafeKey[T] {
+  private object UnsafeHashKey {
+    def apply[T <: AnyRef](ras: Ras[T], self: T): UnsafeHashKey[T] =
+      new UnsafeHashKeyImpl(ras, self)
+    private class UnsafeHashKeyImpl[T <: AnyRef](ras: Ras[T], val self: T)
+      extends UnsafeHashKey[T] {
       override def hashCode(): Int = ras.planModel.hashCode(self)
       override def equals(other: Any): Boolean = {
         other match {
-          case that: UnsafeKeyImpl[T] => ras.planModel.equals(self, that.self)
+          case that: UnsafeHashKeyImpl[T] => ras.planModel.equals(self, 
that.self)
           case _ => false
         }
       }
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala
index 1b30e1242..eb2b41a91 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala
@@ -16,7 +16,7 @@
  */
 package org.apache.gluten.ras
 
-import org.apache.gluten.ras.Ras.UnsafeKey
+import org.apache.gluten.ras.Ras.UnsafeHashKey
 import org.apache.gluten.ras.memo.MemoTable
 import org.apache.gluten.ras.property.PropertySet
 
@@ -55,16 +55,16 @@ object RasCluster {
         override val ras: Ras[T],
         metadata: Metadata)
       extends MutableRasCluster[T] {
-      private val deDup: mutable.Set[UnsafeKey[T]] = mutable.Set()
+      private val deDup: mutable.Set[UnsafeHashKey[T]] = mutable.Set()
       private val buffer: mutable.ListBuffer[CanonicalNode[T]] =
         mutable.ListBuffer()
 
       override def contains(t: CanonicalNode[T]): Boolean = {
-        deDup.contains(t.toUnsafeKey())
+        deDup.contains(t.toHashKey())
       }
 
       override def add(t: CanonicalNode[T]): Unit = {
-        val key = t.toUnsafeKey()
+        val key = t.toHashKey()
         assert(!deDup.contains(key))
         ras.metadataModel.verify(metadata, 
ras.metadataModel.metadataOf(t.self()))
         deDup += key
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala
index 65ff8b735..710a4e682 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala
@@ -16,7 +16,7 @@
  */
 package org.apache.gluten.ras
 
-import org.apache.gluten.ras.Ras.UnsafeKey
+import org.apache.gluten.ras.Ras.UnsafeHashKey
 import org.apache.gluten.ras.property.PropertySet
 
 trait RasNode[T <: AnyRef] {
@@ -43,7 +43,7 @@ object RasNode {
       node.asInstanceOf[GroupNode[T]]
     }
 
-    def toUnsafeKey(): UnsafeKey[T] = node.ras().toUnsafeKey(node.self())
+    def toHashKey(): UnsafeHashKey[T] = node.ras().toHashKey(node.self())
   }
 }
 
@@ -131,16 +131,16 @@ object InGroupNode {
   private case class InGroupNodeImpl[T <: AnyRef](groupId: Int, can: 
CanonicalNode[T])
     extends InGroupNode[T]
 
-  trait HashKey extends Any
+  trait UniqueKey extends Any
 
   implicit class InGroupNodeImplicits[T <: AnyRef](n: InGroupNode[T]) {
     import InGroupNodeImplicits._
-    def toHashKey: HashKey =
-      InGroupNodeHashKeyImpl(n.groupId, System.identityHashCode(n.can))
+    def toUniqueKey: UniqueKey =
+      InGroupNodeUniqueKeyImpl(n.groupId, System.identityHashCode(n.can))
   }
 
   private object InGroupNodeImplicits {
-    private case class InGroupNodeHashKeyImpl(gid: Int, cid: Int) extends 
HashKey
+    private case class InGroupNodeUniqueKeyImpl(gid: Int, cid: Int) extends 
UniqueKey
   }
 }
 
@@ -159,15 +159,16 @@ object InClusterNode {
       can: CanonicalNode[T])
     extends InClusterNode[T]
 
-  trait HashKey extends Any
+  trait UniqueKey extends Any
 
   implicit class InClusterNodeImplicits[T <: AnyRef](n: InClusterNode[T]) {
     import InClusterNodeImplicits._
-    def toHashKey: HashKey =
-      InClusterNodeHashKeyImpl(n.clusterKey, System.identityHashCode(n.can))
+    def toUniqueKey: UniqueKey =
+      InClusterNodeUniqueKeyImpl(n.clusterKey, System.identityHashCode(n.can))
   }
 
   private object InClusterNodeImplicits {
-    private case class InClusterNodeHashKeyImpl(clusterKey: RasClusterKey, 
cid: Int) extends HashKey
+    private case class InClusterNodeUniqueKeyImpl(clusterKey: RasClusterKey, 
cid: Int)
+      extends UniqueKey
   }
 }
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasPlanner.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasPlanner.scala
index 74793a3d0..327b980f3 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasPlanner.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasPlanner.scala
@@ -62,11 +62,11 @@ object Best {
       bestPath: KnownCostPath[T],
       winnerNodes: Seq[InGroupNode[T]],
       costs: InGroupNode[T] => Option[Cost]): Best[T] = {
-    val bestNodes = mutable.Set[InGroupNode.HashKey]()
+    val bestNodes = mutable.Set[InGroupNode.UniqueKey]()
 
     def dfs(groupId: Int, cursor: RasPath.PathNode[T]): Unit = {
       val can = cursor.self().asCanonical()
-      bestNodes += InGroupNode(groupId, can).toHashKey
+      bestNodes += InGroupNode(groupId, can).toUniqueKey
       cursor.zipChildrenWithGroupIds().foreach {
         case (childPathNode, childGroupId) =>
           dfs(childGroupId, childPathNode)
@@ -76,14 +76,14 @@ object Best {
     dfs(rootGroupId, bestPath.rasPath.node())
 
     val bestNodeSet = bestNodes.toSet
-    val winnerNodeSet = winnerNodes.map(_.toHashKey).toSet
+    val winnerNodeSet = winnerNodes.map(_.toUniqueKey).toSet
 
     BestImpl(
       ras,
       rootGroupId,
       bestPath,
-      n => bestNodeSet.contains(n.toHashKey),
-      n => winnerNodeSet.contains(n.toHashKey),
+      n => bestNodeSet.contains(n.toUniqueKey),
+      n => winnerNodeSet.contains(n.toUniqueKey),
       costs)
   }
 
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/BestFinder.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/BestFinder.scala
index 90a0adfb2..601cd72e5 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/BestFinder.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/BestFinder.scala
@@ -57,17 +57,17 @@ object BestFinder {
 
     val bestPath = groupToCosts(group.id()).best()
     val winnerNodes = groupToCosts.map { case (id, g) => InGroupNode(id, 
g.bestNode) }.toSeq
-    val costsMap = mutable.Map[InGroupNode.HashKey, Cost]()
+    val costsMap = mutable.Map[InGroupNode.UniqueKey, Cost]()
     groupToCosts.foreach {
       case (gid, g) =>
         g.nodes.foreach {
           n =>
             val c = g.nodeToCost(n)
             if (c.nonEmpty) {
-              costsMap += (InGroupNode(gid, n).toHashKey -> c.get.cost)
+              costsMap += (InGroupNode(gid, n).toUniqueKey -> c.get.cost)
             }
         }
     }
-    Best(ras, group.id(), bestPath, winnerNodes, ign => 
costsMap.get(ign.toHashKey))
+    Best(ras, group.id(), bestPath, winnerNodes, ign => 
costsMap.get(ign.toUniqueKey))
   }
 }
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpClusterAlgo.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpClusterAlgo.scala
index 6fd95772b..046760ceb 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpClusterAlgo.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpClusterAlgo.scala
@@ -73,7 +73,7 @@ object DpClusterAlgo {
       clusterAlgoDef: DpClusterAlgoDef[T, NodeOutput, ClusterOutput])
     extends DpZipperAlgoDef[InClusterNode[T], RasClusterKey, NodeOutput, 
ClusterOutput] {
     override def idOfX(x: InClusterNode[T]): Any = {
-      x.toHashKey
+      x.toUniqueKey
     }
 
     override def idOfY(y: RasClusterKey): Any = {
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala
index c824fda8e..f88f7b6e4 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala
@@ -66,7 +66,7 @@ object DpGroupAlgo {
       groupAlgoDef: DpGroupAlgoDef[T, NodeOutput, GroupOutput])
     extends DpZipperAlgoDef[InGroupNode[T], RasGroup[T], NodeOutput, 
GroupOutput] {
     override def idOfX(x: InGroupNode[T]): Any = {
-      x.toHashKey
+      x.toUniqueKey
     }
 
     override def idOfY(y: RasGroup[T]): Any = {
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
index 4a9e3f0f0..3f2590dff 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
@@ -21,7 +21,7 @@ import org.apache.gluten.ras.Best.KnownCostPath
 import org.apache.gluten.ras.best.BestFinder
 import org.apache.gluten.ras.dp.DpZipperAlgo.Adjustment.Panel
 import org.apache.gluten.ras.memo.{Memo, MemoTable}
-import org.apache.gluten.ras.path.{InClusterPath, PathFinder, RasPath}
+import org.apache.gluten.ras.path._
 import org.apache.gluten.ras.property.PropertySet
 import org.apache.gluten.ras.rule.{EnforcerRuleSet, RuleApplier, Shape}
 
@@ -99,10 +99,16 @@ object DpPlanner {
       rules: Seq[RuleApplier[T]],
       enforcerRuleSet: EnforcerRuleSet[T])
     extends DpClusterAlgo.Adjustment[T] {
+    import ExploreAdjustment._
+
+    private val ruleShapes: Seq[Shape[T]] = rules.map(_.shape())
 
     override def exploreChildX(
         panel: Panel[InClusterNode[T], RasClusterKey],
-        x: InClusterNode[T]): Unit = {}
+        x: InClusterNode[T]): Unit = {
+      applyRulesOnNode(panel, x.clusterKey, x.can)
+    }
+
     override def exploreChildY(
         panel: Panel[InClusterNode[T], RasClusterKey],
         y: RasClusterKey): Unit = {}
@@ -115,20 +121,24 @@ object DpPlanner {
         cKey: RasClusterKey): Unit = {
       memoTable.doExhaustively {
         applyEnforcerRules(panel, cKey)
-        applyRules(panel, cKey)
       }
     }
 
-    private def applyRules(
+    private def applyRulesOnNode(
         panel: Panel[InClusterNode[T], RasClusterKey],
-        cKey: RasClusterKey): Unit = {
+        cKey: RasClusterKey,
+        can: CanonicalNode[T]): Unit = {
       if (rules.isEmpty) {
         return
       }
       val dummyGroup = memoTable.getDummyGroup(cKey)
-      val shapes = rules.map(_.shape())
-      findPaths(GroupNode(ras, dummyGroup), shapes) {
-        path => rules.foreach(rule => applyRule(panel, cKey, rule, path))
+      findPaths(GroupNode(ras, dummyGroup), ruleShapes, List(new 
FromSingleNode[T](can))) {
+        path =>
+          val rootNode = path.node().self()
+          if (rootNode.isCanonical) {
+            assert(rootNode.asCanonical() eq can)
+          }
+          rules.foreach(rule => applyRule(panel, cKey, rule, path))
       }
     }
 
@@ -137,27 +147,34 @@ object DpPlanner {
         cKey: RasClusterKey): Unit = {
       val dummyGroup = memoTable.getDummyGroup(cKey)
       cKey.propSets(memoTable).foreach {
-        constraintSet =>
+        constraintSet: PropertySet[T] =>
           val enforcerRules = enforcerRuleSet.rulesOf(constraintSet)
           if (enforcerRules.nonEmpty) {
-            val shapes = enforcerRules.map(_.shape())
-            findPaths(GroupNode(ras, dummyGroup), shapes) {
+            val shapes = enforcerRuleSet.ruleShapesOf(constraintSet)
+            findPaths(GroupNode(ras, dummyGroup), shapes, List.empty) {
               path => enforcerRules.foreach(rule => applyRule(panel, cKey, 
rule, path))
             }
           }
       }
     }
 
-    private def findPaths(gn: GroupNode[T], shapes: Seq[Shape[T]])(
+    private def findPaths(gn: GroupNode[T], shapes: Seq[Shape[T]], filters: 
Seq[FilterWizard[T]])(
         onFound: RasPath[T] => Unit): Unit = {
-      val finder = shapes
+      val finderBuilder = shapes
         .foldLeft(
           PathFinder
             .builder(ras, memoTable)) {
           case (builder, shape) =>
             builder.output(shape.wizard())
         }
+
+      val finder = filters
+        .foldLeft(finderBuilder) {
+          case (builder, filter) =>
+            builder.filter(filter)
+        }
         .build()
+
       finder.find(gn).foreach(path => onFound(path))
     }
 
@@ -191,5 +208,22 @@ object DpPlanner {
     }
   }
 
-  private object ExploreAdjustment {}
+  private object ExploreAdjustment {
+    private class FromSingleNode[T <: AnyRef](from: CanonicalNode[T]) extends 
FilterWizard[T] {
+      override def omit(can: CanonicalNode[T]): FilterWizard.FilterAction[T] = 
{
+        if (can eq from) {
+          return FilterWizard.FilterAction.Continue(this)
+        }
+        FilterWizard.FilterAction.omit
+      }
+
+      override def omit(group: GroupNode[T]): FilterWizard.FilterAction[T] =
+        FilterWizard.FilterAction.Continue(this)
+
+      override def advance(offset: Int, count: Int): 
FilterWizard.FilterAdvanceAction[T] = {
+        // We only filter on nodes from the root group. So continue with a 
noop filter.
+        FilterWizard.FilterAdvanceAction.Continue(FilterWizards.none())
+      }
+    }
+  }
 }
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
index 6406b8fb1..c67120357 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
@@ -17,7 +17,7 @@
 package org.apache.gluten.ras.memo
 
 import org.apache.gluten.ras._
-import org.apache.gluten.ras.Ras.UnsafeKey
+import org.apache.gluten.ras.Ras.UnsafeHashKey
 import org.apache.gluten.ras.property.PropertySet
 import org.apache.gluten.ras.vis.GraphvizVisualizer
 
@@ -236,11 +236,11 @@ object Memo {
   private object MemoCacheKey {
     def apply[T <: AnyRef](ras: Ras[T], self: T): MemoCacheKey[T] = {
       assert(ras.isCanonical(self))
-      MemoCacheKey[T](ras.toUnsafeKey(self))
+      MemoCacheKey[T](ras.toHashKey(self))
     }
   }
 
-  private case class MemoCacheKey[T <: AnyRef] private (delegate: UnsafeKey[T])
+  private case class MemoCacheKey[T <: AnyRef] private (delegate: 
UnsafeHashKey[T])
 }
 
 trait MemoStore[T <: AnyRef] {
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/OutputFilter.scala
 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/OutputFilter.scala
index 126ae7766..253e9ec84 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/OutputFilter.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/OutputFilter.scala
@@ -54,6 +54,21 @@ object FilterWizards {
     OmitCycles[T](CycleDetector[GroupNode[T]]((one, other) => one.groupId() == 
other.groupId()))
   }
 
+  def none[T <: AnyRef](): FilterWizard[T] = {
+    None[T]()
+  }
+
+  private class None[T <: AnyRef] private () extends FilterWizard[T] {
+    override def omit(can: CanonicalNode[T]): FilterAction[T] = 
FilterAction.Continue(this)
+    override def omit(group: GroupNode[T]): FilterAction[T] = 
FilterAction.Continue(this)
+    override def advance(offset: Int, count: Int): FilterAdvanceAction[T] =
+      FilterAdvanceAction.Continue(this)
+  }
+
+  private object None {
+    def apply[T <: AnyRef](): None[T] = new None[T]()
+  }
+
   // Cycle detection starts from the first visited group in the input path.
   private class OmitCycles[T <: AnyRef] private (detector: 
CycleDetector[GroupNode[T]])
     extends FilterWizard[T] {
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/EnforcerRule.scala
 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/EnforcerRule.scala
index c18936973..439b88a2c 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/EnforcerRule.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/EnforcerRule.scala
@@ -54,6 +54,7 @@ object EnforcerRule {
 
 trait EnforcerRuleSet[T <: AnyRef] {
   def rulesOf(constraintSet: PropertySet[T]): Seq[RuleApplier[T]]
+  def ruleShapesOf(constraintSet: PropertySet[T]): Seq[Shape[T]]
 }
 
 object EnforcerRuleSet {
@@ -73,21 +74,31 @@ object EnforcerRuleSet {
       mutable.Map[PropertyDef[T, _ <: Property[T]], EnforcerRuleFactory[T]]()
     private val buffer = mutable.Map[Property[T], Seq[RuleApplier[T]]]()
 
+    private val rulesBuffer = mutable.Map[PropertySet[T], 
Seq[RuleApplier[T]]]()
+    private val shapesBuffer = mutable.Map[PropertySet[T], Seq[Shape[T]]]()
+
     override def rulesOf(constraintSet: PropertySet[T]): Seq[RuleApplier[T]] = 
{
-      constraintSet.getMap.flatMap {
-        case (constraintDef, constraint) =>
-          buffer.getOrElseUpdate(
-            constraint, {
-              val factory =
-                factoryBuffer.getOrElseUpdate(
-                  constraintDef,
-                  newEnforcerRuleFactory(ras, constraintDef))
-              RuleApplier(ras, closure, EnforcerRule.builtin(constraint)) +: 
factory
-                .newEnforcerRules(constraint)
-                .map(rule => RuleApplier(ras, closure, EnforcerRule(rule, 
constraint)))
-            }
-          )
-      }.toSeq
+      rulesBuffer.getOrElseUpdate(
+        constraintSet,
+        constraintSet.getMap.flatMap {
+          case (constraintDef, constraint) =>
+            buffer.getOrElseUpdate(
+              constraint, {
+                val factory =
+                  factoryBuffer.getOrElseUpdate(
+                    constraintDef,
+                    newEnforcerRuleFactory(ras, constraintDef))
+                RuleApplier(ras, closure, EnforcerRule.builtin(constraint)) +: 
factory
+                  .newEnforcerRules(constraint)
+                  .map(rule => RuleApplier(ras, closure, EnforcerRule(rule, 
constraint)))
+              }
+            )
+        }.toSeq
+      )
+    }
+
+    override def ruleShapesOf(constraintSet: PropertySet[T]): Seq[Shape[T]] = {
+      shapesBuffer.getOrElseUpdate(constraintSet, 
rulesBuffer(constraintSet).map(_.shape()))
     }
   }
 }
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala
index 01e826f06..6b4082c7e 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala
@@ -17,7 +17,7 @@
 package org.apache.gluten.ras.rule
 
 import org.apache.gluten.ras._
-import org.apache.gluten.ras.Ras.UnsafeKey
+import org.apache.gluten.ras.Ras.UnsafeHashKey
 import org.apache.gluten.ras.memo.Closure
 import org.apache.gluten.ras.path.InClusterPath
 import org.apache.gluten.ras.property.PropertySet
@@ -43,14 +43,14 @@ object RuleApplier {
 
   private class RegularRuleApplier[T <: AnyRef](ras: Ras[T], closure: 
Closure[T], rule: RasRule[T])
     extends RuleApplier[T] {
-    private val deDup = mutable.Map[RasClusterKey, mutable.Set[UnsafeKey[T]]]()
+    private val deDup = mutable.Map[RasClusterKey, 
mutable.Set[UnsafeHashKey[T]]]()
 
     override def apply(icp: InClusterPath[T]): Unit = {
       val cKey = icp.cluster()
       val path = icp.path()
       val plan = path.plan()
       val appliedPlans = deDup.getOrElseUpdate(cKey, mutable.Set())
-      val pKey = ras.toUnsafeKey(plan)
+      val pKey = ras.toHashKey(plan)
       if (appliedPlans.contains(pKey)) {
         return
       }
@@ -76,7 +76,7 @@ object RuleApplier {
       closure: Closure[T],
       rule: EnforcerRule[T])
     extends RuleApplier[T] {
-    private val deDup = mutable.Map[RasClusterKey, mutable.Set[UnsafeKey[T]]]()
+    private val deDup = mutable.Map[RasClusterKey, 
mutable.Set[UnsafeHashKey[T]]]()
     private val constraint = rule.constraint()
     private val constraintDef = constraint.definition()
 
@@ -88,7 +88,7 @@ object RuleApplier {
         return
       }
       val plan = path.plan()
-      val pKey = ras.toUnsafeKey(plan)
+      val pKey = ras.toHashKey(plan)
       val appliedPlans = deDup.getOrElseUpdate(cKey, mutable.Set())
       if (appliedPlans.contains(pKey)) {
         return
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
index 018c8087e..b420d8c29 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
@@ -29,7 +29,7 @@ class GraphvizVisualizer[T <: AnyRef](ras: Ras[T], memoState: 
MemoState[T], best
   private val allGroups = memoState.allGroups()
   private val allClusters = memoState.clusterLookup()
 
-  private val nodeToId = mutable.Map[InGroupNode.HashKey, Int]()
+  private val nodeToId = mutable.Map[InGroupNode.UniqueKey, Int]()
 
   def format(): String = {
     val rootGroupId = best.rootGroupId()
@@ -156,7 +156,7 @@ class GraphvizVisualizer[T <: AnyRef](ras: Ras[T], 
memoState: MemoState[T], best
       group: RasGroup[T],
       node: CanonicalNode[T]): String = {
     val ign = InGroupNode(group.id(), node)
-    val nodeId = nodeToId.getOrElseUpdate(ign.toHashKey, nodeToId.size)
+    val nodeId = nodeToId.getOrElseUpdate(ign.toUniqueKey, nodeToId.size)
     s"[$nodeId][Cost ${costs(ign)
         .map {
           case c if ras.isInfCost(c) => "<INF>"
diff --git 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala
index b29e0c267..2f3ef348c 100644
--- a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala
+++ b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala
@@ -20,6 +20,7 @@ import org.apache.gluten.ras.RasConfig.PlannerType
 import org.apache.gluten.ras.RasSuiteBase._
 import org.apache.gluten.ras.memo.Memo
 import org.apache.gluten.ras.path.Pattern
+import org.apache.gluten.ras.path.Pattern.Matchers
 import org.apache.gluten.ras.rule.{RasRule, Shape, Shapes}
 
 import org.scalatest.funsuite.AnyFunSuite
@@ -265,6 +266,58 @@ abstract class RasSuite extends AnyFunSuite {
     assert(allPaths.size == 15)
   }
 
+  test(s"Rule dependency") {
+    // Op3 relies on Op2 relies on Op1
+
+    object Op1 extends RasRule[TestNode] {
+      override def shift(node: TestNode): Iterable[TestNode] = node match {
+        case Leaf(70) =>
+          List(Leaf(69))
+        case other => List.empty
+      }
+
+      override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
+    }
+
+    object Op2 extends RasRule[TestNode] {
+      override def shift(node: TestNode): Iterable[TestNode] = node match {
+        case Leaf(69) =>
+          List(Leaf(68))
+        case other => List.empty
+      }
+
+      override def shape(): Shape[TestNode] =
+        
Shapes.pattern(Pattern.leaf[TestNode](Matchers.clazz(classOf[Leaf])).build())
+    }
+
+    object Op3 extends RasRule[TestNode] {
+      override def shift(node: TestNode): Iterable[TestNode] = node match {
+        case Leaf(68) =>
+          List(Leaf(67))
+        case other => List.empty
+      }
+
+      override def shape(): Shape[TestNode] =
+        Shapes.pattern(Pattern.any[TestNode].build())
+    }
+
+    val ras =
+      Ras[TestNode](
+        PlanModelImpl,
+        CostModelImpl,
+        MetadataModelImpl,
+        PropertyModelImpl,
+        ExplainImpl,
+        RasRule.Factory.reuse(List(Op3, Op1, Op2)))
+        .withNewConfig(_ => conf)
+
+    val plan = Unary(90, Unary(90, Leaf(70)))
+    val planner = ras.newPlanner(plan)
+    val optimized = planner.plan()
+
+    assert(optimized == Unary(90, Unary(90, Leaf(67))))
+  }
+
   test(s"Unary node insertion") {
     object InsertUnary2 extends RasRule[TestNode] {
       override def shift(node: TestNode): Iterable[TestNode] = node match {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to