This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 449ac3259 [CORE][VL] RAS: Pattern matching by node classes
449ac3259 is described below
commit 449ac32590ece0b1b1d451467906a4a44c911f44
Author: Hongze Zhang <[email protected]>
AuthorDate: Thu Apr 11 12:02:19 2024 +0800
[CORE][VL] RAS: Pattern matching by node classes
---
.../org/apache/gluten/ras/dp/DpClusterAlgo.scala | 2 +-
.../scala/org/apache/gluten/ras/dp/DpPlanner.scala | 2 +-
.../apache/gluten/ras/memo/ForwardMemoTable.scala | 10 ++--
.../scala/org/apache/gluten/ras/memo/Memo.scala | 43 ++---------------
.../org/apache/gluten/ras/memo/MemoTable.scala | 56 +++++++++++++++-------
.../scala/org/apache/gluten/ras/path/Pattern.scala | 15 ++++++
.../scala/org/apache/gluten/ras/rule/Shape.scala | 16 +++++--
.../apache/gluten/ras/vis/GraphvizVisualizer.scala | 8 +++-
.../org/apache/gluten/ras/rule/PatternSuite.scala | 48 +++++++++++++++++++
9 files changed, 133 insertions(+), 67 deletions(-)
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpClusterAlgo.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpClusterAlgo.scala
index e90ba448b..6fd95772b 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpClusterAlgo.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpClusterAlgo.scala
@@ -81,7 +81,7 @@ object DpClusterAlgo {
}
override def browseX(x: InClusterNode[T]): Iterable[RasClusterKey] = {
- val allGroups = memoTable.allGroups()
+ val allGroups = memoTable.asGroupSupplier()
x.can
.getChildrenGroups(allGroups)
.map(gn => allGroups(gn.groupId()).clusterKey())
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
index 8acf66c59..4a9e3f0f0 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
@@ -60,7 +60,7 @@ private class DpPlanner[T <: AnyRef] private (
}
private def findBest(memoTable: MemoTable[T], groupId: Int): Best[T] = {
- val cKey = memoTable.allGroups()(groupId).clusterKey()
+ val cKey = memoTable.asGroupSupplier()(groupId).clusterKey()
val algoDef = new DpExploreAlgoDef[T]
val adjustment = new ExploreAdjustment(ras, memoTable, rules,
enforcerRuleSet)
DpClusterAlgo.resolve(memoTable, algoDef, adjustment, cKey)
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
index 0895544d5..dd4033866 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
@@ -155,11 +155,13 @@ class ForwardMemoTable[T <: AnyRef] private (override val
ras: Ras[T])
groupBuffer(id)
}
- override def allClusters(): Seq[RasClusterKey] = clusterKeyBuffer
+ override def allClusterKeys(): Seq[RasClusterKey] = clusterKeyBuffer
- override def allGroups(): Seq[RasGroup[T]] = groupBuffer
-
- override def allDummyGroups(): Seq[RasGroup[T]] = dummyGroupBuffer
+ override def allGroupIds(): Seq[Int] = {
+ val from = -dummyGroupBuffer.size
+ val to = groupBuffer.size
+ (from until to).toVector
+ }
private def ancestorClusterIdOf(key: RasClusterKey): Int = {
clusterDisjointSet.find(key.id())
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
index c1bb0a6bf..6406b8fb1 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
@@ -18,7 +18,6 @@ package org.apache.gluten.ras.memo
import org.apache.gluten.ras._
import org.apache.gluten.ras.Ras.UnsafeKey
-import org.apache.gluten.ras.RasCluster.ImmutableRasCluster
import org.apache.gluten.ras.property.PropertySet
import org.apache.gluten.ras.vis.GraphvizVisualizer
@@ -78,7 +77,7 @@ object Memo {
memoTable.getDummyGroup(clusterKey)
}
- private def toCacheKeyUnsafe(n: T): MemoCacheKey[T] = {
+ private def toCacheKey(n: T): MemoCacheKey[T] = {
MemoCacheKey(ras, n)
}
@@ -90,11 +89,11 @@ object Memo {
val childrenPrepares = ras.planModel.childrenOf(n).map(child =>
prepareInsert(child))
- val canUnsafe = ras.withNewChildren(
+ val keyUnsafe = ras.withNewChildren(
n,
childrenPrepares.map(childPrepare =>
dummyGroupOf(childPrepare.clusterKey()).self()))
- val cacheKey = toCacheKeyUnsafe(canUnsafe)
+ val cacheKey = toCacheKey(keyUnsafe)
val clusterKey = clusterOfUnsafe(ras.metadataModel.metadataOf(n),
cacheKey)
@@ -144,13 +143,13 @@ object Memo {
val childrenPrepares =
ras.planModel.childrenOf(node).map(child =>
parent.prepareInsert(child))
- val canUnsafe = ras.withNewChildren(
+ val keyUnsafe = ras.withNewChildren(
node,
childrenPrepares.map {
childPrepare =>
parent.dummyGroupOf(childPrepare.clusterKey()).self()
})
- val cacheKey = parent.toCacheKeyUnsafe(canUnsafe)
+ val cacheKey = parent.toCacheKey(keyUnsafe)
if (!parent.cache.contains(cacheKey)) {
// The new node was not added to memo yet. Add it to the target
cluster.
@@ -267,39 +266,7 @@ trait MemoState[T <: AnyRef] extends MemoStore[T] {
}
object MemoState {
- def apply[T <: AnyRef](
- ras: Ras[T],
- clusterLookup: Map[RasClusterKey, ImmutableRasCluster[T]],
- clusterDummyGroupLookup: Map[RasClusterKey, RasGroup[T]],
- allGroups: Seq[RasGroup[T]],
- allDummyGroups: Seq[RasGroup[T]]): MemoState[T] = {
- MemoStateImpl(ras, clusterLookup, clusterDummyGroupLookup, allGroups,
allDummyGroups)
- }
-
- private case class MemoStateImpl[T <: AnyRef](
- override val ras: Ras[T],
- override val clusterLookup: Map[RasClusterKey, ImmutableRasCluster[T]],
- override val clusterDummyGroupLookup: Map[RasClusterKey, RasGroup[T]],
- override val allGroups: Seq[RasGroup[T]],
- allDummyGroups: Seq[RasGroup[T]])
- extends MemoState[T] {
- private val allClustersCopy = clusterLookup.values
-
- override def getCluster(key: RasClusterKey): RasCluster[T] =
clusterLookup(key)
- override def getDummyGroup(key: RasClusterKey): RasGroup[T] =
clusterDummyGroupLookup(key)
- override def getGroup(id: Int): RasGroup[T] = {
- if (id < 0) {
- val out = allDummyGroups((-id - 1))
- assert(out.id() == id)
- return out
- }
- allGroups(id)
- }
- override def allClusters(): Iterable[RasCluster[T]] = allClustersCopy
- }
-
implicit class MemoStateImplicits[T <: AnyRef](state: MemoState[T]) {
-
def formatGraphvizWithBest(best: Best[T]): String = {
GraphvizVisualizer(state.ras(), state, best).format()
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala
index 8e5dac02e..2e2323a1e 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala
@@ -25,12 +25,8 @@ sealed trait MemoTable[T <: AnyRef] extends MemoStore[T] {
def ras: Ras[T]
- override def getCluster(key: RasClusterKey): RasCluster[T]
- override def getGroup(id: Int): RasGroup[T]
-
- def allClusters(): Seq[RasClusterKey]
- def allGroups(): Seq[RasGroup[T]]
- def allDummyGroups(): Seq[RasGroup[T]]
+ def allClusterKeys(): Seq[RasClusterKey]
+ def allGroupIds(): Seq[Int]
def getClusterPropSets(key: RasClusterKey): Set[PropertySet[T]]
@@ -68,26 +64,50 @@ object MemoTable {
}
}
+ private case class MemoStateImpl[T <: AnyRef](
+ override val ras: Ras[T],
+ override val clusterLookup: Map[RasClusterKey, ImmutableRasCluster[T]],
+ override val clusterDummyGroupLookup: Map[RasClusterKey, RasGroup[T]],
+ override val allGroups: Seq[RasGroup[T]],
+ idToGroup: Map[Int, RasGroup[T]])
+ extends MemoState[T] {
+ private val allClustersCopy = clusterLookup.values
+
+ override def getCluster(key: RasClusterKey): RasCluster[T] =
clusterLookup(key)
+ override def getDummyGroup(key: RasClusterKey): RasGroup[T] =
clusterDummyGroupLookup(key)
+ override def getGroup(id: Int): RasGroup[T] = idToGroup(id)
+ override def allClusters(): Iterable[RasCluster[T]] = allClustersCopy
+ }
+
implicit class MemoTableImplicits[T <: AnyRef](table: MemoTable[T]) {
def newState(): MemoState[T] = {
val immutableClusters = table
- .allClusters()
+ .allClusterKeys()
.map(key => key -> ImmutableRasCluster(table.ras,
table.getCluster(key)))
.toMap
val immutableDummyGroups = table
- .allClusters()
+ .allClusterKeys()
.map(key => key -> table.getDummyGroup(key))
.toMap
- table.allDummyGroups().zipWithIndex.foreach {
- case (group, idx) =>
- assert(group.id() == -(idx + 1))
- }
- MemoState(
- table.ras,
- immutableClusters,
- immutableDummyGroups,
- table.allGroups(),
- table.allDummyGroups())
+
+ var maxGroupId = Int.MinValue
+
+ val groupMap = table
+ .allGroupIds()
+ .map {
+ gid =>
+ val group = table.getGroup(gid)
+ assert(group.id() == gid)
+ if (gid > maxGroupId) {
+ maxGroupId = gid
+ }
+ gid -> group
+ }
+ .toMap
+
+ val allGroups = (0 to maxGroupId).map(table.getGroup).toVector
+
+ MemoStateImpl(table.ras, immutableClusters, immutableDummyGroups,
allGroups, groupMap)
}
def doExhaustively(func: => Unit): Unit = {
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala
index 0471fb42a..f20d05c7c 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala
@@ -27,6 +27,21 @@ trait Pattern[T <: AnyRef] {
object Pattern {
trait Matcher[T <: AnyRef] extends (T => Boolean)
+ object Matchers {
+ private case class Or[T <: AnyRef](matchers: Seq[Matcher[T]]) extends
Matcher[T] {
+ override def apply(t: T): Boolean = {
+ matchers.exists(_(t))
+ }
+ }
+
+ private case class Clazz[T <: AnyRef](clazz: Class[_ <: T]) extends
Matcher[T] {
+ override def apply(t: T): Boolean = clazz.isInstance(t)
+ }
+
+ def or[T <: AnyRef](matchers: Matcher[T]*): Matcher[T] = Or(matchers)
+ def clazz[T <: AnyRef](clazz: Class[_ <: T]): Matcher[T] = Clazz(clazz)
+ }
+
trait Node[T <: AnyRef] {
// If abort returns true, caller should make sure not to call further
methods.
// It provides a way to fast fail the matching before actually jumping
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/Shape.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/Shape.scala
index 2a8255861..0dbc26440 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/Shape.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/Shape.scala
@@ -16,7 +16,7 @@
*/
package org.apache.gluten.ras.rule
-import org.apache.gluten.ras.path.{OutputWizard, OutputWizards, RasPath}
+import org.apache.gluten.ras.path.{OutputWizard, OutputWizards, PathKey,
RasPath}
// Shape is an abstraction for all inputs the rule can accept.
// Shape can be specification on pattern, height, or mask
@@ -41,10 +41,15 @@ object Shapes {
new None()
}
+ def anyOf[T <: AnyRef](shapes: Shape[T]*): Shape[T] = {
+ new AnyOf[T](shapes)
+ }
+
private class Pattern[T <: AnyRef](pattern:
org.apache.gluten.ras.path.Pattern[T])
extends Shape[T] {
- override def wizard(): OutputWizard[T] = OutputWizards.withPattern(pattern)
- override def identify(path: RasPath[T]): Boolean = pattern.matches(path,
path.height())
+ private val key = PathKey.random()
+ override def wizard(): OutputWizard[T] =
OutputWizards.withPattern(pattern).withPathKey(key)
+ override def identify(path: RasPath[T]): Boolean =
path.keys().keys().contains(key)
}
private class FixedHeight[T <: AnyRef](height: Int) extends Shape[T] {
@@ -56,4 +61,9 @@ object Shapes {
override def wizard(): OutputWizard[T] = OutputWizards.none()
override def identify(path: RasPath[T]): Boolean = false
}
+
+ private class AnyOf[T <: AnyRef](shapes: Seq[Shape[T]]) extends Shape[T] {
+ override def wizard(): OutputWizard[T] =
OutputWizards.union(shapes.map(_.wizard()))
+ override def identify(path: RasPath[T]): Boolean =
shapes.exists(_.identify(path))
+ }
}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
index 11f6051b0..018c8087e 100644
---
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
+++
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
@@ -29,6 +29,8 @@ class GraphvizVisualizer[T <: AnyRef](ras: Ras[T], memoState:
MemoState[T], best
private val allGroups = memoState.allGroups()
private val allClusters = memoState.clusterLookup()
+ private val nodeToId = mutable.Map[InGroupNode.HashKey, Int]()
+
def format(): String = {
val rootGroupId = best.rootGroupId()
val bestPath = best.path()
@@ -153,12 +155,14 @@ class GraphvizVisualizer[T <: AnyRef](ras: Ras[T],
memoState: MemoState[T], best
costs: InGroupNode[T] => Option[Cost],
group: RasGroup[T],
node: CanonicalNode[T]): String = {
- s"${describeGroup(group)}[Cost ${costs(InGroupNode(group.id(), node))
+ val ign = InGroupNode(group.id(), node)
+ val nodeId = nodeToId.getOrElseUpdate(ign.toHashKey, nodeToId.size)
+ s"[$nodeId][Cost ${costs(ign)
.map {
case c if ras.isInfCost(c) => "<INF>"
case other => other
}
- .getOrElse("N/A")}]${ras.explain.describeNode(node.self())}"
+ .getOrElse("N/A")}] ${ras.explain.describeNode(node.self())}"
}
}
diff --git
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/rule/PatternSuite.scala
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/rule/PatternSuite.scala
index 263b1869a..2a86f164d 100644
---
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/rule/PatternSuite.scala
+++
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/rule/PatternSuite.scala
@@ -173,6 +173,54 @@ class PatternSuite extends AnyFunSuite {
assert(pattern2.matches(path, 3))
assert(!pattern2.matches(path, 4))
}
+
+ test("Match class") {
+ val ras =
+ Ras[TestNode](
+ PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ RasRule.Factory.none())
+
+ val path = MockRasPath.mock(ras, Unary("n1", Leaf("n2", 1)))
+ assert(path.height() == 2)
+
+ val pattern1 = Pattern
+ .node[TestNode](
+ Pattern.Matchers.clazz(classOf[Unary]),
+ Pattern.node(Pattern.Matchers.clazz(classOf[Leaf])))
+ .build()
+ assert(pattern1.matches(path, 1))
+ assert(pattern1.matches(path, 2))
+
+ val pattern2 = Pattern
+ .leaf[TestNode](Pattern.Matchers.clazz(classOf[Leaf]))
+ .build()
+ assert(!pattern2.matches(path, 1))
+ assert(!pattern2.matches(path, 2))
+
+ val pattern3 = Pattern
+ .node[TestNode](
+ Pattern.Matchers
+ .or(Pattern.Matchers.clazz(classOf[Unary]),
Pattern.Matchers.clazz(classOf[Leaf])),
+ Pattern.node(Pattern.Matchers.clazz(classOf[Leaf])))
+ .build()
+ assert(pattern3.matches(path, 1))
+ assert(pattern3.matches(path, 2))
+
+ val pattern4 = Pattern
+ .node[TestNode](
+ Pattern.Matchers
+ .or(Pattern.Matchers.clazz(classOf[Unary]),
Pattern.Matchers.clazz(classOf[Leaf])),
+ Pattern.node(Pattern.Matchers
+ .or(Pattern.Matchers.clazz(classOf[Unary]),
Pattern.Matchers.clazz(classOf[Unary])))
+ )
+ .build()
+ assert(pattern4.matches(path, 1))
+ assert(!pattern4.matches(path, 2))
+ }
}
object PatternSuite {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]