This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 19da3bb1c [CORE][VL] RAS: Group expansion support (#5323)
19da3bb1c is described below
commit 19da3bb1cc8677913c097d7c4ccd1fc17189c438
Author: Hongze Zhang <[email protected]>
AuthorDate: Mon Apr 8 18:16:45 2024 +0800
[CORE][VL] RAS: Group expansion support (#5323)
---
.../src/main/scala/org/apache/gluten/ras/Ras.scala | 6 +-
.../main/scala/org/apache/gluten/ras/RasNode.scala | 2 +-
.../org/apache/gluten/ras/best/BestFinder.scala | 2 -
.../gluten/ras/best/GroupBasedBestFinder.scala | 2 +-
.../scala/org/apache/gluten/ras/dp/DpPlanner.scala | 23 ++--
.../gluten/ras/exaustive/ExhaustivePlanner.scala | 21 ++--
.../apache/gluten/ras/memo/ForwardMemoTable.scala | 25 +++--
.../scala/org/apache/gluten/ras/memo/Memo.scala | 28 +++--
.../org/apache/gluten/ras/memo/MemoTable.scala | 17 ++-
.../org/apache/gluten/ras/path/OutputFilter.scala | 33 ++++--
.../org/apache/gluten/ras/path/OutputWizard.scala | 100 +++++++++++-------
.../org/apache/gluten/ras/path/PathFinder.scala | 35 +++++--
.../scala/org/apache/gluten/ras/path/Pattern.scala | 9 +-
.../org/apache/gluten/ras/rule/RuleApplier.scala | 12 +--
.../scala/org/apache/gluten/ras/rule/Shape.scala | 10 ++
.../org/apache/gluten/ras/PropertySuite.scala | 116 ++++++++++++++++-----
.../scala/org/apache/gluten/ras/RasSuite.scala | 67 ++++++++++++
.../scala/org/apache/gluten/ras/RasSuiteBase.scala | 2 +-
.../org/apache/gluten/ras/mock/MockMemoState.scala | 5 +
.../org/apache/gluten/ras/mock/MockRasPath.scala | 2 +-
.../apache/gluten/ras/path/PathFinderSuite.scala | 60 ++++++++++-
.../org/apache/gluten/ras/path/WizardSuite.scala | 14 +++
.../gluten/ras/specific/DistributedSuite.scala | 2 +
23 files changed, 449 insertions(+), 144 deletions(-)
diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
index 9910fab6f..f3d46847e 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
@@ -31,7 +31,7 @@ trait Optimization[T <: AnyRef] {
constraintSet: PropertySet[T],
altConstraintSets: Seq[PropertySet[T]]): RasPlanner[T]
- def propSetsOf(plan: T): PropertySet[T]
+ def propSetOf(plan: T): PropertySet[T]
def withNewConfig(confFunc: RasConfig => RasConfig): Optimization[T]
}
@@ -49,7 +49,7 @@ object Optimization {
implicit class OptimizationImplicits[T <: AnyRef](opt: Optimization[T]) {
def newPlanner(plan: T): RasPlanner[T] = {
- opt.newPlanner(plan, opt.propSetsOf(plan), List.empty)
+ opt.newPlanner(plan, opt.propSetOf(plan), List.empty)
}
def newPlanner(plan: T, constraintSet: PropertySet[T]): RasPlanner[T] = {
opt.newPlanner(plan, constraintSet, List.empty)
@@ -131,7 +131,7 @@ class Ras[T <: AnyRef] private (
RasPlanner(this, altConstraintSets, constraintSet, plan)
}
- override def propSetsOf(plan: T): PropertySet[T] =
propertySetFactory().get(plan)
+ override def propSetOf(plan: T): PropertySet[T] =
propertySetFactory().get(plan)
private[ras] def withNewChildren(node: T, newChildren: Seq[T]): T = {
val oldChildren = planModel.childrenOf(node)
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala
index 878020391..65ff8b735 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala
@@ -54,7 +54,7 @@ trait CanonicalNode[T <: AnyRef] extends RasNode[T] {
object CanonicalNode {
def apply[T <: AnyRef](ras: Ras[T], canonical: T): CanonicalNode[T] = {
assert(ras.isCanonical(canonical))
- val propSet = ras.propSetsOf(canonical)
+ val propSet = ras.propSetOf(canonical)
val children = ras.planModel.childrenOf(canonical)
new CanonicalNodeImpl[T](ras, canonical, propSet, children.size)
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/BestFinder.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/BestFinder.scala
index 0912ab536..90a0adfb2 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/BestFinder.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/BestFinder.scala
@@ -52,12 +52,10 @@ object BestFinder {
private[best] def newBest[T <: AnyRef](
ras: Ras[T],
- allGroups: Seq[RasGroup[T]],
group: RasGroup[T],
groupToCosts: Map[Int, KnownCostGroup[T]]): Best[T] = {
val bestPath = groupToCosts(group.id()).best()
- val bestRoot = bestPath.rasPath.node()
val winnerNodes = groupToCosts.map { case (id, g) => InGroupNode(id,
g.bestNode) }.toSeq
val costsMap = mutable.Map[InGroupNode.HashKey, Cost]()
groupToCosts.foreach {
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/GroupBasedBestFinder.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/GroupBasedBestFinder.scala
index 6db3600de..effebd41b 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/GroupBasedBestFinder.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/GroupBasedBestFinder.scala
@@ -43,7 +43,7 @@ private class GroupBasedBestFinder[T <: AnyRef](
s"Best path not found. Memo state (Graphviz): \n" +
s"${memoState.formatGraphvizWithoutBest(groupId)}")
}
- BestFinder.newBest(ras, allGroups, group, groupToCosts)
+ BestFinder.newBest(ras, group, groupToCosts)
}
private def fillBests(group: RasGroup[T]): Map[Int, KnownCostGroup[T]] = {
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
index 1be728ae6..8acf66c59 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
@@ -99,8 +99,6 @@ object DpPlanner {
rules: Seq[RuleApplier[T]],
enforcerRuleSet: EnforcerRuleSet[T])
extends DpClusterAlgo.Adjustment[T] {
- private val allGroups = memoTable.allGroups()
- private val clusterLookup = cKey => memoTable.getCluster(cKey)
override def exploreChildX(
panel: Panel[InClusterNode[T], RasClusterKey],
@@ -127,33 +125,30 @@ object DpPlanner {
if (rules.isEmpty) {
return
}
- val cluster = clusterLookup(cKey)
- cluster.nodes().foreach {
- node =>
- val shapes = rules.map(_.shape())
- findPaths(node, shapes)(path => rules.foreach(rule =>
applyRule(panel, cKey, rule, path)))
+ val dummyGroup = memoTable.getDummyGroup(cKey)
+ val shapes = rules.map(_.shape())
+ findPaths(GroupNode(ras, dummyGroup), shapes) {
+ path => rules.foreach(rule => applyRule(panel, cKey, rule, path))
}
}
private def applyEnforcerRules(
panel: Panel[InClusterNode[T], RasClusterKey],
cKey: RasClusterKey): Unit = {
- val cluster = clusterLookup(cKey)
+ val dummyGroup = memoTable.getDummyGroup(cKey)
cKey.propSets(memoTable).foreach {
constraintSet =>
val enforcerRules = enforcerRuleSet.rulesOf(constraintSet)
if (enforcerRules.nonEmpty) {
val shapes = enforcerRules.map(_.shape())
- cluster.nodes().foreach {
- node =>
- findPaths(node, shapes)(
- path => enforcerRules.foreach(rule => applyRule(panel, cKey,
rule, path)))
+ findPaths(GroupNode(ras, dummyGroup), shapes) {
+ path => enforcerRules.foreach(rule => applyRule(panel, cKey,
rule, path))
}
}
}
}
- private def findPaths(canonical: CanonicalNode[T], shapes: Seq[Shape[T]])(
+ private def findPaths(gn: GroupNode[T], shapes: Seq[Shape[T]])(
onFound: RasPath[T] => Unit): Unit = {
val finder = shapes
.foldLeft(
@@ -163,7 +158,7 @@ object DpPlanner {
builder.output(shape.wizard())
}
.build()
- finder.find(canonical).foreach(path => onFound(path))
+ finder.find(gn).foreach(path => onFound(path))
}
private def applyRule(
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
index 47945fc14..3db649b64 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala
@@ -94,7 +94,7 @@ object ExhaustivePlanner {
applyRules()
}
- private def findPaths(canonical: CanonicalNode[T], shapes: Seq[Shape[T]])(
+ private def findPaths(gn: GroupNode[T], shapes: Seq[Shape[T]])(
onFound: RasPath[T] => Unit): Unit = {
val finder = shapes
.foldLeft(
@@ -104,7 +104,7 @@ object ExhaustivePlanner {
builder.output(shape.wizard())
}
.build()
- finder.find(canonical).foreach(path => onFound(path))
+ finder.find(gn).foreach(path => onFound(path))
}
private def applyRule(rule: RuleApplier[T], icp: InClusterPath[T]): Unit =
{
@@ -120,12 +120,10 @@ object ExhaustivePlanner {
.clusterLookup()
.foreach {
case (cKey, cluster) =>
- cluster
- .nodes()
- .foreach(
- node =>
- findPaths(node, shapes)(
- path => rules.foreach(rule => applyRule(rule,
InClusterPath(cKey, path)))))
+ val dummyGroup = memoState.getDummyGroup(cKey)
+ findPaths(GroupNode(ras, dummyGroup), shapes) {
+ path => rules.foreach(rule => applyRule(rule,
InClusterPath(cKey, path)))
+ }
}
}
@@ -137,10 +135,9 @@ object ExhaustivePlanner {
if (enforcerRules.nonEmpty) {
val shapes = enforcerRules.map(_.shape())
val cKey = group.clusterKey()
- memoState.clusterLookup()(cKey).nodes().foreach {
- node =>
- findPaths(node, shapes)(
- path => enforcerRules.foreach(rule => applyRule(rule,
InClusterPath(cKey, path))))
+ val dummyGroup = memoState.getDummyGroup(cKey)
+ findPaths(GroupNode(ras, dummyGroup), shapes) {
+ path => enforcerRules.foreach(rule => applyRule(rule,
InClusterPath(cKey, path)))
}
}
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
index e3ae03ebf..0895544d5 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
@@ -29,11 +29,12 @@ class ForwardMemoTable[T <: AnyRef] private (override val
ras: Ras[T])
import ForwardMemoTable._
private val groupBuffer: mutable.ArrayBuffer[RasGroup[T]] =
mutable.ArrayBuffer()
+ private val dummyGroupBuffer: mutable.ArrayBuffer[RasGroup[T]] =
+ mutable.ArrayBuffer[RasGroup[T]]()
private val clusterKeyBuffer: mutable.ArrayBuffer[IntClusterKey] =
mutable.ArrayBuffer()
private val clusterBuffer: mutable.ArrayBuffer[MutableRasCluster[T]] =
mutable.ArrayBuffer()
private val clusterDisjointSet: IndexDisjointSet = IndexDisjointSet()
- private val clusterDummyGroupBuffer = mutable.ArrayBuffer[RasGroup[T]]()
private val groupLookup: mutable.ArrayBuffer[mutable.Map[PropertySet[T],
RasGroup[T]]] =
mutable.ArrayBuffer()
@@ -55,13 +56,16 @@ class ForwardMemoTable[T <: AnyRef] private (override val
ras: Ras[T])
clusterDisjointSet.grow()
groupLookup += mutable.Map()
// Normal groups start with ID 0, so it's safe to use negative IDs for
dummy groups.
- clusterDummyGroupBuffer += RasGroup(ras, key, -clusterId,
ras.propertySetFactory().any())
+ // Dummy group ID starts from -1.
+ dummyGroupBuffer += RasGroup(ras, key, -(clusterId + 1),
ras.propertySetFactory().any())
key
}
- override def dummyGroupOf(key: RasClusterKey): RasGroup[T] = {
+ override def getDummyGroup(key: RasClusterKey): RasGroup[T] = {
val ancestor = ancestorClusterIdOf(key)
- clusterDummyGroupBuffer(ancestor)
+ val out = dummyGroupBuffer(ancestor)
+ assert(out.id() == -(ancestor + 1))
+ out
}
override def groupOf(key: RasClusterKey, propSet: PropertySet[T]):
RasGroup[T] = {
@@ -142,12 +146,21 @@ class ForwardMemoTable[T <: AnyRef] private (override val
ras: Ras[T])
memoWriteCount += 1
}
- override def getGroup(id: Int): RasGroup[T] = groupBuffer(id)
+ override def getGroup(id: Int): RasGroup[T] = {
+ if (id < 0) {
+ val out = dummyGroupBuffer((-id - 1))
+ assert(out.id() == id)
+ return out
+ }
+ groupBuffer(id)
+ }
override def allClusters(): Seq[RasClusterKey] = clusterKeyBuffer
override def allGroups(): Seq[RasGroup[T]] = groupBuffer
+ override def allDummyGroups(): Seq[RasGroup[T]] = dummyGroupBuffer
+
private def ancestorClusterIdOf(key: RasClusterKey): Int = {
clusterDisjointSet.find(key.id())
}
@@ -156,7 +169,7 @@ class ForwardMemoTable[T <: AnyRef] private (override val
ras: Ras[T])
assert(clusterKeyBuffer.size == clusterBuffer.size)
assert(clusterKeyBuffer.size == clusterDisjointSet.size)
assert(clusterKeyBuffer.size == groupLookup.size)
- assert(clusterKeyBuffer.size == clusterDummyGroupBuffer.size)
+ assert(clusterKeyBuffer.size == dummyGroupBuffer.size)
}
override def probe(): MemoTable.Probe[T] = new
ForwardMemoTable.Probe[T](this)
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
index 49281a82d..c1bb0a6bf 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
@@ -75,7 +75,7 @@ object Memo {
}
private def dummyGroupOf(clusterKey: RasClusterKey): RasGroup[T] = {
- memoTable.dummyGroupOf(clusterKey)
+ memoTable.getDummyGroup(clusterKey)
}
private def toCacheKeyUnsafe(n: T): MemoCacheKey[T] = {
@@ -84,7 +84,7 @@ object Memo {
private def prepareInsert(n: T): Prepare[T] = {
if (ras.isGroupLeaf(n)) {
- val group = memoTable.allGroups()(ras.planModel.getGroupId(n))
+ val group = memoTable.getGroup(ras.planModel.getGroupId(n))
return Prepare.cluster(this, group.clusterKey())
}
@@ -129,7 +129,7 @@ object Memo {
// TODO: Traverse up the tree to do more merges.
private def prepareInsert(node: T): Prepare[T] = {
if (ras.isGroupLeaf(node)) {
- val group =
parent.memoTable.allGroups()(ras.planModel.getGroupId(node))
+ val group = parent.memoTable.getGroup(ras.planModel.getGroupId(node))
val residentCluster = group.clusterKey()
if (residentCluster == targetCluster) {
@@ -246,6 +246,7 @@ object Memo {
trait MemoStore[T <: AnyRef] {
def getCluster(key: RasClusterKey): RasCluster[T]
+ def getDummyGroup(key: RasClusterKey): RasGroup[T]
def getGroup(id: Int): RasGroup[T]
}
@@ -260,6 +261,7 @@ object MemoStore {
trait MemoState[T <: AnyRef] extends MemoStore[T] {
def ras(): Ras[T]
def clusterLookup(): Map[RasClusterKey, RasCluster[T]]
+ def clusterDummyGroupLookup(): Map[RasClusterKey, RasGroup[T]]
def allClusters(): Iterable[RasCluster[T]]
def allGroups(): Seq[RasGroup[T]]
}
@@ -268,19 +270,31 @@ object MemoState {
def apply[T <: AnyRef](
ras: Ras[T],
clusterLookup: Map[RasClusterKey, ImmutableRasCluster[T]],
- allGroups: Seq[RasGroup[T]]): MemoState[T] = {
- MemoStateImpl(ras, clusterLookup, allGroups)
+ clusterDummyGroupLookup: Map[RasClusterKey, RasGroup[T]],
+ allGroups: Seq[RasGroup[T]],
+ allDummyGroups: Seq[RasGroup[T]]): MemoState[T] = {
+ MemoStateImpl(ras, clusterLookup, clusterDummyGroupLookup, allGroups,
allDummyGroups)
}
private case class MemoStateImpl[T <: AnyRef](
override val ras: Ras[T],
override val clusterLookup: Map[RasClusterKey, ImmutableRasCluster[T]],
- override val allGroups: Seq[RasGroup[T]])
+ override val clusterDummyGroupLookup: Map[RasClusterKey, RasGroup[T]],
+ override val allGroups: Seq[RasGroup[T]],
+ allDummyGroups: Seq[RasGroup[T]])
extends MemoState[T] {
private val allClustersCopy = clusterLookup.values
override def getCluster(key: RasClusterKey): RasCluster[T] =
clusterLookup(key)
- override def getGroup(id: Int): RasGroup[T] = allGroups(id)
+ override def getDummyGroup(key: RasClusterKey): RasGroup[T] =
clusterDummyGroupLookup(key)
+ override def getGroup(id: Int): RasGroup[T] = {
+ if (id < 0) {
+ val out = allDummyGroups((-id - 1))
+ assert(out.id() == id)
+ return out
+ }
+ allGroups(id)
+ }
override def allClusters(): Iterable[RasCluster[T]] = allClustersCopy
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala
index 3baba8eae..8e5dac02e 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala
@@ -30,6 +30,7 @@ sealed trait MemoTable[T <: AnyRef] extends MemoStore[T] {
def allClusters(): Seq[RasClusterKey]
def allGroups(): Seq[RasGroup[T]]
+ def allDummyGroups(): Seq[RasGroup[T]]
def getClusterPropSets(key: RasClusterKey): Set[PropertySet[T]]
@@ -44,7 +45,6 @@ object MemoTable {
trait Writable[T <: AnyRef] extends MemoTable[T] {
def newCluster(metadata: Metadata): RasClusterKey
def groupOf(key: RasClusterKey, propertySet: PropertySet[T]): RasGroup[T]
- def dummyGroupOf(key: RasClusterKey): RasGroup[T]
def addToCluster(key: RasClusterKey, node: CanonicalNode[T]): Unit
def mergeClusters(one: RasClusterKey, other: RasClusterKey): Unit
@@ -74,7 +74,20 @@ object MemoTable {
.allClusters()
.map(key => key -> ImmutableRasCluster(table.ras,
table.getCluster(key)))
.toMap
- MemoState(table.ras, immutableClusters, table.allGroups())
+ val immutableDummyGroups = table
+ .allClusters()
+ .map(key => key -> table.getDummyGroup(key))
+ .toMap
+ table.allDummyGroups().zipWithIndex.foreach {
+ case (group, idx) =>
+ assert(group.id() == -(idx + 1))
+ }
+ MemoState(
+ table.ras,
+ immutableClusters,
+ immutableDummyGroups,
+ table.allGroups(),
+ table.allDummyGroups())
}
def doExhaustively(func: => Unit): Unit = {
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/OutputFilter.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/OutputFilter.scala
index ed1c326f9..126ae7766 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/OutputFilter.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/OutputFilter.scala
@@ -17,14 +17,15 @@
package org.apache.gluten.ras.path
import org.apache.gluten.ras.{CanonicalNode, GroupNode}
-import org.apache.gluten.ras.path.FilterWizard.FilterAction
-import org.apache.gluten.ras.path.OutputWizard.OutputAction
+import org.apache.gluten.ras.path.FilterWizard.{FilterAction,
FilterAdvanceAction}
+import org.apache.gluten.ras.path.OutputWizard.{AdvanceAction, OutputAction}
import org.apache.gluten.ras.util.CycleDetector
trait FilterWizard[T <: AnyRef] {
import FilterWizard._
def omit(can: CanonicalNode[T]): FilterAction[T]
- def omit(group: GroupNode[T], offset: Int, count: Int): FilterAction[T]
+ def omit(group: GroupNode[T]): FilterAction[T]
+ def advance(offset: Int, count: Int): FilterAdvanceAction[T]
}
object FilterWizard {
@@ -40,6 +41,11 @@ object FilterWizard {
case class Continue[T <: AnyRef](newWizard: FilterWizard[T]) extends
FilterAction[T]
}
+
+ sealed trait FilterAdvanceAction[T <: AnyRef]
+ object FilterAdvanceAction {
+ case class Continue[T <: AnyRef](newWizard: FilterWizard[T]) extends
FilterAdvanceAction[T]
+ }
}
object FilterWizards {
@@ -55,12 +61,15 @@ object FilterWizards {
FilterAction.Continue(this)
}
- override def omit(group: GroupNode[T], offset: Int, count: Int):
FilterAction[T] = {
+ override def omit(group: GroupNode[T]): FilterAction[T] = {
if (detector.contains(group)) {
return FilterAction.omit
}
FilterAction.Continue(new OmitCycles(detector.append(group)))
}
+
+ override def advance(offset: Int, count: Int): FilterAdvanceAction[T] =
+ FilterAdvanceAction.Continue(this)
}
private object OmitCycles {
@@ -98,11 +107,11 @@ object OutputFilter {
}
}
- override def advance(group: GroupNode[T], offset: Int, count: Int):
OutputAction[T] = {
- filterWizard.omit(group: GroupNode[T], offset: Int, count: Int) match {
+ override def visit(group: GroupNode[T]): OutputAction[T] = {
+ filterWizard.omit(group: GroupNode[T]) match {
case FilterAction.Omit() => OutputAction.stop
case FilterAction.Continue(newFilterWizard) =>
- outputWizard.advance(group, offset, count) match {
+ outputWizard.visit(group) match {
case stop @ OutputAction.Stop(_) => stop
case OutputAction.Continue(drain, newOutputWizard) =>
OutputAction.Continue(drain, new
OutputFilterImpl(newOutputWizard, newFilterWizard))
@@ -110,6 +119,16 @@ object OutputFilter {
}
}
+ override def advance(offset: Int, count: Int):
OutputWizard.AdvanceAction[T] = {
+ val newOutputWizard = outputWizard.advance(offset, count) match {
+ case AdvanceAction.Continue(newWizard) => newWizard
+ }
+ val newFilterWizard = filterWizard.advance(offset, count) match {
+ case FilterAdvanceAction.Continue(newWizard) => newWizard
+ }
+ AdvanceAction.Continue(new OutputFilterImpl(newOutputWizard,
newFilterWizard))
+ }
+
override def withPathKey(newKey: PathKey): OutputWizard[T] =
new OutputFilterImpl[T](outputWizard.withPathKey(newKey), filterWizard)
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/OutputWizard.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/OutputWizard.scala
index 0e40c36d8..73e4a5c19 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/OutputWizard.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/OutputWizard.scala
@@ -17,7 +17,7 @@
package org.apache.gluten.ras.path
import org.apache.gluten.ras.{CanonicalNode, GroupNode, Ras, RasGroup}
-import org.apache.gluten.ras.path.OutputWizard.{OutputAction, PathDrain}
+import org.apache.gluten.ras.path.OutputWizard.{AdvanceAction, OutputAction,
PathDrain}
import scala.collection.{mutable, Seq}
@@ -25,11 +25,11 @@ trait OutputWizard[T <: AnyRef] {
import OutputWizard._
// Visit a new node.
def visit(can: CanonicalNode[T]): OutputAction[T]
- // The returned object is a wizard for one of the node's children at the
+ // Visit a new group.
+ def visit(group: GroupNode[T]): OutputAction[T]
+ // The returned action is typically a wizard for one of the node's children
at the
// known offset among all children.
- def advance(group: GroupNode[T], offset: Int, count: Int): OutputAction[T]
- // The returned wizard would be same with this wizard
- // except it drains paths with the input path key.
+ def advance(offset: Int, count: Int): AdvanceAction[T]
def withPathKey(newKey: PathKey): OutputWizard[T]
}
@@ -50,6 +50,11 @@ object OutputWizard {
extends OutputAction[T]
}
+ sealed trait AdvanceAction[T <: AnyRef]
+ object AdvanceAction {
+ case class Continue[T <: AnyRef](newWizard: OutputWizard[T]) extends
AdvanceAction[T]
+ }
+
// Path drain provides possibility to lazily materialize the yielded paths
using path key.
// Otherwise if each wizard emits its own paths during visiting, the de-dup
operation
// will be required and could cause serious performance issues.
@@ -97,12 +102,8 @@ object OutputWizard {
new NodePrepareImpl[T](ras, wizard, allGroups, can)
}
- def prepareForGroup(
- ras: Ras[T],
- group: GroupNode[T],
- offset: Int,
- count: Int): GroupPrepare[T] = {
- new GroupPrepareImpl[T](ras, wizard, group, offset, count)
+ def prepareForGroup(ras: Ras[T], group: GroupNode[T]): GroupPrepare[T] = {
+ new GroupPrepareImpl[T](ras, wizard, group)
}
}
@@ -112,7 +113,7 @@ object OutputWizard {
}
sealed trait GroupPrepare[T <: AnyRef] {
- def advance(): Terminate[T]
+ def visit(): Terminate[T]
}
sealed trait Terminate[T <: AnyRef] {
@@ -154,12 +155,10 @@ object OutputWizard {
private class GroupPrepareImpl[T <: AnyRef](
ras: Ras[T],
wizard: OutputWizard[T],
- group: GroupNode[T],
- offset: Int,
- count: Int)
+ group: GroupNode[T])
extends GroupPrepare[T] {
- override def advance(): Terminate[T] = {
- val action = wizard.advance(group, offset, count)
+ override def visit(): Terminate[T] = {
+ val action = wizard.visit(group)
val drained = if (action.drain().isEmpty()) {
List.empty
} else {
@@ -201,8 +200,12 @@ object OutputWizards {
OutputAction.Stop(PathDrain.none)
}
- override def advance(group: GroupNode[T], offset: Int, count: Int):
OutputAction[T] =
+ override def visit(group: GroupNode[T]): OutputAction[T] = {
OutputAction.Stop(PathDrain.none)
+ }
+
+ override def advance(offset: Int, count: Int):
OutputWizard.AdvanceAction[T] =
+ AdvanceAction.Continue(this)
override def withPathKey(newKey: PathKey): OutputWizard[T] = this
}
@@ -219,9 +222,11 @@ object OutputWizards {
OutputAction.Continue(PathDrain.none, this)
}
- override def advance(group: GroupNode[T], offset: Int, count: Int):
OutputAction[T] =
+ override def visit(group: GroupNode[T]): OutputAction[T] =
OutputAction.Continue(PathDrain.none, this)
+ override def advance(offset: Int, count: Int): AdvanceAction[T] =
AdvanceAction.Continue(this)
+
override def withPathKey(newKey: PathKey): OutputWizard[T] =
new Emit[T](PathDrain.Specific(List(newKey)))
}
@@ -267,12 +272,20 @@ object OutputWizards {
act(actions)
}
- override def advance(group: GroupNode[T], offset: Int, count: Int):
OutputAction[T] = {
+ override def visit(group: GroupNode[T]): OutputAction[T] = {
val actions = wizards
- .map(_.advance(group, offset, count))
+ .map(_.visit(group))
act(actions)
}
+ override def advance(offset: Int, count: Int): AdvanceAction[T] = {
+ val newWizards = wizards.map(_.advance(offset, count)).map {
+ case AdvanceAction.Continue(newWizard) =>
+ newWizard
+ }
+ AdvanceAction.Continue(new Union(newWizards))
+ }
+
override def withPathKey(newKey: PathKey): OutputWizard[T] =
new Union[T](wizards.map(w => w.withPathKey(newKey)))
}
@@ -331,13 +344,17 @@ object OutputWizards {
OutputAction.Continue(PathDrain.none, this)
}
- override def advance(group: GroupNode[T], offset: Int, count: Int):
OutputAction[T] = {
- var skipCursor = ele + 1
- (0 until offset).foreach(_ => skipCursor = mask.skip(skipCursor))
- if (mask.isAny(skipCursor)) {
+ override def visit(group: GroupNode[T]): OutputAction[T] = {
+ if (mask.isAny(ele)) {
return OutputAction.Stop(drain)
}
- OutputAction.Continue(PathDrain.none, new WithMask[T](drain, mask,
skipCursor))
+ OutputAction.Continue(PathDrain.none, this)
+ }
+
+ override def advance(offset: Int, count: Int): AdvanceAction[T] = {
+ var skipCursor = ele + 1
+ (0 until offset).foreach(_ => skipCursor = mask.skip(skipCursor))
+ AdvanceAction.Continue(new WithMask[T](drain, mask, skipCursor))
}
override def withPathKey(newKey: PathKey): OutputWizard[T] =
@@ -372,13 +389,16 @@ object OutputWizards {
OutputAction.Continue(PathDrain.none, this)
}
- override def advance(group: GroupNode[T], offset: Int, count: Int):
OutputAction[T] = {
- // Omit should be done in #advance.
- val child = pNode.children(count)(offset)
- if (child.skip()) {
+ override def visit(group: GroupNode[T]): OutputAction[T] = {
+ if (pNode.skip()) {
return OutputAction.Stop(drain)
}
- OutputAction.Continue(PathDrain.none, new WithPattern(drain, pattern,
child))
+ OutputAction.Continue(PathDrain.none, this)
+ }
+
+ override def advance(offset: Int, count: Int): AdvanceAction[T] = {
+ val child = pNode.children(count)(offset)
+ AdvanceAction.Continue(new WithPattern(drain, pattern, child))
}
override def withPathKey(newKey: PathKey): OutputWizard[T] =
@@ -398,8 +418,8 @@ object OutputWizards {
override def visit(can: CanonicalNode[T]): OutputAction[T] = {
assert(
- currentDepth <= depth,
- "Current depth already larger than the maximum depth to prune. " +
+ currentDepth < depth,
+ "Current depth already larger than (or equals) the maximum depth to
prune. " +
"It probably because a zero depth was specified for path finding."
)
if (can.isLeaf()) {
@@ -408,13 +428,15 @@ object OutputWizards {
OutputAction.Continue(PathDrain.none, this)
}
- override def advance(group: GroupNode[T], offset: Int, count: Int):
OutputAction[T] = {
- assert(currentDepth <= depth)
- val nextDepth = currentDepth + 1
- if (nextDepth > depth) {
+ override def visit(group: GroupNode[T]): OutputAction[T] = {
+ if (currentDepth >= depth) {
return OutputAction.Stop(drain)
}
- OutputAction.Continue(PathDrain.none, new WithMaxDepth(drain, depth,
nextDepth))
+ OutputAction.Continue(PathDrain.none, this)
+ }
+
+ override def advance(offset: Int, count: Int): AdvanceAction[T] = {
+ AdvanceAction.Continue(new WithMaxDepth(drain, depth, currentDepth + 1))
}
override def withPathKey(newKey: PathKey): OutputWizard[T] =
@@ -423,7 +445,7 @@ object OutputWizards {
private object WithMaxDepth {
def apply[T <: AnyRef](depth: Int): WithMaxDepth[T] = {
- new WithMaxDepth(PathDrain.Specific(List(PathKey.random())), depth, 1)
+ new WithMaxDepth(PathDrain.Specific(List(PathKey.random())), depth, 0)
}
}
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/PathFinder.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/PathFinder.scala
index c5949a9d7..78aed142a 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/PathFinder.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/PathFinder.scala
@@ -18,11 +18,13 @@ package org.apache.gluten.ras.path
import org.apache.gluten.ras.{CanonicalNode, GroupNode, Ras}
import org.apache.gluten.ras.memo.MemoStore
+import org.apache.gluten.ras.path.OutputWizard.AdvanceAction
import scala.collection.mutable
trait PathFinder[T <: AnyRef] {
def find(base: CanonicalNode[T]): Iterable[RasPath[T]]
+ def find(base: GroupNode[T]): Iterable[RasPath[T]]
def find(base: RasPath[T]): Iterable[RasPath[T]]
}
@@ -88,6 +90,14 @@ object PathFinder {
all
}
+ override def find(base: GroupNode[T]): Iterable[RasPath[T]] = {
+ val all =
+ wizard.prepareForGroup(ras, base).visit().onContinue {
+ newWizard => enumerateFromGroup(base, newWizard)
+ }
+ all
+ }
+
override def find(base: RasPath[T]): Iterable[RasPath[T]] = {
val can = base.node().self().asCanonical()
val all = wizard.prepareForNode(ras, memoStore.asGroupSupplier(),
can).visit().onContinue {
@@ -123,9 +133,13 @@ object PathFinder {
childrenGroups.zipWithIndex.map {
case (childGroup, index) =>
wizard
- .prepareForGroup(ras, childGroup, index, childrenGroups.size)
- .advance()
- .onContinue(newWizard => enumerateFromGroup(childGroup,
newWizard))
+ .advance(index, childrenGroups.size) match {
+ case AdvanceAction.Continue(newWizard) =>
+ newWizard
+ .prepareForGroup(ras, childGroup)
+ .visit()
+ .onContinue(newWizard => enumerateFromGroup(childGroup,
newWizard))
+ }
}
RasPath.cartesianProduct(ras, canonical, expandedChildren)
}
@@ -168,11 +182,16 @@ object PathFinder {
children.zip(childrenGroups).zipWithIndex.map {
case ((child, childGroup), index) =>
wizard
- .prepareForGroup(ras, childGroup, index, childrenGroups.size)
- .advance()
- .onContinue {
- newWizard => diveFromGroup(depth - 1,
GroupedPathNode(childGroup, child), newWizard)
- }
+ .advance(index, childrenGroups.size) match {
+ case AdvanceAction.Continue(newWizard) =>
+ newWizard
+ .prepareForGroup(ras, childGroup)
+ .visit()
+ .onContinue {
+ newWizard =>
+ diveFromGroup(depth - 1, GroupedPathNode(childGroup,
child), newWizard)
+ }
+ }
}
)
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala
index b694279ec..0471fb42a 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala
@@ -90,21 +90,20 @@ object Pattern {
private case class PatternImpl[T <: AnyRef](root: Node[T]) extends
Pattern[T] {
override def matches(path: RasPath[T], depth: Int): Boolean = {
- assert(depth >= 1)
+ assert(depth >= 0)
assert(depth <= path.height())
def dfs(remainingDepth: Int, patternN: Node[T], n: PathNode[T]): Boolean
= {
assert(remainingDepth >= 0)
- assert(n.self().isCanonical)
if (remainingDepth == 0) {
return true
}
+ if (patternN.skip()) {
+ return true
+ }
val can = n.self().asCanonical()
if (patternN.abort(can)) {
return false
}
- if (patternN.skip()) {
- return true
- }
if (!patternN.matches(can)) {
return false
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala
index 0a7bf0c76..01e826f06 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala
@@ -20,6 +20,7 @@ import org.apache.gluten.ras._
import org.apache.gluten.ras.Ras.UnsafeKey
import org.apache.gluten.ras.memo.Closure
import org.apache.gluten.ras.path.InClusterPath
+import org.apache.gluten.ras.property.PropertySet
import scala.collection.mutable
@@ -82,8 +83,8 @@ object RuleApplier {
override def apply(icp: InClusterPath[T]): Unit = {
val cKey = icp.cluster()
val path = icp.path()
- val can = path.node().self().asCanonical()
- if (can.propSet().get(constraintDef).satisfies(constraint)) {
+ val propSet = path.node().self().propSet()
+ if (propSet.get(constraintDef).satisfies(constraint)) {
return
}
val plan = path.plan()
@@ -92,13 +93,12 @@ object RuleApplier {
if (appliedPlans.contains(pKey)) {
return
}
- apply0(cKey, plan)
+ val constraintSet = propSet.withProp(constraint)
+ apply0(cKey, constraintSet, plan)
appliedPlans += pKey
}
- private def apply0(cKey: RasClusterKey, plan: T): Unit = {
- val propSet = ras.propertySetFactory().get(plan)
- val constraintSet = propSet.withProp(constraint)
+ private def apply0(cKey: RasClusterKey, constraintSet: PropertySet[T],
plan: T): Unit = {
val equivalents = rule.shift(plan)
equivalents.foreach {
equiv =>
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/Shape.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/Shape.scala
index 400f0eeba..2a8255861 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/Shape.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/Shape.scala
@@ -33,10 +33,20 @@ object Shapes {
new FixedHeight[T](height)
}
+ def pattern[T <: AnyRef](pattern: org.apache.gluten.ras.path.Pattern[T]):
Shape[T] = {
+ new Pattern(pattern)
+ }
+
def none[T <: AnyRef](): Shape[T] = {
new None()
}
+ private class Pattern[T <: AnyRef](pattern:
org.apache.gluten.ras.path.Pattern[T])
+ extends Shape[T] {
+ override def wizard(): OutputWizard[T] = OutputWizards.withPattern(pattern)
+ override def identify(path: RasPath[T]): Boolean = pattern.matches(path,
path.height())
+ }
+
private class FixedHeight[T <: AnyRef](height: Int) extends Shape[T] {
override def wizard(): OutputWizard[T] = OutputWizards.withMaxDepth(height)
override def identify(path: RasPath[T]): Boolean = path.height() == height
diff --git
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/PropertySuite.scala
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/PropertySuite.scala
index e48604116..aed032226 100644
--- a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/PropertySuite.scala
+++ b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/PropertySuite.scala
@@ -27,16 +27,29 @@ import org.scalatest.funsuite.AnyFunSuite
class ExhaustivePlannerPropertySuite extends PropertySuite {
override protected def conf: RasConfig = RasConfig(plannerType =
PlannerType.Exhaustive)
+ override protected def zeroDepth: Boolean = false
}
class DpPlannerPropertySuite extends PropertySuite {
override protected def conf: RasConfig = RasConfig(plannerType =
PlannerType.Dp)
+ override protected def zeroDepth: Boolean = false
+}
+
+class ExhaustivePlannerPropertyZeroDepthSuite extends PropertySuite {
+ override protected def conf: RasConfig = RasConfig(plannerType =
PlannerType.Exhaustive)
+ override protected def zeroDepth: Boolean = true
+}
+
+class DpPlannerPropertyZeroDepthSuite extends PropertySuite {
+ override protected def conf: RasConfig = RasConfig(plannerType =
PlannerType.Dp)
+ override protected def zeroDepth: Boolean = true
}
abstract class PropertySuite extends AnyFunSuite {
import PropertySuite._
protected def conf: RasConfig
+ protected def zeroDepth: Boolean
test("Group memo - cache") {
val ras =
@@ -44,7 +57,7 @@ abstract class PropertySuite extends AnyFunSuite {
PlanModelImpl,
CostModelImpl,
MetadataModelImpl,
- NodeTypePropertyModelWithOutEnforcerRules,
+ propertyModelWithoutEnforcerRules(),
ExplainImpl,
RasRule.Factory.none())
.withNewConfig(_ => conf)
@@ -86,7 +99,7 @@ abstract class PropertySuite extends AnyFunSuite {
PlanModelImpl,
CostModelImpl,
MetadataModelImpl,
- NodeTypePropertyModelWithOutEnforcerRules,
+ propertyModelWithoutEnforcerRules(),
ExplainImpl,
RasRule.Factory.none())
.withNewConfig(_ => conf)
@@ -103,7 +116,7 @@ abstract class PropertySuite extends AnyFunSuite {
PlanModelImpl,
CostModelImpl,
MetadataModelImpl,
- NodeTypePropertyModel,
+ propertyModel(zeroDepth),
ExplainImpl,
RasRule.Factory.none())
.withNewConfig(_ => conf)
@@ -121,7 +134,7 @@ abstract class PropertySuite extends AnyFunSuite {
PlanModelImpl,
CostModelImpl,
MetadataModelImpl,
- NodeTypePropertyModel,
+ propertyModel(zeroDepth),
ExplainImpl,
RasRule.Factory.reuse(List(ReplaceByTypeARule, ReplaceByTypeBRule)))
.withNewConfig(_ => conf)
@@ -177,7 +190,7 @@ abstract class PropertySuite extends AnyFunSuite {
PlanModelImpl,
CostModelImpl,
MetadataModelImpl,
- NodeTypePropertyModelWithOutEnforcerRules,
+ propertyModelWithoutEnforcerRules(),
ExplainImpl,
RasRule.Factory.reuse(List(ReplaceLeafAByLeafBRule, HitCacheOp,
FinalOp))
)
@@ -223,7 +236,7 @@ abstract class PropertySuite extends AnyFunSuite {
PlanModelImpl,
CostModelImpl,
MetadataModelImpl,
- NodeTypePropertyModelWithOutEnforcerRules,
+ propertyModelWithoutEnforcerRules(),
ExplainImpl,
RasRule.Factory.reuse(List(ReplaceLeafAByLeafBRule,
ReplaceUnaryBByUnaryARule))
)
@@ -252,7 +265,7 @@ abstract class PropertySuite extends AnyFunSuite {
PlanModelImpl,
CostModelImpl,
MetadataModelImpl,
- NodeTypePropertyModel,
+ propertyModel(zeroDepth),
ExplainImpl,
RasRule.Factory.reuse(List(ConvertEnforcerAndTypeAToTypeB)))
.withNewConfig(_ => conf)
@@ -290,7 +303,7 @@ abstract class PropertySuite extends AnyFunSuite {
PlanModelImpl,
CostModelImpl,
MetadataModelImpl,
- NodeTypePropertyModel,
+ propertyModel(zeroDepth),
ExplainImpl,
RasRule.Factory.reuse(List(ReplaceByTypeARule,
ReplaceNonUnaryByTypeBRule))
)
@@ -342,9 +355,10 @@ abstract class PropertySuite extends AnyFunSuite {
PlanModelImpl,
CostModelImpl,
MetadataModelImpl,
- NodeTypePropertyModel,
+ propertyModel(zeroDepth),
ExplainImpl,
- RasRule.Factory.reuse(List(ReduceTypeBCost, ConvertUnaryTypeBToTypeC)))
+ RasRule.Factory.reuse(List(ReduceTypeBCost, ConvertUnaryTypeBToTypeC))
+ )
.withNewConfig(_ => conf)
val plan =
@@ -394,7 +408,7 @@ abstract class PropertySuite extends AnyFunSuite {
PlanModelImpl,
CostModelImpl,
MetadataModelImpl,
- NodeTypePropertyModel,
+ propertyModel(zeroDepth),
ExplainImpl,
RasRule.Factory.reuse(List(LeftOp, RightOp))
)
@@ -441,7 +455,7 @@ abstract class PropertySuite extends AnyFunSuite {
PlanModelImpl,
CostModelImpl,
MetadataModelImpl,
- NodeTypePropertyModel,
+ propertyModel(zeroDepth),
ExplainImpl,
RasRule.Factory.reuse(List(ConvertTypeBEnforcerAndLeafToTypeC,
ConvertTypeATypeCToTypeC))
)
@@ -494,7 +508,7 @@ abstract class PropertySuite extends AnyFunSuite {
PlanModelImpl,
CostModelImpl,
MetadataModelImpl,
- NodeTypePropertyModel,
+ propertyModel(zeroDepth),
ExplainImpl,
RasRule.Factory.reuse(List(ReplaceNonUnaryByTypeBRule,
ReduceTypeBCost))
)
@@ -651,6 +665,23 @@ object PropertySuite {
override def shape(): Shape[TestNode] = Shapes.fixedHeight(1)
}
+ case class ZeroDepthNodeTypeEnforcerRule(reqType: NodeType) extends
RasRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = {
+ node match {
+ case group: Group =>
+ val groupType = group.propSet.get(NodeTypeDef)
+ if (groupType.satisfies(reqType)) {
+ List(group)
+ } else {
+ List(TypeEnforcer(reqType, 1, group))
+ }
+ case _ => throw new IllegalStateException()
+ }
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(0)
+ }
+
object ReplaceByTypeARule extends RasRule[TestNode] {
override def shift(node: TestNode): Iterable[TestNode] = {
node match {
@@ -739,24 +770,53 @@ object PropertySuite {
}
}
- object NodeTypePropertyModel extends PropertyModel[TestNode] {
- override def propertyDefs: Seq[PropertyDef[TestNode, _ <:
Property[TestNode]]] = Seq(
- NodeTypeDef)
-
- override def newEnforcerRuleFactory(propertyDef: PropertyDef[TestNode, _
<: Property[TestNode]])
- : EnforcerRuleFactory[TestNode] = {
- (constraint: Property[TestNode]) =>
- {
- List(NodeTypeEnforcerRule(constraint.asInstanceOf[NodeType]))
- }
+ private def propertyModel(zeroDepth: Boolean): PropertyModel[TestNode] = {
+ if (zeroDepth) {
+ return PropertyModels.NodeTypePropertyModelZeroDepth
}
+ PropertyModels.NodeTypePropertyModel
}
- object NodeTypePropertyModelWithOutEnforcerRules extends
PropertyModel[TestNode] {
- override def propertyDefs: Seq[PropertyDef[TestNode, _ <:
Property[TestNode]]] = Seq(
- NodeTypeDef)
+ private def propertyModelWithoutEnforcerRules(): PropertyModel[TestNode] = {
+ PropertyModels.NodeTypePropertyModelWithoutEnforcerRules
+ }
- override def newEnforcerRuleFactory(propertyDef: PropertyDef[TestNode, _
<: Property[TestNode]])
- : EnforcerRuleFactory[TestNode] = (_: Property[TestNode]) => List.empty
+ private object PropertyModels {
+ object NodeTypePropertyModel extends PropertyModel[TestNode] {
+ override def propertyDefs: Seq[PropertyDef[TestNode, _ <:
Property[TestNode]]] = Seq(
+ NodeTypeDef)
+
+ override def newEnforcerRuleFactory(
+ propertyDef: PropertyDef[TestNode, _ <: Property[TestNode]])
+ : EnforcerRuleFactory[TestNode] = {
+ (constraint: Property[TestNode]) =>
+ {
+ List(NodeTypeEnforcerRule(constraint.asInstanceOf[NodeType]))
+ }
+ }
+ }
+
+ object NodeTypePropertyModelZeroDepth extends PropertyModel[TestNode] {
+ override def propertyDefs: Seq[PropertyDef[TestNode, _ <:
Property[TestNode]]] = Seq(
+ NodeTypeDef)
+
+ override def newEnforcerRuleFactory(
+ propertyDef: PropertyDef[TestNode, _ <: Property[TestNode]])
+ : EnforcerRuleFactory[TestNode] = {
+ (constraint: Property[TestNode]) =>
+ {
+
List(ZeroDepthNodeTypeEnforcerRule(constraint.asInstanceOf[NodeType]))
+ }
+ }
+ }
+
+ object NodeTypePropertyModelWithoutEnforcerRules extends
PropertyModel[TestNode] {
+ override def propertyDefs: Seq[PropertyDef[TestNode, _ <:
Property[TestNode]]] = Seq(
+ NodeTypeDef)
+
+ override def newEnforcerRuleFactory(
+ propertyDef: PropertyDef[TestNode, _ <: Property[TestNode]])
+ : EnforcerRuleFactory[TestNode] = (_: Property[TestNode]) =>
List.empty
+ }
}
}
diff --git
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala
index f8a3d0799..b29e0c267 100644
--- a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala
+++ b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala
@@ -19,6 +19,7 @@ package org.apache.gluten.ras
import org.apache.gluten.ras.RasConfig.PlannerType
import org.apache.gluten.ras.RasSuiteBase._
import org.apache.gluten.ras.memo.Memo
+import org.apache.gluten.ras.path.Pattern
import org.apache.gluten.ras.rule.{RasRule, Shape, Shapes}
import org.scalatest.funsuite.AnyFunSuite
@@ -198,6 +199,72 @@ abstract class RasSuite extends AnyFunSuite {
assert(optimized == Leaf(70))
}
+ test(s"Group expansion - fixed height") {
+ object AddUnary extends RasRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = {
+ assert(node.isInstanceOf[Group])
+ List(Unary(50, node))
+ }
+
+ override def shape(): Shape[TestNode] = Shapes.fixedHeight(0)
+ }
+
+ val ras =
+ Ras[TestNode](
+ PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ RasRule.Factory.reuse(List(AddUnary)))
+ .withNewConfig(_ => conf)
+ val plan = Unary(60, Unary(90, Leaf(70)))
+ val planner = ras.newPlanner(plan)
+ val optimized = planner.plan()
+
+ assert(optimized == Unary(60, Unary(90, Leaf(70))))
+
+ val state = planner.newState().memoState()
+ val allPaths = state.collectAllPaths(Int.MaxValue)
+
+ assert(state.allClusters().size == 3)
+ assert(state.allGroups().size == 3)
+ assert(allPaths.size == 15)
+ }
+
+ test(s"Group expansion - pattern") {
+ object AddUnary extends RasRule[TestNode] {
+ override def shift(node: TestNode): Iterable[TestNode] = {
+ assert(node.isInstanceOf[Group])
+ List(Unary(50, node))
+ }
+
+ override def shape(): Shape[TestNode] =
Shapes.pattern(Pattern.ignore.build())
+ }
+
+ val ras =
+ Ras[TestNode](
+ PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ RasRule.Factory.reuse(List(AddUnary)))
+ .withNewConfig(_ => conf)
+ val plan = Unary(60, Unary(90, Leaf(70)))
+ val planner = ras.newPlanner(plan)
+ val optimized = planner.plan()
+
+ assert(optimized == Unary(60, Unary(90, Leaf(70))))
+
+ val state = planner.newState().memoState()
+ val allPaths = state.collectAllPaths(Int.MaxValue)
+
+ assert(state.allClusters().size == 3)
+ assert(state.allGroups().size == 3)
+ assert(allPaths.size == 15)
+ }
+
test(s"Unary node insertion") {
object InsertUnary2 extends RasRule[TestNode] {
override def shift(node: TestNode): Iterable[TestNode] = node match {
diff --git
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuiteBase.scala
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuiteBase.scala
index 5432a6c78..b5455d6af 100644
--- a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuiteBase.scala
+++ b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuiteBase.scala
@@ -163,7 +163,7 @@ object RasSuiteBase {
implicit class MemoLikeImplicits[T <: AnyRef](val memo: MemoLike[T]) {
def memorize(ras: Ras[T], node: T): RasGroup[T] = {
- memo.memorize(node, ras.propSetsOf(node))
+ memo.memorize(node, ras.propSetOf(node))
}
}
diff --git
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockMemoState.scala
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockMemoState.scala
index e4c8cb30b..7bb713afe 100644
---
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockMemoState.scala
+++
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockMemoState.scala
@@ -48,6 +48,11 @@ case class MockMemoState[T <: AnyRef] private (
override def getCluster(key: RasClusterKey): RasCluster[T] =
clusterLookup(key)
override def getGroup(id: Int): RasGroup[T] = allGroups(id)
+
+ override def clusterDummyGroupLookup(): Map[RasClusterKey, RasGroup[T]] =
Map.empty
+
+ override def getDummyGroup(key: RasClusterKey): RasGroup[T] =
+ throw new UnsupportedOperationException()
}
object MockMemoState {
diff --git
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockRasPath.scala
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockRasPath.scala
index eac3d5f70..cd8050e5f 100644
---
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockRasPath.scala
+++
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/mock/MockRasPath.scala
@@ -27,7 +27,7 @@ object MockRasPath {
def mock[T <: AnyRef](ras: Ras[T], node: T, keys: PathKeySet): RasPath[T] = {
val memo = Memo(ras)
- val g = memo.memorize(node, ras.propSetsOf(node))
+ val g = memo.memorize(node, ras.propSetOf(node))
val state = memo.newState()
val groupSupplier = state.asGroupSupplier()
assert(g.nodes(state).size == 1)
diff --git
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/PathFinderSuite.scala
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/PathFinderSuite.scala
index 0c8880c29..b5ea3fc3c 100644
---
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/PathFinderSuite.scala
+++
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/PathFinderSuite.scala
@@ -16,7 +16,7 @@
*/
package org.apache.gluten.ras.path
-import org.apache.gluten.ras.{CanonicalNode, Ras}
+import org.apache.gluten.ras.{CanonicalNode, Ras, RasGroup}
import org.apache.gluten.ras.RasSuiteBase._
import org.apache.gluten.ras.mock.MockMemoState
import org.apache.gluten.ras.rule.RasRule
@@ -81,6 +81,64 @@ class PathFinderSuite extends AnyFunSuite {
Binary(n1, Unary(n2, Leaf(n4, 1)), Unary(n3, Leaf(n6, 1)))))
}
+ test("Find - from group") {
+ val ras =
+ Ras[TestNode](
+ PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ RasRule.Factory.none())
+
+ val mock = MockMemoState.Builder(ras)
+ val cluster = mock.newCluster()
+ val groupA = cluster.newGroup()
+ val groupB = cluster.newGroup()
+ val groupC = cluster.newGroup()
+ val groupD = cluster.newGroup()
+ val groupE = cluster.newGroup()
+ val n1 = "n1"
+ val n2 = "n2"
+ val n3 = "n3"
+ val n4 = "n4"
+ val n5 = "n5"
+ val n6 = "n6"
+ val node1 = Binary(n1, groupB.self, groupC.self).asCanonical(ras)
+ val node2 = Unary(n2, groupD.self).asCanonical(ras)
+ val node3 = Unary(n3, groupE.self).asCanonical(ras)
+ val node4 = Leaf(n4, 1).asCanonical(ras)
+ val node5 = Leaf(n5, 1).asCanonical(ras)
+ val node6 = Leaf(n6, 1).asCanonical(ras)
+
+ groupA.add(node1)
+ groupB.add(node2)
+ groupC.add(node3)
+ groupD.add(node4)
+ groupE.add(List(node5, node6))
+
+ val state = mock.build()
+
+ def find(group: RasGroup[TestNode], depth: Int):
Iterable[RasPath[TestNode]] = {
+ val finder = PathFinder.builder(ras, state).depth(depth).build()
+ finder.find(group.asGroup(ras))
+ }
+
+ val height0 = find(groupA, 0).map(_.plan()).toSeq
+ val height1 = find(groupA, 1).map(_.plan()).toSeq
+ val height2 = find(groupA, 2).map(_.plan()).toSeq
+ val heightInf = find(groupA, RasPath.INF_DEPTH).map(_.plan()).toSeq
+
+ assert(height0 == List(Group(0)))
+ assert(height1 == List(Binary(n1, Group(1), Group(2))))
+ assert(height2 == List(Binary(n1, Unary(n2, Group(3)), Unary(n3,
Group(4)))))
+ assert(
+ heightInf == List(
+ Binary(n1, Unary(n2, Leaf(n4, 1)), Unary(n3, Leaf(n5, 1))),
+ Binary(n1, Unary(n2, Leaf(n4, 1)), Unary(n3, Leaf(n6, 1)))))
+
+ }
+
test("Find - multiple depths") {
val ras =
Ras[TestNode](
diff --git
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/WizardSuite.scala
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/WizardSuite.scala
index c7505ef74..523a22689 100644
---
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/WizardSuite.scala
+++
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/WizardSuite.scala
@@ -178,8 +178,22 @@ class WizardSuite extends AnyFunSuite {
finder.find(node1).map(_.plan()).toSeq
}
+ def findWithPatternsFromGroup(patterns: Seq[Pattern[TestNode]]):
Seq[TestNode] = {
+ val builder = PathFinder.builder(ras, state)
+ val finder = patterns
+ .foldLeft(builder) {
+ case (builder, pattern) =>
+ builder.output(OutputWizards.withPattern(pattern))
+ }
+ .build()
+ finder.find(groupA.asGroup(ras)).map(_.plan()).toSeq
+ }
+
+ assert(findWithPatternsFromGroup(List(Pattern.ignore[TestNode].build()))
== List(Group(0)))
+
assert(
findWithPatterns(List(Pattern.any[TestNode].build())) == List(Binary(n1,
Group(1), Group(2))))
+
assert(
findWithPatterns(
List(
diff --git
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/specific/DistributedSuite.scala
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/specific/DistributedSuite.scala
index de71cba5b..2aefc54e9 100644
---
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/specific/DistributedSuite.scala
+++
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/specific/DistributedSuite.scala
@@ -242,6 +242,7 @@ object DistributedSuite {
case object NoneDistribution extends Distribution {
override def satisfies(other: Property[TestNode]): Boolean = other match {
+ case AnyDistribution => true
case _: Distribution => false
case _ => throw new UnsupportedOperationException()
}
@@ -296,6 +297,7 @@ object DistributedSuite {
case object NoneOrdering extends Ordering {
override def satisfies(other: Property[TestNode]): Boolean = other match {
+ case AnyOrdering => true
case _: Ordering => false
case _ => throw new UnsupportedOperationException()
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]