This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new bc99c4910 [CORE][VL] Avoid re-exploring explored nodes in DpPlanner
(#5363)
bc99c4910 is described below
commit bc99c4910e26648195fec789c090b06c2b4379e2
Author: Hongze Zhang <[email protected]>
AuthorDate: Thu Apr 11 15:53:26 2024 +0800
[CORE][VL] Avoid re-exploring explored nodes in DpPlanner (#5363)
---
.../src/main/scala/org/apache/gluten/ras/Ras.scala | 14 ++---
.../scala/org/apache/gluten/ras/RasCluster.scala | 8 +--
.../main/scala/org/apache/gluten/ras/RasNode.scala | 21 ++++----
.../scala/org/apache/gluten/ras/RasPlanner.scala | 10 ++--
.../org/apache/gluten/ras/best/BestFinder.scala | 6 +--
.../org/apache/gluten/ras/dp/DpClusterAlgo.scala | 2 +-
.../org/apache/gluten/ras/dp/DpGroupAlgo.scala | 2 +-
.../scala/org/apache/gluten/ras/dp/DpPlanner.scala | 62 +++++++++++++++++-----
.../scala/org/apache/gluten/ras/memo/Memo.scala | 6 +--
.../org/apache/gluten/ras/path/OutputFilter.scala | 15 ++++++
.../org/apache/gluten/ras/rule/EnforcerRule.scala | 39 +++++++++-----
.../org/apache/gluten/ras/rule/RuleApplier.scala | 10 ++--
.../apache/gluten/ras/vis/GraphvizVisualizer.scala | 4 +-
.../scala/org/apache/gluten/ras/RasSuite.scala | 53 ++++++++++++++++++
14 files changed, 184 insertions(+), 68 deletions(-)
diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
index f3d46847e..804d04d81 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
@@ -171,7 +171,7 @@ class Ras[T <: AnyRef] private (
private[ras] def isInfCost(cost: Cost) =
costModel.costComparator().equiv(cost, infCost)
- private[ras] def toUnsafeKey(node: T): UnsafeKey[T] = UnsafeKey(this, node)
+ private[ras] def toHashKey(node: T): UnsafeHashKey[T] = UnsafeHashKey(this,
node)
}
object Ras {
@@ -251,15 +251,17 @@ object Ras {
}
}
- trait UnsafeKey[T]
+ trait UnsafeHashKey[T]
- private object UnsafeKey {
- def apply[T <: AnyRef](ras: Ras[T], self: T): UnsafeKey[T] = new
UnsafeKeyImpl(ras, self)
- private class UnsafeKeyImpl[T <: AnyRef](ras: Ras[T], val self: T) extends
UnsafeKey[T] {
+ private object UnsafeHashKey {
+ def apply[T <: AnyRef](ras: Ras[T], self: T): UnsafeHashKey[T] =
+ new UnsafeHashKeyImpl(ras, self)
+ private class UnsafeHashKeyImpl[T <: AnyRef](ras: Ras[T], val self: T)
+ extends UnsafeHashKey[T] {
override def hashCode(): Int = ras.planModel.hashCode(self)
override def equals(other: Any): Boolean = {
other match {
- case that: UnsafeKeyImpl[T] => ras.planModel.equals(self, that.self)
+ case that: UnsafeHashKeyImpl[T] => ras.planModel.equals(self,
that.self)
case _ => false
}
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala
index 1b30e1242..eb2b41a91 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala
@@ -16,7 +16,7 @@
*/
package org.apache.gluten.ras
-import org.apache.gluten.ras.Ras.UnsafeKey
+import org.apache.gluten.ras.Ras.UnsafeHashKey
import org.apache.gluten.ras.memo.MemoTable
import org.apache.gluten.ras.property.PropertySet
@@ -55,16 +55,16 @@ object RasCluster {
override val ras: Ras[T],
metadata: Metadata)
extends MutableRasCluster[T] {
- private val deDup: mutable.Set[UnsafeKey[T]] = mutable.Set()
+ private val deDup: mutable.Set[UnsafeHashKey[T]] = mutable.Set()
private val buffer: mutable.ListBuffer[CanonicalNode[T]] =
mutable.ListBuffer()
override def contains(t: CanonicalNode[T]): Boolean = {
- deDup.contains(t.toUnsafeKey())
+ deDup.contains(t.toHashKey())
}
override def add(t: CanonicalNode[T]): Unit = {
- val key = t.toUnsafeKey()
+ val key = t.toHashKey()
assert(!deDup.contains(key))
ras.metadataModel.verify(metadata,
ras.metadataModel.metadataOf(t.self()))
deDup += key
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala
index 65ff8b735..710a4e682 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala
@@ -16,7 +16,7 @@
*/
package org.apache.gluten.ras
-import org.apache.gluten.ras.Ras.UnsafeKey
+import org.apache.gluten.ras.Ras.UnsafeHashKey
import org.apache.gluten.ras.property.PropertySet
trait RasNode[T <: AnyRef] {
@@ -43,7 +43,7 @@ object RasNode {
node.asInstanceOf[GroupNode[T]]
}
- def toUnsafeKey(): UnsafeKey[T] = node.ras().toUnsafeKey(node.self())
+ def toHashKey(): UnsafeHashKey[T] = node.ras().toHashKey(node.self())
}
}
@@ -131,16 +131,16 @@ object InGroupNode {
private case class InGroupNodeImpl[T <: AnyRef](groupId: Int, can:
CanonicalNode[T])
extends InGroupNode[T]
- trait HashKey extends Any
+ trait UniqueKey extends Any
implicit class InGroupNodeImplicits[T <: AnyRef](n: InGroupNode[T]) {
import InGroupNodeImplicits._
- def toHashKey: HashKey =
- InGroupNodeHashKeyImpl(n.groupId, System.identityHashCode(n.can))
+ def toUniqueKey: UniqueKey =
+ InGroupNodeUniqueKeyImpl(n.groupId, System.identityHashCode(n.can))
}
private object InGroupNodeImplicits {
- private case class InGroupNodeHashKeyImpl(gid: Int, cid: Int) extends
HashKey
+ private case class InGroupNodeUniqueKeyImpl(gid: Int, cid: Int) extends
UniqueKey
}
}
@@ -159,15 +159,16 @@ object InClusterNode {
can: CanonicalNode[T])
extends InClusterNode[T]
- trait HashKey extends Any
+ trait UniqueKey extends Any
implicit class InClusterNodeImplicits[T <: AnyRef](n: InClusterNode[T]) {
import InClusterNodeImplicits._
- def toHashKey: HashKey =
- InClusterNodeHashKeyImpl(n.clusterKey, System.identityHashCode(n.can))
+ def toUniqueKey: UniqueKey =
+ InClusterNodeUniqueKeyImpl(n.clusterKey, System.identityHashCode(n.can))
}
private object InClusterNodeImplicits {
- private case class InClusterNodeHashKeyImpl(clusterKey: RasClusterKey,
cid: Int) extends HashKey
+ private case class InClusterNodeUniqueKeyImpl(clusterKey: RasClusterKey,
cid: Int)
+ extends UniqueKey
}
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasPlanner.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasPlanner.scala
index 74793a3d0..327b980f3 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasPlanner.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasPlanner.scala
@@ -62,11 +62,11 @@ object Best {
bestPath: KnownCostPath[T],
winnerNodes: Seq[InGroupNode[T]],
costs: InGroupNode[T] => Option[Cost]): Best[T] = {
- val bestNodes = mutable.Set[InGroupNode.HashKey]()
+ val bestNodes = mutable.Set[InGroupNode.UniqueKey]()
def dfs(groupId: Int, cursor: RasPath.PathNode[T]): Unit = {
val can = cursor.self().asCanonical()
- bestNodes += InGroupNode(groupId, can).toHashKey
+ bestNodes += InGroupNode(groupId, can).toUniqueKey
cursor.zipChildrenWithGroupIds().foreach {
case (childPathNode, childGroupId) =>
dfs(childGroupId, childPathNode)
@@ -76,14 +76,14 @@ object Best {
dfs(rootGroupId, bestPath.rasPath.node())
val bestNodeSet = bestNodes.toSet
- val winnerNodeSet = winnerNodes.map(_.toHashKey).toSet
+ val winnerNodeSet = winnerNodes.map(_.toUniqueKey).toSet
BestImpl(
ras,
rootGroupId,
bestPath,
- n => bestNodeSet.contains(n.toHashKey),
- n => winnerNodeSet.contains(n.toHashKey),
+ n => bestNodeSet.contains(n.toUniqueKey),
+ n => winnerNodeSet.contains(n.toUniqueKey),
costs)
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/BestFinder.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/BestFinder.scala
index 90a0adfb2..601cd72e5 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/BestFinder.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/BestFinder.scala
@@ -57,17 +57,17 @@ object BestFinder {
val bestPath = groupToCosts(group.id()).best()
val winnerNodes = groupToCosts.map { case (id, g) => InGroupNode(id,
g.bestNode) }.toSeq
- val costsMap = mutable.Map[InGroupNode.HashKey, Cost]()
+ val costsMap = mutable.Map[InGroupNode.UniqueKey, Cost]()
groupToCosts.foreach {
case (gid, g) =>
g.nodes.foreach {
n =>
val c = g.nodeToCost(n)
if (c.nonEmpty) {
- costsMap += (InGroupNode(gid, n).toHashKey -> c.get.cost)
+ costsMap += (InGroupNode(gid, n).toUniqueKey -> c.get.cost)
}
}
}
- Best(ras, group.id(), bestPath, winnerNodes, ign =>
costsMap.get(ign.toHashKey))
+ Best(ras, group.id(), bestPath, winnerNodes, ign =>
costsMap.get(ign.toUniqueKey))
}
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpClusterAlgo.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpClusterAlgo.scala
index 6fd95772b..046760ceb 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpClusterAlgo.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpClusterAlgo.scala
@@ -73,7 +73,7 @@ object DpClusterAlgo {
clusterAlgoDef: DpClusterAlgoDef[T, NodeOutput, ClusterOutput])
extends DpZipperAlgoDef[InClusterNode[T], RasClusterKey, NodeOutput,
ClusterOutput] {
override def idOfX(x: InClusterNode[T]): Any = {
- x.toHashKey
+ x.toUniqueKey
}
override def idOfY(y: RasClusterKey): Any = {
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala
index c824fda8e..f88f7b6e4 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala
@@ -66,7 +66,7 @@ object DpGroupAlgo {
groupAlgoDef: DpGroupAlgoDef[T, NodeOutput, GroupOutput])
extends DpZipperAlgoDef[InGroupNode[T], RasGroup[T], NodeOutput,
GroupOutput] {
override def idOfX(x: InGroupNode[T]): Any = {
- x.toHashKey
+ x.toUniqueKey
}
override def idOfY(y: RasGroup[T]): Any = {
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
index 4a9e3f0f0..3f2590dff 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
@@ -21,7 +21,7 @@ import org.apache.gluten.ras.Best.KnownCostPath
import org.apache.gluten.ras.best.BestFinder
import org.apache.gluten.ras.dp.DpZipperAlgo.Adjustment.Panel
import org.apache.gluten.ras.memo.{Memo, MemoTable}
-import org.apache.gluten.ras.path.{InClusterPath, PathFinder, RasPath}
+import org.apache.gluten.ras.path._
import org.apache.gluten.ras.property.PropertySet
import org.apache.gluten.ras.rule.{EnforcerRuleSet, RuleApplier, Shape}
@@ -99,10 +99,16 @@ object DpPlanner {
rules: Seq[RuleApplier[T]],
enforcerRuleSet: EnforcerRuleSet[T])
extends DpClusterAlgo.Adjustment[T] {
+ import ExploreAdjustment._
+
+ private val ruleShapes: Seq[Shape[T]] = rules.map(_.shape())
override def exploreChildX(
panel: Panel[InClusterNode[T], RasClusterKey],
- x: InClusterNode[T]): Unit = {}
+ x: InClusterNode[T]): Unit = {
+ applyRulesOnNode(panel, x.clusterKey, x.can)
+ }
+
override def exploreChildY(
panel: Panel[InClusterNode[T], RasClusterKey],
y: RasClusterKey): Unit = {}
@@ -115,20 +121,24 @@ object DpPlanner {
cKey: RasClusterKey): Unit = {
memoTable.doExhaustively {
applyEnforcerRules(panel, cKey)
- applyRules(panel, cKey)
}
}
- private def applyRules(
+ private def applyRulesOnNode(
panel: Panel[InClusterNode[T], RasClusterKey],
- cKey: RasClusterKey): Unit = {
+ cKey: RasClusterKey,
+ can: CanonicalNode[T]): Unit = {
if (rules.isEmpty) {
return
}
val dummyGroup = memoTable.getDummyGroup(cKey)
- val shapes = rules.map(_.shape())
- findPaths(GroupNode(ras, dummyGroup), shapes) {
- path => rules.foreach(rule => applyRule(panel, cKey, rule, path))
+ findPaths(GroupNode(ras, dummyGroup), ruleShapes, List(new
FromSingleNode[T](can))) {
+ path =>
+ val rootNode = path.node().self()
+ if (rootNode.isCanonical) {
+ assert(rootNode.asCanonical() eq can)
+ }
+ rules.foreach(rule => applyRule(panel, cKey, rule, path))
}
}
@@ -137,27 +147,34 @@ object DpPlanner {
cKey: RasClusterKey): Unit = {
val dummyGroup = memoTable.getDummyGroup(cKey)
cKey.propSets(memoTable).foreach {
- constraintSet =>
+ constraintSet: PropertySet[T] =>
val enforcerRules = enforcerRuleSet.rulesOf(constraintSet)
if (enforcerRules.nonEmpty) {
- val shapes = enforcerRules.map(_.shape())
- findPaths(GroupNode(ras, dummyGroup), shapes) {
+ val shapes = enforcerRuleSet.ruleShapesOf(constraintSet)
+ findPaths(GroupNode(ras, dummyGroup), shapes, List.empty) {
path => enforcerRules.foreach(rule => applyRule(panel, cKey,
rule, path))
}
}
}
}
- private def findPaths(gn: GroupNode[T], shapes: Seq[Shape[T]])(
+ private def findPaths(gn: GroupNode[T], shapes: Seq[Shape[T]], filters:
Seq[FilterWizard[T]])(
onFound: RasPath[T] => Unit): Unit = {
- val finder = shapes
+ val finderBuilder = shapes
.foldLeft(
PathFinder
.builder(ras, memoTable)) {
case (builder, shape) =>
builder.output(shape.wizard())
}
+
+ val finder = filters
+ .foldLeft(finderBuilder) {
+ case (builder, filter) =>
+ builder.filter(filter)
+ }
.build()
+
finder.find(gn).foreach(path => onFound(path))
}
@@ -191,5 +208,22 @@ object DpPlanner {
}
}
- private object ExploreAdjustment {}
+ private object ExploreAdjustment {
+ private class FromSingleNode[T <: AnyRef](from: CanonicalNode[T]) extends
FilterWizard[T] {
+ override def omit(can: CanonicalNode[T]): FilterWizard.FilterAction[T] =
{
+ if (can eq from) {
+ return FilterWizard.FilterAction.Continue(this)
+ }
+ FilterWizard.FilterAction.omit
+ }
+
+ override def omit(group: GroupNode[T]): FilterWizard.FilterAction[T] =
+ FilterWizard.FilterAction.Continue(this)
+
+ override def advance(offset: Int, count: Int):
FilterWizard.FilterAdvanceAction[T] = {
+ // We only filter on nodes from the root group. So continue with a
noop filter.
+ FilterWizard.FilterAdvanceAction.Continue(FilterWizards.none())
+ }
+ }
+ }
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
index 6406b8fb1..c67120357 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
@@ -17,7 +17,7 @@
package org.apache.gluten.ras.memo
import org.apache.gluten.ras._
-import org.apache.gluten.ras.Ras.UnsafeKey
+import org.apache.gluten.ras.Ras.UnsafeHashKey
import org.apache.gluten.ras.property.PropertySet
import org.apache.gluten.ras.vis.GraphvizVisualizer
@@ -236,11 +236,11 @@ object Memo {
private object MemoCacheKey {
def apply[T <: AnyRef](ras: Ras[T], self: T): MemoCacheKey[T] = {
assert(ras.isCanonical(self))
- MemoCacheKey[T](ras.toUnsafeKey(self))
+ MemoCacheKey[T](ras.toHashKey(self))
}
}
- private case class MemoCacheKey[T <: AnyRef] private (delegate: UnsafeKey[T])
+ private case class MemoCacheKey[T <: AnyRef] private (delegate:
UnsafeHashKey[T])
}
trait MemoStore[T <: AnyRef] {
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/OutputFilter.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/OutputFilter.scala
index 126ae7766..253e9ec84 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/OutputFilter.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/OutputFilter.scala
@@ -54,6 +54,21 @@ object FilterWizards {
OmitCycles[T](CycleDetector[GroupNode[T]]((one, other) => one.groupId() ==
other.groupId()))
}
+ def none[T <: AnyRef](): FilterWizard[T] = {
+ None[T]()
+ }
+
+ private class None[T <: AnyRef] private () extends FilterWizard[T] {
+ override def omit(can: CanonicalNode[T]): FilterAction[T] =
FilterAction.Continue(this)
+ override def omit(group: GroupNode[T]): FilterAction[T] =
FilterAction.Continue(this)
+ override def advance(offset: Int, count: Int): FilterAdvanceAction[T] =
+ FilterAdvanceAction.Continue(this)
+ }
+
+ private object None {
+ def apply[T <: AnyRef](): None[T] = new None[T]()
+ }
+
// Cycle detection starts from the first visited group in the input path.
private class OmitCycles[T <: AnyRef] private (detector:
CycleDetector[GroupNode[T]])
extends FilterWizard[T] {
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/EnforcerRule.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/EnforcerRule.scala
index c18936973..439b88a2c 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/EnforcerRule.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/EnforcerRule.scala
@@ -54,6 +54,7 @@ object EnforcerRule {
trait EnforcerRuleSet[T <: AnyRef] {
def rulesOf(constraintSet: PropertySet[T]): Seq[RuleApplier[T]]
+ def ruleShapesOf(constraintSet: PropertySet[T]): Seq[Shape[T]]
}
object EnforcerRuleSet {
@@ -73,21 +74,31 @@ object EnforcerRuleSet {
mutable.Map[PropertyDef[T, _ <: Property[T]], EnforcerRuleFactory[T]]()
private val buffer = mutable.Map[Property[T], Seq[RuleApplier[T]]]()
+ private val rulesBuffer = mutable.Map[PropertySet[T],
Seq[RuleApplier[T]]]()
+ private val shapesBuffer = mutable.Map[PropertySet[T], Seq[Shape[T]]]()
+
override def rulesOf(constraintSet: PropertySet[T]): Seq[RuleApplier[T]] =
{
- constraintSet.getMap.flatMap {
- case (constraintDef, constraint) =>
- buffer.getOrElseUpdate(
- constraint, {
- val factory =
- factoryBuffer.getOrElseUpdate(
- constraintDef,
- newEnforcerRuleFactory(ras, constraintDef))
- RuleApplier(ras, closure, EnforcerRule.builtin(constraint)) +:
factory
- .newEnforcerRules(constraint)
- .map(rule => RuleApplier(ras, closure, EnforcerRule(rule,
constraint)))
- }
- )
- }.toSeq
+ rulesBuffer.getOrElseUpdate(
+ constraintSet,
+ constraintSet.getMap.flatMap {
+ case (constraintDef, constraint) =>
+ buffer.getOrElseUpdate(
+ constraint, {
+ val factory =
+ factoryBuffer.getOrElseUpdate(
+ constraintDef,
+ newEnforcerRuleFactory(ras, constraintDef))
+ RuleApplier(ras, closure, EnforcerRule.builtin(constraint)) +:
factory
+ .newEnforcerRules(constraint)
+ .map(rule => RuleApplier(ras, closure, EnforcerRule(rule,
constraint)))
+ }
+ )
+ }.toSeq
+ )
+ }
+
+ override def ruleShapesOf(constraintSet: PropertySet[T]): Seq[Shape[T]] = {
+ shapesBuffer.getOrElseUpdate(constraintSet,
rulesBuffer(constraintSet).map(_.shape()))
}
}
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala
index 01e826f06..6b4082c7e 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala
@@ -17,7 +17,7 @@
package org.apache.gluten.ras.rule
import org.apache.gluten.ras._
-import org.apache.gluten.ras.Ras.UnsafeKey
+import org.apache.gluten.ras.Ras.UnsafeHashKey
import org.apache.gluten.ras.memo.Closure
import org.apache.gluten.ras.path.InClusterPath
import org.apache.gluten.ras.property.PropertySet
@@ -43,14 +43,14 @@ object RuleApplier {
private class RegularRuleApplier[T <: AnyRef](ras: Ras[T], closure:
Closure[T], rule: RasRule[T])
extends RuleApplier[T] {
- private val deDup = mutable.Map[RasClusterKey, mutable.Set[UnsafeKey[T]]]()
+ private val deDup = mutable.Map[RasClusterKey,
mutable.Set[UnsafeHashKey[T]]]()
override def apply(icp: InClusterPath[T]): Unit = {
val cKey = icp.cluster()
val path = icp.path()
val plan = path.plan()
val appliedPlans = deDup.getOrElseUpdate(cKey, mutable.Set())
- val pKey = ras.toUnsafeKey(plan)
+ val pKey = ras.toHashKey(plan)
if (appliedPlans.contains(pKey)) {
return
}
@@ -76,7 +76,7 @@ object RuleApplier {
closure: Closure[T],
rule: EnforcerRule[T])
extends RuleApplier[T] {
- private val deDup = mutable.Map[RasClusterKey, mutable.Set[UnsafeKey[T]]]()
+ private val deDup = mutable.Map[RasClusterKey,
mutable.Set[UnsafeHashKey[T]]]()
private val constraint = rule.constraint()
private val constraintDef = constraint.definition()
@@ -88,7 +88,7 @@ object RuleApplier {
return
}
val plan = path.plan()
- val pKey = ras.toUnsafeKey(plan)
+ val pKey = ras.toHashKey(plan)
val appliedPlans = deDup.getOrElseUpdate(cKey, mutable.Set())
if (appliedPlans.contains(pKey)) {
return
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
index 018c8087e..b420d8c29 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
@@ -29,7 +29,7 @@ class GraphvizVisualizer[T <: AnyRef](ras: Ras[T], memoState:
MemoState[T], best
private val allGroups = memoState.allGroups()
private val allClusters = memoState.clusterLookup()
- private val nodeToId = mutable.Map[InGroupNode.HashKey, Int]()
+ private val nodeToId = mutable.Map[InGroupNode.UniqueKey, Int]()
def format(): String = {
val rootGroupId = best.rootGroupId()
@@ -156,7 +156,7 @@ class GraphvizVisualizer[T <: AnyRef](ras: Ras[T],
memoState: MemoState[T], best
group: RasGroup[T],
node: CanonicalNode[T]): String = {
val ign = InGroupNode(group.id(), node)
- val nodeId = nodeToId.getOrElseUpdate(ign.toHashKey, nodeToId.size)
+ val nodeId = nodeToId.getOrElseUpdate(ign.toUniqueKey, nodeToId.size)
s"[$nodeId][Cost ${costs(ign)
.map {
case c if ras.isInfCost(c) => "<INF>"
diff --git
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala
index b29e0c267..2f3ef348c 100644
--- a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala
+++ b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala
@@ -20,6 +20,7 @@ import org.apache.gluten.ras.RasConfig.PlannerType
import org.apache.gluten.ras.RasSuiteBase._
import org.apache.gluten.ras.memo.Memo
import org.apache.gluten.ras.path.Pattern
+import org.apache.gluten.ras.path.Pattern.Matchers
import org.apache.gluten.ras.rule.{RasRule, Shape, Shapes}
import org.scalatest.funsuite.AnyFunSuite
@@ -265,6 +266,58 @@ abstract class RasSuite extends AnyFunSuite {
assert(allPaths.size == 15)
}
+ test(s"Rule dependency") {
+ // Op3 relies on Op2 relies on Op1
+
+ object Op1 extends RasRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = node match {
+ case Leaf(70) =>
+ List(Leaf(69))
+ case other => List.empty
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
+ }
+
+ object Op2 extends RasRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = node match {
+ case Leaf(69) =>
+ List(Leaf(68))
+ case other => List.empty
+ }
+
+ override def shape(): Shape[TestNode] =
+
Shapes.pattern(Pattern.leaf[TestNode](Matchers.clazz(classOf[Leaf])).build())
+ }
+
+ object Op3 extends RasRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = node match {
+ case Leaf(68) =>
+ List(Leaf(67))
+ case other => List.empty
+ }
+
+ override def shape(): Shape[TestNode] =
+ Shapes.pattern(Pattern.any[TestNode].build())
+ }
+
+ val ras =
+ Ras[TestNode](
+ PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ RasRule.Factory.reuse(List(Op3, Op1, Op2)))
+ .withNewConfig(_ => conf)
+
+ val plan = Unary(90, Unary(90, Leaf(70)))
+ val planner = ras.newPlanner(plan)
+ val optimized = planner.plan()
+
+ assert(optimized == Unary(90, Unary(90, Leaf(67))))
+ }
+
test(s"Unary node insertion") {
object InsertUnary2 extends RasRule[TestNode] {
override def shift(node: TestNode): Iterable[TestNode] = node match {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]