This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new c50e21ddd0 [VL] RAS: Various fixes (#9803)
c50e21ddd0 is described below
commit c50e21ddd032e494896a49f0832f3ac7c4c471b5
Author: Hongze Zhang <[email protected]>
AuthorDate: Thu May 29 18:34:30 2025 +0100
[VL] RAS: Various fixes (#9803)
---
.../enumerated/planner/VeloxRasSuite.scala | 4 +-
.../src/main/scala/org/apache/gluten/ras/Ras.scala | 46 +++---
.../scala/org/apache/gluten/ras/RasCluster.scala | 12 +-
.../org/apache/gluten/ras/dp/DpGroupAlgo.scala | 4 +-
.../scala/org/apache/gluten/ras/dp/DpPlanner.scala | 22 +--
.../gluten/ras/exaustive/ExhaustivePlanner.scala | 21 +--
.../apache/gluten/ras/memo/ForwardMemoTable.scala | 40 +++--
.../scala/org/apache/gluten/ras/memo/Memo.scala | 12 +-
.../org/apache/gluten/ras/memo/MemoTable.scala | 17 +--
.../org/apache/gluten/ras/property/MemoRole.scala | 162 ++++++++++++---------
.../apache/gluten/ras/rule/EnforcerRuleSet.scala | 2 +-
.../apache/gluten/ras/vis/GraphvizVisualizer.scala | 26 +++-
.../org/apache/gluten/ras/mock/MockMemoState.scala | 7 +-
.../apache/gluten/ras/property/MemoRoleSuite.scala | 40 +++++
14 files changed, 230 insertions(+), 185 deletions(-)
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/enumerated/planner/VeloxRasSuite.scala
b/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/enumerated/planner/VeloxRasSuite.scala
index 987e8ce70b..8db5c0b1b9 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/enumerated/planner/VeloxRasSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/enumerated/planner/VeloxRasSuite.scala
@@ -81,7 +81,7 @@ class VeloxRasSuite extends SharedSparkSession {
val numGroups = memoState.allGroups().size
val numNodes = memoState.allClusters().flatMap(_.nodes()).size
assert(numClusters == 8)
- assert(numGroups == 30)
+ assert(numGroups == 22)
assert(numNodes == 39)
}
@@ -110,7 +110,7 @@ class VeloxRasSuite extends SharedSparkSession {
val numGroups = memoState.allGroups().size
val numNodes = memoState.allClusters().flatMap(_.nodes()).size
assert(numClusters == 8)
- assert(numGroups == 32)
+ assert(numGroups == 28)
assert(numNodes == 55)
}
diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
index 785afe5ebc..7dc2c2d42c 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
@@ -17,11 +17,12 @@
package org.apache.gluten.ras
import org.apache.gluten.ras.property.{MemoRole, PropertySet,
PropertySetFactory}
-import org.apache.gluten.ras.rule.RasRule
+import org.apache.gluten.ras.property.MemoRole.PropertySetFactoryWithMemoRole
+import org.apache.gluten.ras.rule.{EnforcerRuleFactory, RasRule}
/**
- * Entrypoint of RAS (relational algebra selector)'s search engine. See basic
introduction of RAS:
- * https://github.com/apache/incubator-gluten/issues/5057.
+ * Entrypoint of RAS (relational algebra selector) 's search engine. See the
basic introduction of
+ * RAS: https://github.com/apache/incubator-gluten/issues/5057.
*/
trait Optimization[T <: AnyRef] {
def newPlanner(plan: T, constraintSet: PropertySet[T]): RasPlanner[T]
@@ -50,11 +51,11 @@ class Ras[T <: AnyRef] private (
extends Optimization[T] {
import Ras._
- private[ras] val memoRoleDef: MemoRole.Def[T] = MemoRole.newDef(planModel)
- private val userPropertySetFactory: PropertySetFactory[T] =
- PropertySetFactory(propertyModel, planModel)
- private val propSetFactory: PropertySetFactory[T] =
- MemoRole.wrapPropertySetFactory(userPropertySetFactory, memoRoleDef)
+ private val propSetFactory: PropertySetFactoryWithMemoRole[T] = {
+ val memoRoleDef: MemoRole.Def[T] = MemoRole.newDef(planModel)
+ val baseFactory = PropertySetFactory(propertyModel, planModel)
+ MemoRole.wrapPropertySetFactory(baseFactory, memoRoleDef)
+ }
// Normal groups start with ID 0, so it's safe to use Int.MinValue to do
validation.
private val dummyGroup: T =
newGroupLeaf(Int.MinValue, metadataModel.dummy(), propSetFactory.any())
@@ -91,11 +92,11 @@ class Ras[T <: AnyRef] private (
}
override def newPlanner(plan: T, constraintSet: PropertySet[T]):
RasPlanner[T] = {
- RasPlanner(this, constraintSet, plan)
+ RasPlanner(this, withUserConstraint(constraintSet), plan)
}
def newPlanner(plan: T): RasPlanner[T] = {
- RasPlanner(this, userPropertySetFactory.any(), plan)
+ RasPlanner(this, userConstraintSet(), plan)
}
def withNewConfig(confFunc: RasConfig => RasConfig): Ras[T] = {
@@ -109,15 +110,26 @@ class Ras[T <: AnyRef] private (
ruleFactory)
}
- private[ras] def userConstraintSet(): PropertySet[T] =
- userPropertySetFactory.any() +: memoRoleDef.reqUser
+ private[ras] def withUserConstraint(from: PropertySet[T]): PropertySet[T] = {
+ from +: propSetFactory.userConstraint()
+ }
+
+ private[ras] def userConstraintSet(): PropertySet[T] =
propSetFactory.userConstraintSet()
- private[ras] def hubConstraintSet(): PropertySet[T] =
- userPropertySetFactory.any() +: memoRoleDef.reqHub
+ private[ras] def hubConstraintSet(): PropertySet[T] =
propSetFactory.hubConstraintSet()
private[ras] def propSetOf(plan: T): PropertySet[T] = {
- val out = propertySetFactory().get(plan)
- out
+ propSetFactory.get(plan)
+ }
+
+ private[ras] def childrenConstraintSets(
+ node: T,
+ constraintSet: PropertySet[T]): Seq[PropertySet[T]] = {
+ propSetFactory.childrenConstraintSets(node, constraintSet)
+ }
+
+ private[ras] def newEnforcerRuleFactory(): EnforcerRuleFactory[T] = {
+ propSetFactory.newEnforcerRuleFactory()
}
private[ras] def withNewChildren(node: T, newChildren: Seq[T]): T = {
@@ -148,8 +160,6 @@ class Ras[T <: AnyRef] private (
.map(child => planModel.getGroupId(child))
}
- private[ras] def propertySetFactory(): PropertySetFactory[T] = propSetFactory
-
private[ras] def dummyGroupLeaf(): T = {
dummyGroup
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala
index 98f03eb961..c186752cda 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala
@@ -77,14 +77,12 @@ object RasCluster {
}
}
- case class ImmutableRasCluster[T <: AnyRef] private (
- ras: Ras[T],
- override val nodes: Seq[CanonicalNode[T]])
- extends RasCluster[T]
-
object ImmutableRasCluster {
- def apply[T <: AnyRef](ras: Ras[T], cluster: RasCluster[T]):
ImmutableRasCluster[T] = {
- ImmutableRasCluster(ras, cluster.nodes().toVector)
+ def apply[T <: AnyRef](ras: Ras[T], cluster: RasCluster[T]): RasCluster[T]
= {
+ new Impl[T](ras, cluster.nodes().toSeq)
}
+
+ private class Impl[T <: AnyRef](ras: Ras[T], override val nodes:
Seq[CanonicalNode[T]])
+ extends RasCluster[T]
}
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala
index 13e103cfce..9172814354 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala
@@ -20,8 +20,8 @@ import org.apache.gluten.ras.{InGroupNode, RasGroup}
import org.apache.gluten.ras.dp.DpZipperAlgo.Solution
import org.apache.gluten.ras.memo.MemoState
-// Dynamic programming algorithm to solve problem against a single RAS group
that can be
-// broken down to sub problems for subgroups.
+// Dynamic programming algorithm to solve a problem against a single RAS group
that can be
+// broken down to subproblems for subgroups.
trait DpGroupAlgoDef[T <: AnyRef, NodeOutput <: AnyRef, GroupOutput <: AnyRef]
{
def solveNode(node: InGroupNode[T], childrenGroupsOutput: RasGroup[T] =>
GroupOutput): NodeOutput
def solveGroup(group: RasGroup[T], nodesOutput: InGroupNode[T] =>
NodeOutput): GroupOutput
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
index 0aca6bf13a..b2d429ed0e 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
@@ -36,7 +36,7 @@ private class DpPlanner[T <: AnyRef] private (ras: Ras[T],
constraintSet: Proper
private val deriverRuleSetFactory = EnforcerRuleSet.Factory.derive(ras, memo)
private lazy val rootGroupId: Int = {
- memo.memorize(plan, constraintSet +: ras.memoRoleDef.reqUser).id()
+ memo.memorize(plan, constraintSet).id()
}
private lazy val best: (Best[T], KnownCostPath[T]) = {
@@ -100,7 +100,6 @@ object DpPlanner {
override def exploreChildX(
panel: Panel[InClusterNode[T], RasClusterKey],
x: InClusterNode[T]): Unit = {
- applyHubRulesOnUserNode(panel, x.clusterKey, x.can)
applyRulesOnHubNode(panel, x.clusterKey, x.can)
}
@@ -136,25 +135,6 @@ object DpPlanner {
}
}
- private def applyHubRulesOnUserNode(
- panel: Panel[InClusterNode[T], RasClusterKey],
- cKey: RasClusterKey,
- can: CanonicalNode[T]): Unit = {
- val hubConstraint = ras.hubConstraintSet()
- val hubDeriverRuleSet = deriverRuleSetFactory.ruleSetOf(hubConstraint)
- val hubDeriverRules = hubDeriverRuleSet.rules()
- if (hubDeriverRules.nonEmpty) {
- val hubDeriverRuleShapes = hubDeriverRuleSet.shapes()
- val userGroup = memoTable.getUserGroup(cKey)
- findPaths(
- GroupNode(ras, userGroup),
- hubDeriverRuleShapes,
- List(new FromSingleNode[T](can))) {
- path => hubDeriverRules.foreach(rule => applyRule(panel, cKey, rule,
path))
- }
- }
- }
-
private def applyEnforcerRules(
panel: Panel[InClusterNode[T], RasClusterKey],
cKey: RasClusterKey): Unit = {
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
index 58a37afa47..dca811134b 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
@@ -36,7 +36,7 @@ private class ExhaustivePlanner[T <: AnyRef] private (
private val deriverRuleSetFactory = EnforcerRuleSet.Factory.derive(ras, memo)
private lazy val rootGroupId: Int = {
- memo.memorize(plan, constraintSet +: ras.memoRoleDef.reqUser).id()
+ memo.memorize(plan, constraintSet).id()
}
private lazy val best: (Best[T], KnownCostPath[T]) = {
@@ -91,7 +91,6 @@ object ExhaustivePlanner {
def explore(): Unit = {
// TODO: ONLY APPLY RULES ON ALTERED GROUPS (and close parents)
- applyHubRules()
applyEnforcerRules()
applyRules()
}
@@ -129,24 +128,6 @@ object ExhaustivePlanner {
}
}
- private def applyHubRules(): Unit = {
- val hubConstraint = ras.hubConstraintSet()
- val hubDeriverRuleSet = deriverRuleSetFactory.ruleSetOf(hubConstraint)
- val hubDeriverRules = hubDeriverRuleSet.rules()
- if (hubDeriverRules.nonEmpty) {
- memoState
- .clusterLookup()
- .foreach {
- case (cKey, cluster) =>
- val hubDeriverRuleShapes = hubDeriverRuleSet.shapes()
- val userGroup = memoState.getUserGroup(cKey)
- findPaths(GroupNode(ras, userGroup), hubDeriverRuleShapes) {
- path => hubDeriverRules.foreach(rule => applyRule(rule,
InClusterPath(cKey, path)))
- }
- }
- }
- }
-
private def applyEnforcerRules(): Unit = {
allGroups.foreach {
group =>
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
index cbd43026c6..791e848c2f 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
@@ -48,13 +48,12 @@ class ForwardMemoTable[T <: AnyRef] private (override val
ras: Ras[T])
override def newCluster(metadata: Metadata): RasClusterKey = {
checkBufferSizes()
val clusterId = clusterBuffer.size
- val key = IntClusterKey(clusterId, metadata)
+ val key = IntClusterKey(clusterId)(metadata)
clusterKeyBuffer += key
clusterBuffer += MutableRasCluster(ras, metadata)
clusterDisjointSet.grow()
groupLookup += mutable.Map()
groupOf(key, ras.hubConstraintSet())
- groupOf(key, ras.userConstraintSet())
memoWriteCount += 1
key
}
@@ -67,7 +66,7 @@ class ForwardMemoTable[T <: AnyRef] private (override val
ras: Ras[T])
}
val gid = groupBuffer.size
val newGroup =
- RasGroup(ras, IntClusterKey(ancestor, key.metadata), gid, constraintSet)
+ RasGroup(ras, IntClusterKey(ancestor)(key.metadata), gid, constraintSet)
lookup += constraintSet -> newGroup
groupBuffer += newGroup
memoWriteCount += 1
@@ -80,12 +79,24 @@ class ForwardMemoTable[T <: AnyRef] private (override val
ras: Ras[T])
}
override def addToCluster(key: RasClusterKey, node: CanonicalNode[T]): Unit
= {
+ if (addToCluster0(key, node)) {
+ // Insert the corresponding hub node right away.
+ addToCluster0(key, node.toHubNode(this))
+ return
+ }
+ // Node was already inserted to the cluster.
+ // Do an assertion to ensure the corresponding hub node was inserted as
well.
+ assert(!addToCluster0(key, node.toHubNode(this)))
+ }
+
+ private def addToCluster0(key: RasClusterKey, node: CanonicalNode[T]):
Boolean = {
val cluster = getCluster(key)
if (cluster.contains(node)) {
- return
+ return false
}
cluster.add(node)
memoWriteCount += 1
+ true
}
override def mergeClusters(one: RasClusterKey, other: RasClusterKey): Unit =
{
@@ -169,18 +180,12 @@ class ForwardMemoTable[T <: AnyRef] private (override val
ras: Ras[T])
val lookup = groupLookup(ancestor)
lookup(ras.hubConstraintSet())
}
-
- override def getUserGroup(key: RasClusterKey): RasGroup[T] = {
- val ancestor = ancestorClusterIdOf(key)
- val lookup = groupLookup(ancestor)
- lookup(ras.userConstraintSet())
- }
}
object ForwardMemoTable {
def apply[T <: AnyRef](ras: Ras[T]): MemoTable.Writable[T] = new
ForwardMemoTable[T](ras)
- private case class IntClusterKey(id: Int, metadata: Metadata) extends
RasClusterKey
+ private case class IntClusterKey(id: Int)(override val metadata: Metadata)
extends RasClusterKey
private class Probe[T <: AnyRef](table: ForwardMemoTable[T]) extends
MemoTable.Probe[T] {
private val probedClusterCount: Int = table.clusterKeyBuffer.size
@@ -228,4 +233,17 @@ object ForwardMemoTable {
key.asInstanceOf[IntClusterKey]
}
}
+
+ implicit class CanonicalNodeImplicits[T <: AnyRef](node: CanonicalNode[T]) {
+ def toHubNode(store: MemoStore[T]): CanonicalNode[T] = {
+ val ras = node.ras()
+ val canUnsafe = ras.withNewChildren(
+ node.self(),
+ ras
+ .getChildrenGroupIds(node.self())
+ .map(gid => store.asGroupSupplier()(gid).clusterKey())
+ .map(cKey => store.getHubGroup(cKey).self()))
+ CanonicalNode(ras, canUnsafe)
+ }
+ }
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
index db3e90d7ab..1286824723 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
@@ -66,7 +66,7 @@ object Memo {
if (cache.contains(cacheKey)) {
cache(cacheKey)
} else {
- // Node not yet added to cluster.
+ // Node was not yet added to a cluster.
val cluster = newCluster(metadata)
cache += (cacheKey -> cluster)
cluster
@@ -135,7 +135,7 @@ object Memo {
if (residentCluster == targetCluster) {
return Prepare.cluster(parent, targetCluster)
}
- // The resident cluster of group leaf is not the same with target
cluster.
+ // The resident cluster of group leaf is different with target
cluster.
// Merge.
parent.memoTable.mergeClusters(residentCluster, targetCluster)
return Prepare.cluster(parent, targetCluster)
@@ -153,7 +153,7 @@ object Memo {
val cacheKey = parent.toCacheKey(keyUnsafe)
if (!parent.cache.contains(cacheKey)) {
- // The new node was not added to memo yet. Add it to the target
cluster.
+ // The new node was not added to the memo yet. Add it to the target
cluster.
parent.cache += (cacheKey -> targetCluster)
return Prepare.tree(parent, targetCluster, childrenPrepares)
}
@@ -164,7 +164,7 @@ object Memo {
// The new node already memorized to memo and in the target cluster.
return Prepare.tree(parent, targetCluster, childrenPrepares)
}
- // The new node already memorized to memo, but in the different
cluster.
+ // The new node already memorized to memo, but in a different cluster.
// Merge the two clusters.
parent.memoTable.mergeClusters(cachedCluster, targetCluster)
Prepare.tree(parent, targetCluster, childrenPrepares)
@@ -204,7 +204,7 @@ object Memo {
assert(!ras.isGroupLeaf(node))
val childrenGroups = children
.zip(ras.planModel.childrenOf(node))
- .zip(ras.propertySetFactory().childrenConstraintSets(node,
constraintSet))
+ .zip(ras.childrenConstraintSets(node, constraintSet))
.map {
case ((childPrepare, child), childConstraintSet) =>
childPrepare.doInsert(child, childConstraintSet)
@@ -250,7 +250,6 @@ object Memo {
trait MemoStore[T <: AnyRef] {
def getCluster(key: RasClusterKey): RasCluster[T]
def getHubGroup(key: RasClusterKey): RasGroup[T]
- def getUserGroup(key: RasClusterKey): RasGroup[T]
def getGroup(id: Int): RasGroup[T]
}
@@ -266,7 +265,6 @@ trait MemoState[T <: AnyRef] extends MemoStore[T] {
def ras(): Ras[T]
def clusterLookup(): Map[RasClusterKey, RasCluster[T]]
def clusterHubGroupLookup(): Map[RasClusterKey, RasGroup[T]]
- def clusterUserGroupLookup(): Map[RasClusterKey, RasGroup[T]]
def allClusters(): Iterable[RasCluster[T]]
def allGroups(): Seq[RasGroup[T]]
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala
index 3bdf7b794e..5b1b8d04d6 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala
@@ -66,9 +66,8 @@ object MemoTable {
private case class MemoStateImpl[T <: AnyRef](
override val ras: Ras[T],
- override val clusterLookup: Map[RasClusterKey, ImmutableRasCluster[T]],
+ override val clusterLookup: Map[RasClusterKey, RasCluster[T]],
override val clusterHubGroupLookup: Map[RasClusterKey, RasGroup[T]],
- override val clusterUserGroupLookup: Map[RasClusterKey, RasGroup[T]],
override val allGroups: Seq[RasGroup[T]],
idToGroup: Map[Int, RasGroup[T]])
extends MemoState[T] {
@@ -76,7 +75,6 @@ object MemoTable {
override def getCluster(key: RasClusterKey): RasCluster[T] =
clusterLookup(key)
override def getHubGroup(key: RasClusterKey): RasGroup[T] =
clusterHubGroupLookup(key)
- override def getUserGroup(key: RasClusterKey): RasGroup[T] =
clusterUserGroupLookup(key)
override def getGroup(id: Int): RasGroup[T] = idToGroup(id)
override def allClusters(): Iterable[RasCluster[T]] = allClustersCopy
}
@@ -93,11 +91,6 @@ object MemoTable {
.map(key => key -> table.getHubGroup(key))
.toMap
- val immutableUserGroups = table
- .allClusterKeys()
- .map(key => key -> table.getUserGroup(key))
- .toMap
-
var maxGroupId = Int.MinValue
val groupMap = table
@@ -115,13 +108,7 @@ object MemoTable {
val allGroups = (0 to maxGroupId).map(table.getGroup).toVector
- MemoStateImpl(
- table.ras,
- immutableClusters,
- immutableHubGroups,
- immutableUserGroups,
- allGroups,
- groupMap)
+ MemoStateImpl(table.ras, immutableClusters, immutableHubGroups,
allGroups, groupMap)
}
def doExhaustively(func: => Unit): Unit = {
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/property/MemoRole.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/property/MemoRole.scala
index 523597eddb..62e5936d12 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/property/MemoRole.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/property/MemoRole.scala
@@ -27,8 +27,8 @@ sealed trait MemoRole[T <: AnyRef] extends Property[T] {
object MemoRole {
implicit class MemoRoleImplicits[T <: AnyRef](role: MemoRole[T]) {
- def asReq(): Req[T] = role.asInstanceOf[Req[T]]
- def asProp(): Prop[T] = role.asInstanceOf[Prop[T]]
+ private[MemoRole] def asReq(): Req[T] = role.asInstanceOf[Req[T]]
+ private[MemoRole] def asProp(): Prop[T] = role.asInstanceOf[Prop[T]]
def +:(base: PropertySet[T]): PropertySet[T] = {
require(!base.asMap.contains(role.definition()))
@@ -39,14 +39,14 @@ object MemoRole {
}
}
- trait Req[T <: AnyRef] extends MemoRole[T]
- trait Prop[T <: AnyRef] extends MemoRole[T]
+ sealed private trait Req[T <: AnyRef] extends MemoRole[T]
+ sealed private trait Prop[T <: AnyRef] extends MemoRole[T]
// Constraints.
- class ReqHub[T <: AnyRef] private[MemoRole] (
+ private class ReqHub[T <: AnyRef] private[MemoRole] (
override val definition: PropertyDef[T, _ <: Property[T]])
extends Req[T]
- class ReqUser[T <: AnyRef] private[MemoRole] (
+ private class ReqUser[T <: AnyRef] private[MemoRole] (
override val definition: PropertyDef[T, _ <: Property[T]])
extends Req[T]
private class ReqAny[T <: AnyRef] private[MemoRole] (
@@ -54,13 +54,13 @@ object MemoRole {
extends Req[T]
// Props.
- class Leaf[T <: AnyRef] private[MemoRole] (
+ private class Leaf[T <: AnyRef] private[MemoRole] (
override val definition: PropertyDef[T, _ <: Property[T]])
extends Prop[T]
- class Hub[T <: AnyRef] private[MemoRole] (
+ private class Hub[T <: AnyRef] private[MemoRole] (
override val definition: PropertyDef[T, _ <: Property[T]])
extends Prop[T]
- class User[T <: AnyRef] private[MemoRole] (
+ private class User[T <: AnyRef] private[MemoRole] (
override val definition: PropertyDef[T, _ <: Property[T]])
extends Prop[T]
@@ -68,13 +68,13 @@ object MemoRole {
extends PropertyDef[T, MemoRole[T]] {
private val groupRoleLookup = mutable.Map[Int, Prop[T]]()
+ private[MemoRole] val reqHub = new ReqHub[T](this)
+ private[MemoRole] val reqUser = new ReqUser[T](this)
private val reqAny = new ReqAny[T](this)
- val reqHub = new ReqHub[T](this)
- val reqUser = new ReqUser[T](this)
- val leaf = new Leaf[T](this)
- val hub = new Hub[T](this)
- val user = new User[T](this)
+ private val leaf = new Leaf[T](this)
+ private val hub = new Hub[T](this)
+ private val user = new User[T](this)
override def any(): MemoRole[T] = reqAny
@@ -127,7 +127,7 @@ object MemoRole {
}
}
- implicit class DefImplicits[T <: AnyRef](roleDef: Def[T]) {
+ implicit private class DefImplicits[T <: AnyRef](roleDef: Def[T]) {
def -:(base: PropertySet[T]): PropertySet[T] = {
require(base.asMap.contains(roleDef))
val map: Map[PropertyDef[T, _ <: Property[T]], Property[T]] = {
@@ -137,82 +137,104 @@ object MemoRole {
}
}
- def newDef[T <: AnyRef](planModel: PlanModel[T]): Def[T] = {
+ private[ras] def newDef[T <: AnyRef](planModel: PlanModel[T]): Def[T] = {
new Def[T](planModel)
}
- def wrapPropertySetFactory[T <: AnyRef](
+ private[ras] def wrapPropertySetFactory[T <: AnyRef](
factory: PropertySetFactory[T],
- roleDef: Def[T]): PropertySetFactory[T] = {
- new PropertySetFactoryWithMemoRole[T](factory, roleDef)
+ roleDef: Def[T]): PropertySetFactoryWithMemoRole[T] = {
+ PropertySetFactoryWithMemoRole(factory, roleDef)
}
- private class PropertySetFactoryWithMemoRole[T <: AnyRef](
- delegate: PropertySetFactory[T],
- roleDef: Def[T])
- extends PropertySetFactory[T] {
+ trait PropertySetFactoryWithMemoRole[T <: AnyRef] extends
PropertySetFactory[T] {
+ def userConstraint(): MemoRole[T]
+ def userConstraintSet(): PropertySet[T]
+ def hubConstraintSet(): PropertySet[T]
+ }
- override val any: PropertySet[T] = compose(roleDef.any(), delegate.any())
+ private object PropertySetFactoryWithMemoRole {
+ def apply[T <: AnyRef](
+ factory: PropertySetFactory[T],
+ roleDef: Def[T]): PropertySetFactoryWithMemoRole[T] = {
+ new Impl(factory, roleDef)
+ }
- override def get(node: T): PropertySet[T] =
- compose(roleDef.getProperty(node), delegate.get(node))
+ private class Impl[T <: AnyRef](delegate: PropertySetFactory[T], roleDef:
Def[T])
+ extends PropertySetFactoryWithMemoRole[T] {
- override def childrenConstraintSets(
- node: T,
- constraintSet: PropertySet[T]): Seq[PropertySet[T]] = {
- assert(!roleDef.planModel.isGroupLeaf(node))
+ override val any: PropertySet[T] = compose(roleDef.any(), delegate.any())
- if (roleDef.planModel.isLeaf(node)) {
- return Nil
- }
+ override val userConstraint: MemoRole[T] = roleDef.reqUser
- val numChildren = roleDef.planModel.childrenOf(node).size
+ override val userConstraintSet: PropertySet[T] =
+ delegate.any() +: roleDef.reqUser
- def delegateChildrenConstraintSets(): Seq[PropertySet[T]] = {
- val roleRemoved = PropertySet(constraintSet.asMap - roleDef)
- val out = delegate.childrenConstraintSets(node, roleRemoved)
- out
- }
+ override val hubConstraintSet: PropertySet[T] =
+ delegate.any() +: roleDef.reqHub
- def delegateConstraintSetAny(): PropertySet[T] = {
- val properties: Seq[Property[T]] = constraintSet.asMap.keys.flatMap {
- case _: Def[T] => Nil
- case other => Seq(other.any())
- }.toSeq
- PropertySet(properties)
- }
+ override def get(node: T): PropertySet[T] =
+ compose(roleDef.getProperty(node), delegate.get(node))
- val constraintSets = constraintSet.get(roleDef).asReq() match {
- case _: ReqAny[T] =>
- delegateChildrenConstraintSets().map(
- delegateConstraint => compose(roleDef.any(), delegateConstraint))
- case _: ReqHub[T] =>
- Seq.tabulate(numChildren)(_ => compose(roleDef.reqHub,
delegateConstraintSetAny()))
- case _: ReqUser[T] =>
- delegateChildrenConstraintSets().map(
- delegateConstraint => compose(roleDef.reqUser, delegateConstraint))
- }
+ override def childrenConstraintSets(
+ node: T,
+ constraintSet: PropertySet[T]): Seq[PropertySet[T]] = {
+ assert(!roleDef.planModel.isGroupLeaf(node))
- constraintSets
- }
+ if (roleDef.planModel.isLeaf(node)) {
+ return Nil
+ }
- override def assignToGroup(group: GroupLeafBuilder[T], constraintSet:
PropertySet[T]): Unit = {
- roleDef.assignToGroup(group, constraintSet.asMap(roleDef))
- delegate.assignToGroup(group, PropertySet(constraintSet.asMap - roleDef))
- }
+ val numChildren = roleDef.planModel.childrenOf(node).size
- override def newEnforcerRuleFactory(): EnforcerRuleFactory[T] = {
- new EnforcerRuleFactory[T] {
- private val delegateFactory: EnforcerRuleFactory[T] =
delegate.newEnforcerRuleFactory()
+ def delegateChildrenConstraintSets(): Seq[PropertySet[T]] = {
+ val roleRemoved = PropertySet(constraintSet.asMap - roleDef)
+ val out = delegate.childrenConstraintSets(node, roleRemoved)
+ out
+ }
+
+ def delegateConstraintSetAny(): PropertySet[T] = {
+ val properties: Seq[Property[T]] = constraintSet.asMap.keys.flatMap {
+ case _: Def[T] => Nil
+ case other => Seq(other.any())
+ }.toSeq
+ PropertySet(properties)
+ }
- override def newEnforcerRules(constraintSet: PropertySet[T]):
Seq[RasRule[T]] = {
- delegateFactory.newEnforcerRules(constraintSet -: roleDef)
+ val constraintSets = constraintSet.get(roleDef).asReq() match {
+ case _: ReqAny[T] =>
+ delegateChildrenConstraintSets().map(
+ delegateConstraint => compose(roleDef.any(), delegateConstraint))
+ case _: ReqHub[T] =>
+ Seq.tabulate(numChildren)(_ => compose(roleDef.reqHub,
delegateConstraintSetAny()))
+ case _: ReqUser[T] =>
+ delegateChildrenConstraintSets().map(
+ delegateConstraint => compose(roleDef.reqUser,
delegateConstraint))
+ }
+
+ constraintSets
+ }
+
+ override def assignToGroup(
+ group: GroupLeafBuilder[T],
+ constraintSet: PropertySet[T]): Unit = {
+ roleDef.assignToGroup(group, constraintSet.asMap(roleDef))
+ delegate.assignToGroup(group, PropertySet(constraintSet.asMap -
roleDef))
+ }
+
+ override def newEnforcerRuleFactory(): EnforcerRuleFactory[T] = {
+ new EnforcerRuleFactory[T] {
+ private val delegateFactory: EnforcerRuleFactory[T] =
delegate.newEnforcerRuleFactory()
+
+ override def newEnforcerRules(constraintSet: PropertySet[T]):
Seq[RasRule[T]] = {
+ delegateFactory.newEnforcerRules(constraintSet -: roleDef)
+ }
}
}
- }
- private def compose(memoRole: MemoRole[T], base: PropertySet[T]):
PropertySet[T] = {
- base +: memoRole
+ private def compose(memoRole: MemoRole[T], base: PropertySet[T]):
PropertySet[T] = {
+ base +: memoRole
+ }
}
}
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/EnforcerRuleSet.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/EnforcerRuleSet.scala
index a254851f43..c9604aa013 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/EnforcerRuleSet.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/EnforcerRuleSet.scala
@@ -59,7 +59,7 @@ object EnforcerRuleSet {
}
private class Regular[T <: AnyRef](ras: Ras[T], closure: Closure[T])
extends Factory[T] {
- private val factory = ras.propertySetFactory().newEnforcerRuleFactory()
+ private val factory = ras.newEnforcerRuleFactory()
private val ruleSetBuffer = mutable.Map[PropertySet[T],
EnforcerRuleSet[T]]()
override def ruleSetOf(constraintSet: PropertySet[T]):
EnforcerRuleSet[T] = {
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
index 4b4e0d45da..1b549d6663 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
@@ -22,6 +22,7 @@ import org.apache.gluten.ras.memo.MemoState
import org.apache.gluten.ras.path._
import scala.collection.mutable
+import scala.util.Random
// Visualize the planning procedure using dot language.
class GraphvizVisualizer[T <: AnyRef](ras: Ras[T], memoState: MemoState[T],
best: Best[T]) {
@@ -41,7 +42,8 @@ class GraphvizVisualizer[T <: AnyRef](ras: Ras[T], memoState:
MemoState[T], best
val buf = new StringBuilder()
buf.append("digraph G {\n")
- buf.append(" compound=true;\n")
+ buf.append(" compound=true\n")
+ buf.append(" rankdir=TB\n")
object IsBestNode {
def unapply(nodeAndGroupToTest: (CanonicalNode[T], RasGroup[T])):
Boolean = {
@@ -57,6 +59,18 @@ class GraphvizVisualizer[T <: AnyRef](ras: Ras[T],
memoState: MemoState[T], best
val clusterToGroups: mutable.Map[RasClusterKey, mutable.Set[Int]] =
mutable.Map()
+ def determineGroupColor(group: RasGroup[T]): String = {
+ val isRootGroup = group.id() == rootGroupId
+ if (isRootGroup) {
+ return "lightyellow"
+ }
+ val isHubGroup = group.constraintSet() == ras.hubConstraintSet()
+ if (isHubGroup) {
+ return "lightgrey"
+ }
+ "lightblue"
+ }
+
allGroups.foreach {
group => clusterToGroups.getOrElseUpdate(group.clusterKey(),
mutable.Set()).add(group.id())
}
@@ -71,6 +85,8 @@ class GraphvizVisualizer[T <: AnyRef](ras: Ras[T], memoState:
MemoState[T], best
clusterToGroups(clusterKey).map(allGroups(_)).foreach {
group =>
buf.append(s" subgraph cluster$dotClusterId {\n")
+ buf.append(s" style=filled\n")
+ buf.append(s" fillcolor=${determineGroupColor(group)}\n")
groupToDotClusterId += group.id() -> dotClusterId
dotClusterId = dotClusterId + 1
buf.append(s"
label=${'"'}${describeGroupVerbose(group)}${'"'}\n")
@@ -80,9 +96,9 @@ class GraphvizVisualizer[T <: AnyRef](ras: Ras[T], memoState:
MemoState[T], best
buf.append(s" ${'"'}${describeNode(costs, group,
node)}${'"'}")
(node, group) match {
case IsBestNode() =>
- buf.append(" [style=filled, fillcolor=green] ")
+ buf.append(" [style=filled, fillcolor=lightgreen] ")
case IsWinnerNode() =>
- buf.append(" [style=filled, fillcolor=grey] ")
+ buf.append(" [style=filled, fillcolor=lightgrey] ")
case _ =>
}
buf.append("\n")
@@ -99,9 +115,9 @@ class GraphvizVisualizer[T <: AnyRef](ras: Ras[T],
memoState: MemoState[T], best
node =>
node.getChildrenGroups(allGroups).map(_.group(allGroups)).foreach {
childGroup =>
- val childGroupNodes = childGroup.nodes(memoState)
+ val childGroupNodes = childGroup.nodes(memoState).toSeq
if (childGroupNodes.nonEmpty) {
- val randomChild = childGroupNodes.head
+ val randomChild =
childGroupNodes(Random.nextInt(childGroupNodes.size))
buf.append(
s" ${'"'}${describeNode(costs, group, node)}${'"'} -> " +
s"${'"'}${describeNode(costs, childGroup,
randomChild)}${'"'} " +
diff --git
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockMemoState.scala
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockMemoState.scala
index b33073cd3c..487a5986d1 100644
---
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockMemoState.scala
+++
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockMemoState.scala
@@ -53,11 +53,6 @@ case class MockMemoState[T <: AnyRef] private (
override def getHubGroup(key: RasClusterKey): RasGroup[T] =
throw new UnsupportedOperationException()
-
- override def clusterUserGroupLookup(): Map[RasClusterKey, RasGroup[T]] =
Map.empty
-
- override def getUserGroup(key: RasClusterKey): RasGroup[T] =
- throw new UnsupportedOperationException()
}
object MockMemoState {
@@ -153,7 +148,7 @@ object MockMemoState {
id,
clusterKey,
propSet,
- ras.newGroupLeaf(id, clusterKey.metadata, propSet +:
ras.memoRoleDef.reqUser))
+ ras.newGroupLeaf(id, clusterKey.metadata,
ras.withUserConstraint(propSet)))
groupBuffer += group
group
}
diff --git
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/property/MemoRoleSuite.scala
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/property/MemoRoleSuite.scala
new file mode 100644
index 0000000000..9cb43cd6ad
--- /dev/null
+++
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/property/MemoRoleSuite.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.ras.property
+
+import org.apache.gluten.ras.{Ras, RasConfig, RasSuite}
+import org.apache.gluten.ras.RasSuiteBase.{CostModelImpl, ExplainImpl,
MetadataModelImpl, PlanModelImpl, PropertyModelImpl, TestNode}
+import org.apache.gluten.ras.rule.RasRule
+
+class MemoRoleSuite extends RasSuite {
+ override protected def conf: RasConfig = RasConfig(plannerType =
RasConfig.PlannerType.Dp)
+
+ test("equality") {
+ val ras =
+ Ras[TestNode](
+ PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ RasRule.Factory.none())
+ .withNewConfig(_ => conf)
+ val one = ras.userConstraintSet()
+ val other = ras.withUserConstraint(PropertySet(Nil))
+ assert(one == other)
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]