This is an automated email from the ASF dual-hosted git repository.

hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 449ac3259 [CORE][VL] RAS: Pattern matching by node classes
449ac3259 is described below

commit 449ac32590ece0b1b1d451467906a4a44c911f44
Author: Hongze Zhang <[email protected]>
AuthorDate: Thu Apr 11 12:02:19 2024 +0800

    [CORE][VL] RAS: Pattern matching by node classes
---
 .../org/apache/gluten/ras/dp/DpClusterAlgo.scala   |  2 +-
 .../scala/org/apache/gluten/ras/dp/DpPlanner.scala |  2 +-
 .../apache/gluten/ras/memo/ForwardMemoTable.scala  | 10 ++--
 .../scala/org/apache/gluten/ras/memo/Memo.scala    | 43 ++---------------
 .../org/apache/gluten/ras/memo/MemoTable.scala     | 56 +++++++++++++++-------
 .../scala/org/apache/gluten/ras/path/Pattern.scala | 15 ++++++
 .../scala/org/apache/gluten/ras/rule/Shape.scala   | 16 +++++--
 .../apache/gluten/ras/vis/GraphvizVisualizer.scala |  8 +++-
 .../org/apache/gluten/ras/rule/PatternSuite.scala  | 48 +++++++++++++++++++
 9 files changed, 133 insertions(+), 67 deletions(-)

diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpClusterAlgo.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpClusterAlgo.scala
index e90ba448b..6fd95772b 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpClusterAlgo.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpClusterAlgo.scala
@@ -81,7 +81,7 @@ object DpClusterAlgo {
     }
 
     override def browseX(x: InClusterNode[T]): Iterable[RasClusterKey] = {
-      val allGroups = memoTable.allGroups()
+      val allGroups = memoTable.asGroupSupplier()
       x.can
         .getChildrenGroups(allGroups)
         .map(gn => allGroups(gn.groupId()).clusterKey())
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
index 8acf66c59..4a9e3f0f0 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala
@@ -60,7 +60,7 @@ private class DpPlanner[T <: AnyRef] private (
   }
 
   private def findBest(memoTable: MemoTable[T], groupId: Int): Best[T] = {
-    val cKey = memoTable.allGroups()(groupId).clusterKey()
+    val cKey = memoTable.asGroupSupplier()(groupId).clusterKey()
     val algoDef = new DpExploreAlgoDef[T]
     val adjustment = new ExploreAdjustment(ras, memoTable, rules, 
enforcerRuleSet)
     DpClusterAlgo.resolve(memoTable, algoDef, adjustment, cKey)
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
index 0895544d5..dd4033866 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala
@@ -155,11 +155,13 @@ class ForwardMemoTable[T <: AnyRef] private (override val 
ras: Ras[T])
     groupBuffer(id)
   }
 
-  override def allClusters(): Seq[RasClusterKey] = clusterKeyBuffer
+  override def allClusterKeys(): Seq[RasClusterKey] = clusterKeyBuffer
 
-  override def allGroups(): Seq[RasGroup[T]] = groupBuffer
-
-  override def allDummyGroups(): Seq[RasGroup[T]] = dummyGroupBuffer
+  override def allGroupIds(): Seq[Int] = {
+    val from = -dummyGroupBuffer.size
+    val to = groupBuffer.size
+    (from until to).toVector
+  }
 
   private def ancestorClusterIdOf(key: RasClusterKey): Int = {
     clusterDisjointSet.find(key.id())
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
index c1bb0a6bf..6406b8fb1 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala
@@ -18,7 +18,6 @@ package org.apache.gluten.ras.memo
 
 import org.apache.gluten.ras._
 import org.apache.gluten.ras.Ras.UnsafeKey
-import org.apache.gluten.ras.RasCluster.ImmutableRasCluster
 import org.apache.gluten.ras.property.PropertySet
 import org.apache.gluten.ras.vis.GraphvizVisualizer
 
@@ -78,7 +77,7 @@ object Memo {
       memoTable.getDummyGroup(clusterKey)
     }
 
-    private def toCacheKeyUnsafe(n: T): MemoCacheKey[T] = {
+    private def toCacheKey(n: T): MemoCacheKey[T] = {
       MemoCacheKey(ras, n)
     }
 
@@ -90,11 +89,11 @@ object Memo {
 
       val childrenPrepares = ras.planModel.childrenOf(n).map(child => 
prepareInsert(child))
 
-      val canUnsafe = ras.withNewChildren(
+      val keyUnsafe = ras.withNewChildren(
         n,
         childrenPrepares.map(childPrepare => 
dummyGroupOf(childPrepare.clusterKey()).self()))
 
-      val cacheKey = toCacheKeyUnsafe(canUnsafe)
+      val cacheKey = toCacheKey(keyUnsafe)
 
       val clusterKey = clusterOfUnsafe(ras.metadataModel.metadataOf(n), 
cacheKey)
 
@@ -144,13 +143,13 @@ object Memo {
         val childrenPrepares =
           ras.planModel.childrenOf(node).map(child => 
parent.prepareInsert(child))
 
-        val canUnsafe = ras.withNewChildren(
+        val keyUnsafe = ras.withNewChildren(
           node,
           childrenPrepares.map {
             childPrepare => 
parent.dummyGroupOf(childPrepare.clusterKey()).self()
           })
 
-        val cacheKey = parent.toCacheKeyUnsafe(canUnsafe)
+        val cacheKey = parent.toCacheKey(keyUnsafe)
 
         if (!parent.cache.contains(cacheKey)) {
           // The new node was not added to memo yet. Add it to the target 
cluster.
@@ -267,39 +266,7 @@ trait MemoState[T <: AnyRef] extends MemoStore[T] {
 }
 
 object MemoState {
-  def apply[T <: AnyRef](
-      ras: Ras[T],
-      clusterLookup: Map[RasClusterKey, ImmutableRasCluster[T]],
-      clusterDummyGroupLookup: Map[RasClusterKey, RasGroup[T]],
-      allGroups: Seq[RasGroup[T]],
-      allDummyGroups: Seq[RasGroup[T]]): MemoState[T] = {
-    MemoStateImpl(ras, clusterLookup, clusterDummyGroupLookup, allGroups, 
allDummyGroups)
-  }
-
-  private case class MemoStateImpl[T <: AnyRef](
-      override val ras: Ras[T],
-      override val clusterLookup: Map[RasClusterKey, ImmutableRasCluster[T]],
-      override val clusterDummyGroupLookup: Map[RasClusterKey, RasGroup[T]],
-      override val allGroups: Seq[RasGroup[T]],
-      allDummyGroups: Seq[RasGroup[T]])
-    extends MemoState[T] {
-    private val allClustersCopy = clusterLookup.values
-
-    override def getCluster(key: RasClusterKey): RasCluster[T] = 
clusterLookup(key)
-    override def getDummyGroup(key: RasClusterKey): RasGroup[T] = 
clusterDummyGroupLookup(key)
-    override def getGroup(id: Int): RasGroup[T] = {
-      if (id < 0) {
-        val out = allDummyGroups((-id - 1))
-        assert(out.id() == id)
-        return out
-      }
-      allGroups(id)
-    }
-    override def allClusters(): Iterable[RasCluster[T]] = allClustersCopy
-  }
-
   implicit class MemoStateImplicits[T <: AnyRef](state: MemoState[T]) {
-
     def formatGraphvizWithBest(best: Best[T]): String = {
       GraphvizVisualizer(state.ras(), state, best).format()
     }
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala
index 8e5dac02e..2e2323a1e 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala
@@ -25,12 +25,8 @@ sealed trait MemoTable[T <: AnyRef] extends MemoStore[T] {
 
   def ras: Ras[T]
 
-  override def getCluster(key: RasClusterKey): RasCluster[T]
-  override def getGroup(id: Int): RasGroup[T]
-
-  def allClusters(): Seq[RasClusterKey]
-  def allGroups(): Seq[RasGroup[T]]
-  def allDummyGroups(): Seq[RasGroup[T]]
+  def allClusterKeys(): Seq[RasClusterKey]
+  def allGroupIds(): Seq[Int]
 
   def getClusterPropSets(key: RasClusterKey): Set[PropertySet[T]]
 
@@ -68,26 +64,50 @@ object MemoTable {
     }
   }
 
+  private case class MemoStateImpl[T <: AnyRef](
+      override val ras: Ras[T],
+      override val clusterLookup: Map[RasClusterKey, ImmutableRasCluster[T]],
+      override val clusterDummyGroupLookup: Map[RasClusterKey, RasGroup[T]],
+      override val allGroups: Seq[RasGroup[T]],
+      idToGroup: Map[Int, RasGroup[T]])
+    extends MemoState[T] {
+    private val allClustersCopy = clusterLookup.values
+
+    override def getCluster(key: RasClusterKey): RasCluster[T] = 
clusterLookup(key)
+    override def getDummyGroup(key: RasClusterKey): RasGroup[T] = 
clusterDummyGroupLookup(key)
+    override def getGroup(id: Int): RasGroup[T] = idToGroup(id)
+    override def allClusters(): Iterable[RasCluster[T]] = allClustersCopy
+  }
+
   implicit class MemoTableImplicits[T <: AnyRef](table: MemoTable[T]) {
     def newState(): MemoState[T] = {
       val immutableClusters = table
-        .allClusters()
+        .allClusterKeys()
         .map(key => key -> ImmutableRasCluster(table.ras, 
table.getCluster(key)))
         .toMap
       val immutableDummyGroups = table
-        .allClusters()
+        .allClusterKeys()
         .map(key => key -> table.getDummyGroup(key))
         .toMap
-      table.allDummyGroups().zipWithIndex.foreach {
-        case (group, idx) =>
-          assert(group.id() == -(idx + 1))
-      }
-      MemoState(
-        table.ras,
-        immutableClusters,
-        immutableDummyGroups,
-        table.allGroups(),
-        table.allDummyGroups())
+
+      var maxGroupId = Int.MinValue
+
+      val groupMap = table
+        .allGroupIds()
+        .map {
+          gid =>
+            val group = table.getGroup(gid)
+            assert(group.id() == gid)
+            if (gid > maxGroupId) {
+              maxGroupId = gid
+            }
+            gid -> group
+        }
+        .toMap
+
+      val allGroups = (0 to maxGroupId).map(table.getGroup).toVector
+
+      MemoStateImpl(table.ras, immutableClusters, immutableDummyGroups, 
allGroups, groupMap)
     }
 
     def doExhaustively(func: => Unit): Unit = {
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala
index 0471fb42a..f20d05c7c 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala
@@ -27,6 +27,21 @@ trait Pattern[T <: AnyRef] {
 object Pattern {
   trait Matcher[T <: AnyRef] extends (T => Boolean)
 
+  object Matchers {
+    private case class Or[T <: AnyRef](matchers: Seq[Matcher[T]]) extends 
Matcher[T] {
+      override def apply(t: T): Boolean = {
+        matchers.exists(_(t))
+      }
+    }
+
+    private case class Clazz[T <: AnyRef](clazz: Class[_ <: T]) extends 
Matcher[T] {
+      override def apply(t: T): Boolean = clazz.isInstance(t)
+    }
+
+    def or[T <: AnyRef](matchers: Matcher[T]*): Matcher[T] = Or(matchers)
+    def clazz[T <: AnyRef](clazz: Class[_ <: T]): Matcher[T] = Clazz(clazz)
+  }
+
   trait Node[T <: AnyRef] {
     // If abort returns true, caller should make sure not to call further 
methods.
     // It provides a way to fast fail the matching before actually jumping
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/Shape.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/Shape.scala
index 2a8255861..0dbc26440 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/Shape.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/Shape.scala
@@ -16,7 +16,7 @@
  */
 package org.apache.gluten.ras.rule
 
-import org.apache.gluten.ras.path.{OutputWizard, OutputWizards, RasPath}
+import org.apache.gluten.ras.path.{OutputWizard, OutputWizards, PathKey, 
RasPath}
 
 // Shape is an abstraction for all inputs the rule can accept.
 // Shape can be specification on pattern, height, or mask
@@ -41,10 +41,15 @@ object Shapes {
     new None()
   }
 
+  def anyOf[T <: AnyRef](shapes: Shape[T]*): Shape[T] = {
+    new AnyOf[T](shapes)
+  }
+
   private class Pattern[T <: AnyRef](pattern: 
org.apache.gluten.ras.path.Pattern[T])
     extends Shape[T] {
-    override def wizard(): OutputWizard[T] = OutputWizards.withPattern(pattern)
-    override def identify(path: RasPath[T]): Boolean = pattern.matches(path, 
path.height())
+    private val key = PathKey.random()
+    override def wizard(): OutputWizard[T] = 
OutputWizards.withPattern(pattern).withPathKey(key)
+    override def identify(path: RasPath[T]): Boolean = 
path.keys().keys().contains(key)
   }
 
   private class FixedHeight[T <: AnyRef](height: Int) extends Shape[T] {
@@ -56,4 +61,9 @@ object Shapes {
     override def wizard(): OutputWizard[T] = OutputWizards.none()
     override def identify(path: RasPath[T]): Boolean = false
   }
+
+  private class AnyOf[T <: AnyRef](shapes: Seq[Shape[T]]) extends Shape[T] {
+    override def wizard(): OutputWizard[T] = 
OutputWizards.union(shapes.map(_.wizard()))
+    override def identify(path: RasPath[T]): Boolean = 
shapes.exists(_.identify(path))
+  }
 }
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
index 11f6051b0..018c8087e 100644
--- 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
+++ 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala
@@ -29,6 +29,8 @@ class GraphvizVisualizer[T <: AnyRef](ras: Ras[T], memoState: 
MemoState[T], best
   private val allGroups = memoState.allGroups()
   private val allClusters = memoState.clusterLookup()
 
+  private val nodeToId = mutable.Map[InGroupNode.HashKey, Int]()
+
   def format(): String = {
     val rootGroupId = best.rootGroupId()
     val bestPath = best.path()
@@ -153,12 +155,14 @@ class GraphvizVisualizer[T <: AnyRef](ras: Ras[T], 
memoState: MemoState[T], best
       costs: InGroupNode[T] => Option[Cost],
       group: RasGroup[T],
       node: CanonicalNode[T]): String = {
-    s"${describeGroup(group)}[Cost ${costs(InGroupNode(group.id(), node))
+    val ign = InGroupNode(group.id(), node)
+    val nodeId = nodeToId.getOrElseUpdate(ign.toHashKey, nodeToId.size)
+    s"[$nodeId][Cost ${costs(ign)
         .map {
           case c if ras.isInfCost(c) => "<INF>"
           case other => other
         }
-        .getOrElse("N/A")}]${ras.explain.describeNode(node.self())}"
+        .getOrElse("N/A")}] ${ras.explain.describeNode(node.self())}"
   }
 }
 
diff --git 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/rule/PatternSuite.scala
 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/rule/PatternSuite.scala
index 263b1869a..2a86f164d 100644
--- 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/rule/PatternSuite.scala
+++ 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/rule/PatternSuite.scala
@@ -173,6 +173,54 @@ class PatternSuite extends AnyFunSuite {
     assert(pattern2.matches(path, 3))
     assert(!pattern2.matches(path, 4))
   }
+
+  test("Match class") {
+    val ras =
+      Ras[TestNode](
+        PlanModelImpl,
+        CostModelImpl,
+        MetadataModelImpl,
+        PropertyModelImpl,
+        ExplainImpl,
+        RasRule.Factory.none())
+
+    val path = MockRasPath.mock(ras, Unary("n1", Leaf("n2", 1)))
+    assert(path.height() == 2)
+
+    val pattern1 = Pattern
+      .node[TestNode](
+        Pattern.Matchers.clazz(classOf[Unary]),
+        Pattern.node(Pattern.Matchers.clazz(classOf[Leaf])))
+      .build()
+    assert(pattern1.matches(path, 1))
+    assert(pattern1.matches(path, 2))
+
+    val pattern2 = Pattern
+      .leaf[TestNode](Pattern.Matchers.clazz(classOf[Leaf]))
+      .build()
+    assert(!pattern2.matches(path, 1))
+    assert(!pattern2.matches(path, 2))
+
+    val pattern3 = Pattern
+      .node[TestNode](
+        Pattern.Matchers
+          .or(Pattern.Matchers.clazz(classOf[Unary]), 
Pattern.Matchers.clazz(classOf[Leaf])),
+        Pattern.node(Pattern.Matchers.clazz(classOf[Leaf])))
+      .build()
+    assert(pattern3.matches(path, 1))
+    assert(pattern3.matches(path, 2))
+
+    val pattern4 = Pattern
+      .node[TestNode](
+        Pattern.Matchers
+          .or(Pattern.Matchers.clazz(classOf[Unary]), 
Pattern.Matchers.clazz(classOf[Leaf])),
+        Pattern.node(Pattern.Matchers
+          .or(Pattern.Matchers.clazz(classOf[Unary]), 
Pattern.Matchers.clazz(classOf[Unary])))
+      )
+      .build()
+    assert(pattern4.matches(path, 1))
+    assert(!pattern4.matches(path, 2))
+  }
 }
 
 object PatternSuite {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to