This is an automated email from the ASF dual-hosted git repository.

hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new c50e21ddd0 [VL] RAS: Various fixes (#9803)
c50e21ddd0 is described below

commit c50e21ddd032e494896a49f0832f3ac7c4c471b5
Author: Hongze Zhang <[email protected]>
AuthorDate: Thu May 29 18:34:30 2025 +0100

    [VL] RAS: Various fixes (#9803)
---
 .../enumerated/planner/VeloxRasSuite.scala         |   4 +-
 .../src/main/scala/org/apache/gluten/ras/Ras.scala |  46 +++---
 .../scala/org/apache/gluten/ras/RasCluster.scala   |  12 +-
 .../org/apache/gluten/ras/dp/DpGroupAlgo.scala     |   4 +-
 .../scala/org/apache/gluten/ras/dp/DpPlanner.scala |  22 +--
 .../gluten/ras/exaustive/ExhaustivePlanner.scala   |  21 +--
 .../apache/gluten/ras/memo/ForwardMemoTable.scala  |  40 +++--
 .../scala/org/apache/gluten/ras/memo/Memo.scala    |  12 +-
 .../org/apache/gluten/ras/memo/MemoTable.scala     |  17 +--
 .../org/apache/gluten/ras/property/MemoRole.scala  | 162 ++++++++++++---------
 .../apache/gluten/ras/rule/EnforcerRuleSet.scala   |   2 +-
 .../apache/gluten/ras/vis/GraphvizVisualizer.scala |  26 +++-
 .../org/apache/gluten/ras/mock/MockMemoState.scala |   7 +-
 .../apache/gluten/ras/property/MemoRoleSuite.scala |  40 +++++
 14 files changed, 230 insertions(+), 185 deletions(-)

diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/enumerated/planner/VeloxRasSuite.scala
 
b/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/enumerated/planner/VeloxRasSuite.scala
index 987e8ce70b..8db5c0b1b9 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/enumerated/planner/VeloxRasSuite.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/enumerated/planner/VeloxRasSuite.scala
@@ -81,7 +81,7 @@ class VeloxRasSuite extends SharedSparkSession {
     val numGroups = memoState.allGroups().size
     val numNodes = memoState.allClusters().flatMap(_.nodes()).size
     assert(numClusters == 8)
-    assert(numGroups == 30)
+    assert(numGroups == 22)
     assert(numNodes == 39)
   }
 
@@ -110,7 +110,7 @@ class VeloxRasSuite extends SharedSparkSession {
     val numGroups = memoState.allGroups().size
     val numNodes = memoState.allClusters().flatMap(_.nodes()).size
     assert(numClusters == 8)
-    assert(numGroups == 32)
+    assert(numGroups == 28)
     assert(numNodes == 55)
   }
 
diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
index 785afe5ebc..7dc2c2d42c 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
@@ -17,11 +17,12 @@
 package org.apache.gluten.ras
 
 import org.apache.gluten.ras.property.{MemoRole, PropertySet, 
PropertySetFactory}
-import org.apache.gluten.ras.rule.RasRule
+import org.apache.gluten.ras.property.MemoRole.PropertySetFactoryWithMemoRole
+import org.apache.gluten.ras.rule.{EnforcerRuleFactory, RasRule}
 
 /**
- * Entrypoint of RAS (relational algebra selector)'s search engine. See basic 
introduction of RAS:
- * https://github.com/apache/incubator-gluten/issues/5057.
+ * Entrypoint of RAS (relational algebra selector) 's search engine. See the 
basic introduction of
+ * RAS: https://github.com/apache/incubator-gluten/issues/5057.
  */
 trait Optimization[T <: AnyRef] {
   def newPlanner(plan: T, constraintSet: PropertySet[T]): RasPlanner[T]
@@ -50,11 +51,11 @@ class Ras[T <: AnyRef] private (
   extends Optimization[T] {
   import Ras._
 
-  private[ras] val memoRoleDef: MemoRole.Def[T] = MemoRole.newDef(planModel)
-  private val userPropertySetFactory: PropertySetFactory[T] =
-    PropertySetFactory(propertyModel, planModel)
-  private val propSetFactory: PropertySetFactory[T] =
-    MemoRole.wrapPropertySetFactory(userPropertySetFactory, memoRoleDef)
+  private val propSetFactory: PropertySetFactoryWithMemoRole[T] = {
+    val memoRoleDef: MemoRole.Def[T] = MemoRole.newDef(planModel)
+    val baseFactory = PropertySetFactory(propertyModel, planModel)
+    MemoRole.wrapPropertySetFactory(baseFactory, memoRoleDef)
+  }
   // Normal groups start with ID 0, so it's safe to use Int.MinValue to do 
validation.
   private val dummyGroup: T =
     newGroupLeaf(Int.MinValue, metadataModel.dummy(), propSetFactory.any())
@@ -91,11 +92,11 @@ class Ras[T <: AnyRef] private (
   }
 
   override def newPlanner(plan: T, constraintSet: PropertySet[T]): 
RasPlanner[T] = {
-    RasPlanner(this, constraintSet, plan)
+    RasPlanner(this, withUserConstraint(constraintSet), plan)
   }
 
   def newPlanner(plan: T): RasPlanner[T] = {
-    RasPlanner(this, userPropertySetFactory.any(), plan)
+    RasPlanner(this, userConstraintSet(), plan)
   }
 
   def withNewConfig(confFunc: RasConfig => RasConfig): Ras[T] = {
@@ -109,15 +110,26 @@ class Ras[T <: AnyRef] private (
       ruleFactory)
   }
 
-  private[ras] def userConstraintSet(): PropertySet[T] =
-    userPropertySetFactory.any() +: memoRoleDef.reqUser
+  private[ras] def withUserConstraint(from: PropertySet[T]): PropertySet[T] = {
+    from +: propSetFactory.userConstraint()
+  }
+
+  private[ras] def userConstraintSet(): PropertySet[T] = 
propSetFactory.userConstraintSet()
 
-  private[ras] def hubConstraintSet(): PropertySet[T] =
-    userPropertySetFactory.any() +: memoRoleDef.reqHub
+  private[ras] def hubConstraintSet(): PropertySet[T] = 
propSetFactory.hubConstraintSet()
 
   private[ras] def propSetOf(plan: T): PropertySet[T] = {
-    val out = propertySetFactory().get(plan)
-    out
+    propSetFactory.get(plan)
+  }
+
+  private[ras] def childrenConstraintSets(
+      node: T,
+      constraintSet: PropertySet[T]): Seq[PropertySet[T]] = {
+    propSetFactory.childrenConstraintSets(node, constraintSet)
+  }
+
+  private[ras] def newEnforcerRuleFactory(): EnforcerRuleFactory[T] = {
+    propSetFactory.newEnforcerRuleFactory()
   }
 
   private[ras] def withNewChildren(node: T, newChildren: Seq[T]): T = {
@@ -148,8 +160,6 @@ class Ras[T <: AnyRef] private (
       .map(child => planModel.getGroupId(child))
   }
 
-  private[ras] def propertySetFactory(): PropertySetFactory[T] = propSetFactory
-
   private[ras] def dummyGroupLeaf(): T = {
     dummyGroup
   }
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala
index 98f03eb961..c186752cda 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala
@@ -77,14 +77,12 @@ object RasCluster {
     }
   }
 
-  case class ImmutableRasCluster[T <: AnyRef] private (
-      ras: Ras[T],
-      override val nodes: Seq[CanonicalNode[T]])
-    extends RasCluster[T]
-
   object ImmutableRasCluster {
-    def apply[T <: AnyRef](ras: Ras[T], cluster: RasCluster[T]): 
ImmutableRasCluster[T] = {
-      ImmutableRasCluster(ras, cluster.nodes().toVector)
+    def apply[T <: AnyRef](ras: Ras[T], cluster: RasCluster[T]): RasCluster[T] 
= {
+      new Impl[T](ras, cluster.nodes().toSeq)
     }
+
+    private class Impl[T <: AnyRef](ras: Ras[T], override val nodes: 
Seq[CanonicalNode[T]])
+      extends RasCluster[T]
   }
 }
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala
index 13e103cfce..9172814354 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala
@@ -20,8 +20,8 @@ import org.apache.gluten.ras.{InGroupNode, RasGroup}
 import org.apache.gluten.ras.dp.DpZipperAlgo.Solution
 import org.apache.gluten.ras.memo.MemoState
 
-// Dynamic programming algorithm to solve problem against a single RAS group 
that can be
-// broken down to sub problems for subgroups.
+// Dynamic programming algorithm to solve a problem against a single RAS group 
that can be
+// broken down to subproblems for subgroups.
 trait DpGroupAlgoDef[T <: AnyRef, NodeOutput <: AnyRef, GroupOutput <: AnyRef] 
{
   def solveNode(node: InGroupNode[T], childrenGroupsOutput: RasGroup[T] => 
GroupOutput): NodeOutput
   def solveGroup(group: RasGroup[T], nodesOutput: InGroupNode[T] => 
NodeOutput): GroupOutput
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
index 0aca6bf13a..b2d429ed0e 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
@@ -36,7 +36,7 @@ private class DpPlanner[T <: AnyRef] private (ras: Ras[T], 
constraintSet: Proper
   private val deriverRuleSetFactory = EnforcerRuleSet.Factory.derive(ras, memo)
 
   private lazy val rootGroupId: Int = {
-    memo.memorize(plan, constraintSet +: ras.memoRoleDef.reqUser).id()
+    memo.memorize(plan, constraintSet).id()
   }
 
   private lazy val best: (Best[T], KnownCostPath[T]) = {
@@ -100,7 +100,6 @@ object DpPlanner {
     override def exploreChildX(
         panel: Panel[InClusterNode[T], RasClusterKey],
         x: InClusterNode[T]): Unit = {
-      applyHubRulesOnUserNode(panel, x.clusterKey, x.can)
       applyRulesOnHubNode(panel, x.clusterKey, x.can)
     }
 
@@ -136,25 +135,6 @@ object DpPlanner {
       }
     }
 
-    private def applyHubRulesOnUserNode(
-        panel: Panel[InClusterNode[T], RasClusterKey],
-        cKey: RasClusterKey,
-        can: CanonicalNode[T]): Unit = {
-      val hubConstraint = ras.hubConstraintSet()
-      val hubDeriverRuleSet = deriverRuleSetFactory.ruleSetOf(hubConstraint)
-      val hubDeriverRules = hubDeriverRuleSet.rules()
-      if (hubDeriverRules.nonEmpty) {
-        val hubDeriverRuleShapes = hubDeriverRuleSet.shapes()
-        val userGroup = memoTable.getUserGroup(cKey)
-        findPaths(
-          GroupNode(ras, userGroup),
-          hubDeriverRuleShapes,
-          List(new FromSingleNode[T](can))) {
-          path => hubDeriverRules.foreach(rule => applyRule(panel, cKey, rule, 
path))
-        }
-      }
-    }
-
     private def applyEnforcerRules(
         panel: Panel[InClusterNode[T], RasClusterKey],
         cKey: RasClusterKey): Unit = {
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
index 58a37afa47..dca811134b 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
@@ -36,7 +36,7 @@ private class ExhaustivePlanner[T <: AnyRef] private (
   private val deriverRuleSetFactory = EnforcerRuleSet.Factory.derive(ras, memo)
 
   private lazy val rootGroupId: Int = {
-    memo.memorize(plan, constraintSet +: ras.memoRoleDef.reqUser).id()
+    memo.memorize(plan, constraintSet).id()
   }
 
   private lazy val best: (Best[T], KnownCostPath[T]) = {
@@ -91,7 +91,6 @@ object ExhaustivePlanner {
 
     def explore(): Unit = {
       // TODO: ONLY APPLY RULES ON ALTERED GROUPS (and close parents)
-      applyHubRules()
       applyEnforcerRules()
       applyRules()
     }
@@ -129,24 +128,6 @@ object ExhaustivePlanner {
         }
     }
 
-    private def applyHubRules(): Unit = {
-      val hubConstraint = ras.hubConstraintSet()
-      val hubDeriverRuleSet = deriverRuleSetFactory.ruleSetOf(hubConstraint)
-      val hubDeriverRules = hubDeriverRuleSet.rules()
-      if (hubDeriverRules.nonEmpty) {
-        memoState
-          .clusterLookup()
-          .foreach {
-            case (cKey, cluster) =>
-              val hubDeriverRuleShapes = hubDeriverRuleSet.shapes()
-              val userGroup = memoState.getUserGroup(cKey)
-              findPaths(GroupNode(ras, userGroup), hubDeriverRuleShapes) {
-                path => hubDeriverRules.foreach(rule => applyRule(rule, 
InClusterPath(cKey, path)))
-              }
-          }
-      }
-    }
-
     private def applyEnforcerRules(): Unit = {
       allGroups.foreach {
         group =>
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
index cbd43026c6..791e848c2f 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
@@ -48,13 +48,12 @@ class ForwardMemoTable[T <: AnyRef] private (override val 
ras: Ras[T])
   override def newCluster(metadata: Metadata): RasClusterKey = {
     checkBufferSizes()
     val clusterId = clusterBuffer.size
-    val key = IntClusterKey(clusterId, metadata)
+    val key = IntClusterKey(clusterId)(metadata)
     clusterKeyBuffer += key
     clusterBuffer += MutableRasCluster(ras, metadata)
     clusterDisjointSet.grow()
     groupLookup += mutable.Map()
     groupOf(key, ras.hubConstraintSet())
-    groupOf(key, ras.userConstraintSet())
     memoWriteCount += 1
     key
   }
@@ -67,7 +66,7 @@ class ForwardMemoTable[T <: AnyRef] private (override val 
ras: Ras[T])
     }
     val gid = groupBuffer.size
     val newGroup =
-      RasGroup(ras, IntClusterKey(ancestor, key.metadata), gid, constraintSet)
+      RasGroup(ras, IntClusterKey(ancestor)(key.metadata), gid, constraintSet)
     lookup += constraintSet -> newGroup
     groupBuffer += newGroup
     memoWriteCount += 1
@@ -80,12 +79,24 @@ class ForwardMemoTable[T <: AnyRef] private (override val 
ras: Ras[T])
   }
 
   override def addToCluster(key: RasClusterKey, node: CanonicalNode[T]): Unit 
= {
+    if (addToCluster0(key, node)) {
+      // Insert the corresponding hub node right away.
+      addToCluster0(key, node.toHubNode(this))
+      return
+    }
+    // Node was already inserted to the cluster.
+    // Do an assertion to ensure the corresponding hub node was inserted as 
well.
+    assert(!addToCluster0(key, node.toHubNode(this)))
+  }
+
+  private def addToCluster0(key: RasClusterKey, node: CanonicalNode[T]): 
Boolean = {
     val cluster = getCluster(key)
     if (cluster.contains(node)) {
-      return
+      return false
     }
     cluster.add(node)
     memoWriteCount += 1
+    true
   }
 
   override def mergeClusters(one: RasClusterKey, other: RasClusterKey): Unit = 
{
@@ -169,18 +180,12 @@ class ForwardMemoTable[T <: AnyRef] private (override val 
ras: Ras[T])
     val lookup = groupLookup(ancestor)
     lookup(ras.hubConstraintSet())
   }
-
-  override def getUserGroup(key: RasClusterKey): RasGroup[T] = {
-    val ancestor = ancestorClusterIdOf(key)
-    val lookup = groupLookup(ancestor)
-    lookup(ras.userConstraintSet())
-  }
 }
 
 object ForwardMemoTable {
   def apply[T <: AnyRef](ras: Ras[T]): MemoTable.Writable[T] = new 
ForwardMemoTable[T](ras)
 
-  private case class IntClusterKey(id: Int, metadata: Metadata) extends 
RasClusterKey
+  private case class IntClusterKey(id: Int)(override val metadata: Metadata) 
extends RasClusterKey
 
   private class Probe[T <: AnyRef](table: ForwardMemoTable[T]) extends 
MemoTable.Probe[T] {
     private val probedClusterCount: Int = table.clusterKeyBuffer.size
@@ -228,4 +233,17 @@ object ForwardMemoTable {
       key.asInstanceOf[IntClusterKey]
     }
   }
+
+  implicit class CanonicalNodeImplicits[T <: AnyRef](node: CanonicalNode[T]) {
+    def toHubNode(store: MemoStore[T]): CanonicalNode[T] = {
+      val ras = node.ras()
+      val canUnsafe = ras.withNewChildren(
+        node.self(),
+        ras
+          .getChildrenGroupIds(node.self())
+          .map(gid => store.asGroupSupplier()(gid).clusterKey())
+          .map(cKey => store.getHubGroup(cKey).self()))
+      CanonicalNode(ras, canUnsafe)
+    }
+  }
 }
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
index db3e90d7ab..1286824723 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
@@ -66,7 +66,7 @@ object Memo {
       if (cache.contains(cacheKey)) {
         cache(cacheKey)
       } else {
-        // Node not yet added to cluster.
+        // Node was not yet added to a cluster.
         val cluster = newCluster(metadata)
         cache += (cacheKey -> cluster)
         cluster
@@ -135,7 +135,7 @@ object Memo {
           if (residentCluster == targetCluster) {
             return Prepare.cluster(parent, targetCluster)
           }
-          // The resident cluster of group leaf is not the same with target 
cluster.
+          // The resident cluster of group leaf is different with target 
cluster.
           // Merge.
           parent.memoTable.mergeClusters(residentCluster, targetCluster)
           return Prepare.cluster(parent, targetCluster)
@@ -153,7 +153,7 @@ object Memo {
         val cacheKey = parent.toCacheKey(keyUnsafe)
 
         if (!parent.cache.contains(cacheKey)) {
-          // The new node was not added to memo yet. Add it to the target 
cluster.
+          // The new node was not added to the memo yet. Add it to the target 
cluster.
           parent.cache += (cacheKey -> targetCluster)
           return Prepare.tree(parent, targetCluster, childrenPrepares)
         }
@@ -164,7 +164,7 @@ object Memo {
           // The new node already memorized to memo and in the target cluster.
           return Prepare.tree(parent, targetCluster, childrenPrepares)
         }
-        // The new node already memorized to memo, but in the different 
cluster.
+        // The new node already memorized to memo, but in a different cluster.
         // Merge the two clusters.
         parent.memoTable.mergeClusters(cachedCluster, targetCluster)
         Prepare.tree(parent, targetCluster, childrenPrepares)
@@ -204,7 +204,7 @@ object Memo {
           assert(!ras.isGroupLeaf(node))
           val childrenGroups = children
             .zip(ras.planModel.childrenOf(node))
-            .zip(ras.propertySetFactory().childrenConstraintSets(node, 
constraintSet))
+            .zip(ras.childrenConstraintSets(node, constraintSet))
             .map {
               case ((childPrepare, child), childConstraintSet) =>
                 childPrepare.doInsert(child, childConstraintSet)
@@ -250,7 +250,6 @@ object Memo {
 trait MemoStore[T <: AnyRef] {
   def getCluster(key: RasClusterKey): RasCluster[T]
   def getHubGroup(key: RasClusterKey): RasGroup[T]
-  def getUserGroup(key: RasClusterKey): RasGroup[T]
   def getGroup(id: Int): RasGroup[T]
 }
 
@@ -266,7 +265,6 @@ trait MemoState[T <: AnyRef] extends MemoStore[T] {
   def ras(): Ras[T]
   def clusterLookup(): Map[RasClusterKey, RasCluster[T]]
   def clusterHubGroupLookup(): Map[RasClusterKey, RasGroup[T]]
-  def clusterUserGroupLookup(): Map[RasClusterKey, RasGroup[T]]
   def allClusters(): Iterable[RasCluster[T]]
   def allGroups(): Seq[RasGroup[T]]
 }
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala
index 3bdf7b794e..5b1b8d04d6 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala
@@ -66,9 +66,8 @@ object MemoTable {
 
   private case class MemoStateImpl[T <: AnyRef](
       override val ras: Ras[T],
-      override val clusterLookup: Map[RasClusterKey, ImmutableRasCluster[T]],
+      override val clusterLookup: Map[RasClusterKey, RasCluster[T]],
       override val clusterHubGroupLookup: Map[RasClusterKey, RasGroup[T]],
-      override val clusterUserGroupLookup: Map[RasClusterKey, RasGroup[T]],
       override val allGroups: Seq[RasGroup[T]],
       idToGroup: Map[Int, RasGroup[T]])
     extends MemoState[T] {
@@ -76,7 +75,6 @@ object MemoTable {
 
     override def getCluster(key: RasClusterKey): RasCluster[T] = 
clusterLookup(key)
     override def getHubGroup(key: RasClusterKey): RasGroup[T] = 
clusterHubGroupLookup(key)
-    override def getUserGroup(key: RasClusterKey): RasGroup[T] = 
clusterUserGroupLookup(key)
     override def getGroup(id: Int): RasGroup[T] = idToGroup(id)
     override def allClusters(): Iterable[RasCluster[T]] = allClustersCopy
   }
@@ -93,11 +91,6 @@ object MemoTable {
         .map(key => key -> table.getHubGroup(key))
         .toMap
 
-      val immutableUserGroups = table
-        .allClusterKeys()
-        .map(key => key -> table.getUserGroup(key))
-        .toMap
-
       var maxGroupId = Int.MinValue
 
       val groupMap = table
@@ -115,13 +108,7 @@ object MemoTable {
 
       val allGroups = (0 to maxGroupId).map(table.getGroup).toVector
 
-      MemoStateImpl(
-        table.ras,
-        immutableClusters,
-        immutableHubGroups,
-        immutableUserGroups,
-        allGroups,
-        groupMap)
+      MemoStateImpl(table.ras, immutableClusters, immutableHubGroups, 
allGroups, groupMap)
     }
 
     def doExhaustively(func: => Unit): Unit = {
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/property/MemoRole.scala
 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/property/MemoRole.scala
index 523597eddb..62e5936d12 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/property/MemoRole.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/property/MemoRole.scala
@@ -27,8 +27,8 @@ sealed trait MemoRole[T <: AnyRef] extends Property[T] {
 
 object MemoRole {
   implicit class MemoRoleImplicits[T <: AnyRef](role: MemoRole[T]) {
-    def asReq(): Req[T] = role.asInstanceOf[Req[T]]
-    def asProp(): Prop[T] = role.asInstanceOf[Prop[T]]
+    private[MemoRole] def asReq(): Req[T] = role.asInstanceOf[Req[T]]
+    private[MemoRole] def asProp(): Prop[T] = role.asInstanceOf[Prop[T]]
 
     def +:(base: PropertySet[T]): PropertySet[T] = {
       require(!base.asMap.contains(role.definition()))
@@ -39,14 +39,14 @@ object MemoRole {
     }
   }
 
-  trait Req[T <: AnyRef] extends MemoRole[T]
-  trait Prop[T <: AnyRef] extends MemoRole[T]
+  sealed private trait Req[T <: AnyRef] extends MemoRole[T]
+  sealed private trait Prop[T <: AnyRef] extends MemoRole[T]
 
   // Constraints.
-  class ReqHub[T <: AnyRef] private[MemoRole] (
+  private class ReqHub[T <: AnyRef] private[MemoRole] (
       override val definition: PropertyDef[T, _ <: Property[T]])
     extends Req[T]
-  class ReqUser[T <: AnyRef] private[MemoRole] (
+  private class ReqUser[T <: AnyRef] private[MemoRole] (
       override val definition: PropertyDef[T, _ <: Property[T]])
     extends Req[T]
   private class ReqAny[T <: AnyRef] private[MemoRole] (
@@ -54,13 +54,13 @@ object MemoRole {
     extends Req[T]
 
   // Props.
-  class Leaf[T <: AnyRef] private[MemoRole] (
+  private class Leaf[T <: AnyRef] private[MemoRole] (
       override val definition: PropertyDef[T, _ <: Property[T]])
     extends Prop[T]
-  class Hub[T <: AnyRef] private[MemoRole] (
+  private class Hub[T <: AnyRef] private[MemoRole] (
       override val definition: PropertyDef[T, _ <: Property[T]])
     extends Prop[T]
-  class User[T <: AnyRef] private[MemoRole] (
+  private class User[T <: AnyRef] private[MemoRole] (
       override val definition: PropertyDef[T, _ <: Property[T]])
     extends Prop[T]
 
@@ -68,13 +68,13 @@ object MemoRole {
     extends PropertyDef[T, MemoRole[T]] {
     private val groupRoleLookup = mutable.Map[Int, Prop[T]]()
 
+    private[MemoRole] val reqHub = new ReqHub[T](this)
+    private[MemoRole] val reqUser = new ReqUser[T](this)
     private val reqAny = new ReqAny[T](this)
-    val reqHub = new ReqHub[T](this)
-    val reqUser = new ReqUser[T](this)
 
-    val leaf = new Leaf[T](this)
-    val hub = new Hub[T](this)
-    val user = new User[T](this)
+    private val leaf = new Leaf[T](this)
+    private val hub = new Hub[T](this)
+    private val user = new User[T](this)
 
     override def any(): MemoRole[T] = reqAny
 
@@ -127,7 +127,7 @@ object MemoRole {
     }
   }
 
-  implicit class DefImplicits[T <: AnyRef](roleDef: Def[T]) {
+  implicit private class DefImplicits[T <: AnyRef](roleDef: Def[T]) {
     def -:(base: PropertySet[T]): PropertySet[T] = {
       require(base.asMap.contains(roleDef))
       val map: Map[PropertyDef[T, _ <: Property[T]], Property[T]] = {
@@ -137,82 +137,104 @@ object MemoRole {
     }
   }
 
-  def newDef[T <: AnyRef](planModel: PlanModel[T]): Def[T] = {
+  private[ras] def newDef[T <: AnyRef](planModel: PlanModel[T]): Def[T] = {
     new Def[T](planModel)
   }
 
-  def wrapPropertySetFactory[T <: AnyRef](
+  private[ras] def wrapPropertySetFactory[T <: AnyRef](
       factory: PropertySetFactory[T],
-      roleDef: Def[T]): PropertySetFactory[T] = {
-    new PropertySetFactoryWithMemoRole[T](factory, roleDef)
+      roleDef: Def[T]): PropertySetFactoryWithMemoRole[T] = {
+    PropertySetFactoryWithMemoRole(factory, roleDef)
   }
 
-  private class PropertySetFactoryWithMemoRole[T <: AnyRef](
-      delegate: PropertySetFactory[T],
-      roleDef: Def[T])
-    extends PropertySetFactory[T] {
+  trait PropertySetFactoryWithMemoRole[T <: AnyRef] extends 
PropertySetFactory[T] {
+    def userConstraint(): MemoRole[T]
+    def userConstraintSet(): PropertySet[T]
+    def hubConstraintSet(): PropertySet[T]
+  }
 
-    override val any: PropertySet[T] = compose(roleDef.any(), delegate.any())
+  private object PropertySetFactoryWithMemoRole {
+    def apply[T <: AnyRef](
+        factory: PropertySetFactory[T],
+        roleDef: Def[T]): PropertySetFactoryWithMemoRole[T] = {
+      new Impl(factory, roleDef)
+    }
 
-    override def get(node: T): PropertySet[T] =
-      compose(roleDef.getProperty(node), delegate.get(node))
+    private class Impl[T <: AnyRef](delegate: PropertySetFactory[T], roleDef: 
Def[T])
+      extends PropertySetFactoryWithMemoRole[T] {
 
-    override def childrenConstraintSets(
-        node: T,
-        constraintSet: PropertySet[T]): Seq[PropertySet[T]] = {
-      assert(!roleDef.planModel.isGroupLeaf(node))
+      override val any: PropertySet[T] = compose(roleDef.any(), delegate.any())
 
-      if (roleDef.planModel.isLeaf(node)) {
-        return Nil
-      }
+      override val userConstraint: MemoRole[T] = roleDef.reqUser
 
-      val numChildren = roleDef.planModel.childrenOf(node).size
+      override val userConstraintSet: PropertySet[T] =
+        delegate.any() +: roleDef.reqUser
 
-      def delegateChildrenConstraintSets(): Seq[PropertySet[T]] = {
-        val roleRemoved = PropertySet(constraintSet.asMap - roleDef)
-        val out = delegate.childrenConstraintSets(node, roleRemoved)
-        out
-      }
+      override val hubConstraintSet: PropertySet[T] =
+        delegate.any() +: roleDef.reqHub
 
-      def delegateConstraintSetAny(): PropertySet[T] = {
-        val properties: Seq[Property[T]] = constraintSet.asMap.keys.flatMap {
-          case _: Def[T] => Nil
-          case other => Seq(other.any())
-        }.toSeq
-        PropertySet(properties)
-      }
+      override def get(node: T): PropertySet[T] =
+        compose(roleDef.getProperty(node), delegate.get(node))
 
-      val constraintSets = constraintSet.get(roleDef).asReq() match {
-        case _: ReqAny[T] =>
-          delegateChildrenConstraintSets().map(
-            delegateConstraint => compose(roleDef.any(), delegateConstraint))
-        case _: ReqHub[T] =>
-          Seq.tabulate(numChildren)(_ => compose(roleDef.reqHub, 
delegateConstraintSetAny()))
-        case _: ReqUser[T] =>
-          delegateChildrenConstraintSets().map(
-            delegateConstraint => compose(roleDef.reqUser, delegateConstraint))
-      }
+      override def childrenConstraintSets(
+          node: T,
+          constraintSet: PropertySet[T]): Seq[PropertySet[T]] = {
+        assert(!roleDef.planModel.isGroupLeaf(node))
 
-      constraintSets
-    }
+        if (roleDef.planModel.isLeaf(node)) {
+          return Nil
+        }
 
-    override def assignToGroup(group: GroupLeafBuilder[T], constraintSet: 
PropertySet[T]): Unit = {
-      roleDef.assignToGroup(group, constraintSet.asMap(roleDef))
-      delegate.assignToGroup(group, PropertySet(constraintSet.asMap - roleDef))
-    }
+        val numChildren = roleDef.planModel.childrenOf(node).size
 
-    override def newEnforcerRuleFactory(): EnforcerRuleFactory[T] = {
-      new EnforcerRuleFactory[T] {
-        private val delegateFactory: EnforcerRuleFactory[T] = 
delegate.newEnforcerRuleFactory()
+        def delegateChildrenConstraintSets(): Seq[PropertySet[T]] = {
+          val roleRemoved = PropertySet(constraintSet.asMap - roleDef)
+          val out = delegate.childrenConstraintSets(node, roleRemoved)
+          out
+        }
+
+        def delegateConstraintSetAny(): PropertySet[T] = {
+          val properties: Seq[Property[T]] = constraintSet.asMap.keys.flatMap {
+            case _: Def[T] => Nil
+            case other => Seq(other.any())
+          }.toSeq
+          PropertySet(properties)
+        }
 
-        override def newEnforcerRules(constraintSet: PropertySet[T]): 
Seq[RasRule[T]] = {
-          delegateFactory.newEnforcerRules(constraintSet -: roleDef)
+        val constraintSets = constraintSet.get(roleDef).asReq() match {
+          case _: ReqAny[T] =>
+            delegateChildrenConstraintSets().map(
+              delegateConstraint => compose(roleDef.any(), delegateConstraint))
+          case _: ReqHub[T] =>
+            Seq.tabulate(numChildren)(_ => compose(roleDef.reqHub, 
delegateConstraintSetAny()))
+          case _: ReqUser[T] =>
+            delegateChildrenConstraintSets().map(
+              delegateConstraint => compose(roleDef.reqUser, 
delegateConstraint))
+        }
+
+        constraintSets
+      }
+
+      override def assignToGroup(
+          group: GroupLeafBuilder[T],
+          constraintSet: PropertySet[T]): Unit = {
+        roleDef.assignToGroup(group, constraintSet.asMap(roleDef))
+        delegate.assignToGroup(group, PropertySet(constraintSet.asMap - 
roleDef))
+      }
+
+      override def newEnforcerRuleFactory(): EnforcerRuleFactory[T] = {
+        new EnforcerRuleFactory[T] {
+          private val delegateFactory: EnforcerRuleFactory[T] = 
delegate.newEnforcerRuleFactory()
+
+          override def newEnforcerRules(constraintSet: PropertySet[T]): 
Seq[RasRule[T]] = {
+            delegateFactory.newEnforcerRules(constraintSet -: roleDef)
+          }
         }
       }
-    }
 
-    private def compose(memoRole: MemoRole[T], base: PropertySet[T]): 
PropertySet[T] = {
-      base +: memoRole
+      private def compose(memoRole: MemoRole[T], base: PropertySet[T]): 
PropertySet[T] = {
+        base +: memoRole
+      }
     }
   }
 }
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/EnforcerRuleSet.scala
 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/EnforcerRuleSet.scala
index a254851f43..c9604aa013 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/EnforcerRuleSet.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/EnforcerRuleSet.scala
@@ -59,7 +59,7 @@ object EnforcerRuleSet {
     }
 
     private class Regular[T <: AnyRef](ras: Ras[T], closure: Closure[T]) 
extends Factory[T] {
-      private val factory = ras.propertySetFactory().newEnforcerRuleFactory()
+      private val factory = ras.newEnforcerRuleFactory()
       private val ruleSetBuffer = mutable.Map[PropertySet[T], 
EnforcerRuleSet[T]]()
 
       override def ruleSetOf(constraintSet: PropertySet[T]): 
EnforcerRuleSet[T] = {
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
index 4b4e0d45da..1b549d6663 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
@@ -22,6 +22,7 @@ import org.apache.gluten.ras.memo.MemoState
 import org.apache.gluten.ras.path._
 
 import scala.collection.mutable
+import scala.util.Random
 
 // Visualize the planning procedure using dot language.
 class GraphvizVisualizer[T <: AnyRef](ras: Ras[T], memoState: MemoState[T], 
best: Best[T]) {
@@ -41,7 +42,8 @@ class GraphvizVisualizer[T <: AnyRef](ras: Ras[T], memoState: 
MemoState[T], best
 
     val buf = new StringBuilder()
     buf.append("digraph G {\n")
-    buf.append("  compound=true;\n")
+    buf.append("  compound=true\n")
+    buf.append("  rankdir=TB\n")
 
     object IsBestNode {
       def unapply(nodeAndGroupToTest: (CanonicalNode[T], RasGroup[T])): 
Boolean = {
@@ -57,6 +59,18 @@ class GraphvizVisualizer[T <: AnyRef](ras: Ras[T], 
memoState: MemoState[T], best
 
     val clusterToGroups: mutable.Map[RasClusterKey, mutable.Set[Int]] = 
mutable.Map()
 
+    def determineGroupColor(group: RasGroup[T]): String = {
+      val isRootGroup = group.id() == rootGroupId
+      if (isRootGroup) {
+        return "lightyellow"
+      }
+      val isHubGroup = group.constraintSet() == ras.hubConstraintSet()
+      if (isHubGroup) {
+        return "lightgrey"
+      }
+      "lightblue"
+    }
+
     allGroups.foreach {
       group => clusterToGroups.getOrElseUpdate(group.clusterKey(), 
mutable.Set()).add(group.id())
     }
@@ -71,6 +85,8 @@ class GraphvizVisualizer[T <: AnyRef](ras: Ras[T], memoState: 
MemoState[T], best
         clusterToGroups(clusterKey).map(allGroups(_)).foreach {
           group =>
             buf.append(s"    subgraph cluster$dotClusterId {\n")
+            buf.append(s"      style=filled\n")
+            buf.append(s"      fillcolor=${determineGroupColor(group)}\n")
             groupToDotClusterId += group.id() -> dotClusterId
             dotClusterId = dotClusterId + 1
             buf.append(s"      
label=${'"'}${describeGroupVerbose(group)}${'"'}\n")
@@ -80,9 +96,9 @@ class GraphvizVisualizer[T <: AnyRef](ras: Ras[T], memoState: 
MemoState[T], best
                   buf.append(s"      ${'"'}${describeNode(costs, group, 
node)}${'"'}")
                   (node, group) match {
                     case IsBestNode() =>
-                      buf.append(" [style=filled, fillcolor=green] ")
+                      buf.append(" [style=filled, fillcolor=lightgreen] ")
                     case IsWinnerNode() =>
-                      buf.append(" [style=filled, fillcolor=grey] ")
+                      buf.append(" [style=filled, fillcolor=lightgrey] ")
                     case _ =>
                   }
                   buf.append("\n")
@@ -99,9 +115,9 @@ class GraphvizVisualizer[T <: AnyRef](ras: Ras[T], 
memoState: MemoState[T], best
           node =>
             node.getChildrenGroups(allGroups).map(_.group(allGroups)).foreach {
               childGroup =>
-                val childGroupNodes = childGroup.nodes(memoState)
+                val childGroupNodes = childGroup.nodes(memoState).toSeq
                 if (childGroupNodes.nonEmpty) {
-                  val randomChild = childGroupNodes.head
+                  val randomChild = 
childGroupNodes(Random.nextInt(childGroupNodes.size))
                   buf.append(
                     s"  ${'"'}${describeNode(costs, group, node)}${'"'} -> " +
                       s"${'"'}${describeNode(costs, childGroup, 
randomChild)}${'"'}  " +
diff --git 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockMemoState.scala
 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockMemoState.scala
index b33073cd3c..487a5986d1 100644
--- 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockMemoState.scala
+++ 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockMemoState.scala
@@ -53,11 +53,6 @@ case class MockMemoState[T <: AnyRef] private (
 
   override def getHubGroup(key: RasClusterKey): RasGroup[T] =
     throw new UnsupportedOperationException()
-
-  override def clusterUserGroupLookup(): Map[RasClusterKey, RasGroup[T]] = 
Map.empty
-
-  override def getUserGroup(key: RasClusterKey): RasGroup[T] =
-    throw new UnsupportedOperationException()
 }
 
 object MockMemoState {
@@ -153,7 +148,7 @@ object MockMemoState {
             id,
             clusterKey,
             propSet,
-            ras.newGroupLeaf(id, clusterKey.metadata, propSet +: 
ras.memoRoleDef.reqUser))
+            ras.newGroupLeaf(id, clusterKey.metadata, 
ras.withUserConstraint(propSet)))
         groupBuffer += group
         group
       }
diff --git 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/property/MemoRoleSuite.scala
 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/property/MemoRoleSuite.scala
new file mode 100644
index 0000000000..9cb43cd6ad
--- /dev/null
+++ 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/property/MemoRoleSuite.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.ras.property
+
+import org.apache.gluten.ras.{Ras, RasConfig, RasSuite}
+import org.apache.gluten.ras.RasSuiteBase.{CostModelImpl, ExplainImpl, 
MetadataModelImpl, PlanModelImpl, PropertyModelImpl, TestNode}
+import org.apache.gluten.ras.rule.RasRule
+
+class MemoRoleSuite extends RasSuite {
+  override protected def conf: RasConfig = RasConfig(plannerType = 
RasConfig.PlannerType.Dp)
+
+  test("equality") {
+    val ras =
+      Ras[TestNode](
+        PlanModelImpl,
+        CostModelImpl,
+        MetadataModelImpl,
+        PropertyModelImpl,
+        ExplainImpl,
+        RasRule.Factory.none())
+        .withNewConfig(_ => conf)
+    val one = ras.userConstraintSet()
+    val other = ras.withUserConstraint(PropertySet(Nil))
+    assert(one == other)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to