This is an automated email from the ASF dual-hosted git repository.

hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 90961bc90 [VL] RAS: Optimize offload rule code to gain better 
compatibility with rewrite rules (#5836)
90961bc90 is described below

commit 90961bc907955409e1f3b7c09af00aa3bf7abf16
Author: Hongze Zhang <[email protected]>
AuthorDate: Thu May 23 08:45:33 2024 +0800

    [VL] RAS: Optimize offload rule code to gain better compatibility with 
rewrite rules (#5836)
---
 .../columnar/enumerated/EnumeratedTransform.scala  |  59 ++++++---
 .../columnar/enumerated/PushFilterToScan.scala     |   6 +-
 .../extension/columnar/enumerated/RasOffload.scala | 147 +++++++++++++++------
 .../columnar/enumerated/RasOffloadFilter.scala     |   5 +-
 ...gregate.scala => RasOffloadHashAggregate.scala} |   7 +-
 .../columnar/enumerated/RemoveFilter.scala         |   2 +-
 .../extension/columnar/transition/Transition.scala |   2 +-
 .../scala/org/apache/gluten/ras/path/Pattern.scala |  22 ++-
 .../org/apache/gluten/ras/path/WizardSuite.scala   |  21 ++-
 .../org/apache/gluten/ras/rule/PatternSuite.scala  |  37 +++---
 10 files changed, 206 insertions(+), 102 deletions(-)

diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
index 50f0dce13..c41c1ca2c 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
@@ -16,21 +16,29 @@
  */
 package org.apache.gluten.extension.columnar.enumerated
 
-import org.apache.gluten.extension.columnar.{OffloadExchange, OffloadJoin, 
OffloadOthers, OffloadSingleNode}
+import org.apache.gluten.extension.columnar.{OffloadExchange, OffloadJoin, 
OffloadOthers}
 import org.apache.gluten.extension.columnar.transition.ConventionReq
 import org.apache.gluten.planner.GlutenOptimization
 import org.apache.gluten.planner.property.Conv
 import org.apache.gluten.ras.property.PropertySet
+import org.apache.gluten.sql.shims.SparkShimLoader
 import org.apache.gluten.utils.LogLevelUtil
 
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.aggregate.{ObjectHashAggregateExec, 
SortAggregateExec}
+import org.apache.spark.sql.execution.datasources.WriteFilesExec
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExecBase
+import org.apache.spark.sql.execution.exchange.Exchange
+import org.apache.spark.sql.execution.joins.BaseJoinExec
+import org.apache.spark.sql.execution.python.EvalPythonExec
+import org.apache.spark.sql.execution.window.WindowExec
+import org.apache.spark.sql.hive.HiveTableScanExecTransformer
 
 case class EnumeratedTransform(session: SparkSession, outputsColumnar: Boolean)
   extends Rule[SparkPlan]
   with LogLevelUtil {
-  import EnumeratedTransform._
 
   private val rules = List(
     new PushFilterToScan(RasOffload.validator),
@@ -40,11 +48,35 @@ case class EnumeratedTransform(session: SparkSession, 
outputsColumnar: Boolean)
   // TODO: Should obey ReplaceSingleNode#applyScanNotTransformable to select
   //  (vanilla) scan with cheaper sub-query plan through cost model.
   private val offloadRules = List(
-    new AsRasOffload(OffloadOthers()),
-    new AsRasOffload(OffloadExchange()),
-    new AsRasOffload(OffloadJoin()),
-    RasOffloadAggregate,
-    RasOffloadFilter
+    RasOffload.from[Exchange](OffloadExchange()).toRule,
+    RasOffload.from[BaseJoinExec](OffloadJoin()).toRule,
+    RasOffloadHashAggregate.toRule,
+    RasOffloadFilter.toRule,
+    RasOffload.from[DataSourceV2ScanExecBase](OffloadOthers()).toRule,
+    RasOffload.from[DataSourceScanExec](OffloadOthers()).toRule,
+    RasOffload
+      .from(
+        (node: SparkPlan) => 
HiveTableScanExecTransformer.isHiveTableScan(node),
+        OffloadOthers())
+      .toRule,
+    RasOffload.from[CoalesceExec](OffloadOthers()).toRule,
+    RasOffload.from[ProjectExec](OffloadOthers()).toRule,
+    RasOffload.from[SortAggregateExec](OffloadOthers()).toRule,
+    RasOffload.from[ObjectHashAggregateExec](OffloadOthers()).toRule,
+    RasOffload.from[UnionExec](OffloadOthers()).toRule,
+    RasOffload.from[ExpandExec](OffloadOthers()).toRule,
+    RasOffload.from[WriteFilesExec](OffloadOthers()).toRule,
+    RasOffload.from[SortExec](OffloadOthers()).toRule,
+    RasOffload.from[TakeOrderedAndProjectExec](OffloadOthers()).toRule,
+    RasOffload.from[WindowExec](OffloadOthers()).toRule,
+    RasOffload
+      .from(
+        (node: SparkPlan) => 
SparkShimLoader.getSparkShims.isWindowGroupLimitExec(node),
+        OffloadOthers())
+      .toRule,
+    RasOffload.from[LimitExec](OffloadOthers()).toRule,
+    RasOffload.from[GenerateExec](OffloadOthers()).toRule,
+    RasOffload.from[EvalPythonExec](OffloadOthers()).toRule
   )
 
   private val optimization = GlutenOptimization(rules ++ offloadRules)
@@ -67,13 +99,4 @@ case class EnumeratedTransform(session: SparkSession, 
outputsColumnar: Boolean)
   }
 }
 
-object EnumeratedTransform {
-
-  /** Accepts a [[OffloadSingleNode]] rule to convert it into a RAS offload 
rule. */
-  private class AsRasOffload(delegate: OffloadSingleNode) extends RasOffload {
-    override protected def offload(node: SparkPlan): SparkPlan = {
-      val out = delegate.offload(node)
-      out
-    }
-  }
-}
+object EnumeratedTransform {}
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/PushFilterToScan.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/PushFilterToScan.scala
index 388668287..611d6db0b 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/PushFilterToScan.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/PushFilterToScan.scala
@@ -50,16 +50,16 @@ class PushFilterToScan(validator: Validator) extends 
RasRule[SparkPlan] {
   override def shape(): Shape[SparkPlan] =
     anyOf(
       pattern(
-        node[SparkPlan](
+        branch[SparkPlan](
           clazz(classOf[FilterExec]),
           leaf(
             or(clazz(classOf[FileSourceScanExec]), 
clazz(classOf[BatchScanExec]))
           )
         ).build()),
       pattern(
-        node[SparkPlan](
+        branch[SparkPlan](
           clazz(classOf[FilterExec]),
-          node(
+          branch(
             clazz(classOf[ColumnarToRowTransition]),
             leaf(
               or(clazz(classOf[FileSourceScanExec]), 
clazz(classOf[BatchScanExec]))
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala
index 5cabfa88e..6af89dc05 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala
@@ -17,61 +17,39 @@
 package org.apache.gluten.extension.columnar.enumerated
 
 import org.apache.gluten.extension.GlutenPlan
+import org.apache.gluten.extension.columnar.OffloadSingleNode
 import org.apache.gluten.extension.columnar.rewrite.RewriteSingleNode
 import org.apache.gluten.extension.columnar.validator.{Validator, Validators}
-import org.apache.gluten.ras.rule.{RasRule, Shape, Shapes}
+import org.apache.gluten.ras.path.Pattern
+import org.apache.gluten.ras.path.Pattern.node
+import org.apache.gluten.ras.rule.{RasRule, Shape}
+import org.apache.gluten.ras.rule.Shapes.pattern
 
 import org.apache.spark.sql.execution.SparkPlan
 
-trait RasOffload extends RasRule[SparkPlan] {
-  import RasOffload._
+import scala.reflect.{classTag, ClassTag}
 
-  final override def shift(node: SparkPlan): Iterable[SparkPlan] = {
-    // 0. If the node is already offloaded, return fast.
-    if (node.isInstanceOf[GlutenPlan]) {
-      return List.empty
-    }
+trait RasOffload {
+  def offload(plan: SparkPlan): SparkPlan
+  def typeIdentifier(): RasOffload.TypeIdentifier
+}
 
-    // 1. Rewrite the node to form that native library supports.
-    val rewritten = rewrites.foldLeft(node) {
-      case (node, rewrite) =>
-        node.transformUp {
-          case p =>
-            val out = rewrite.rewrite(p)
-            out
-        }
-    }
+object RasOffload {
+  trait TypeIdentifier {
+    def isInstance(node: SparkPlan): Boolean
+  }
 
-    // 2. Walk the rewritten tree.
-    val offloaded = rewritten.transformUp {
-      case from =>
-        // 3. Validate current node. If passed, offload it.
-        validator.validate(from) match {
-          case Validator.Passed =>
-            offload(from) match {
-              case t: GlutenPlan if !t.doValidate().isValid =>
-                // 4. If native validation fails on the offloaded node, return 
the
-                // original one.
-                from
-              case other =>
-                other
-            }
-          case Validator.Failed(reason) =>
-            from
-        }
+  object TypeIdentifier {
+    def of[T <: SparkPlan: ClassTag]: TypeIdentifier = {
+      val nodeClass: Class[SparkPlan] =
+        classTag[T].runtimeClass.asInstanceOf[Class[SparkPlan]]
+      new TypeIdentifier {
+        override def isInstance(node: SparkPlan): Boolean = 
nodeClass.isInstance(node)
+      }
     }
-
-    // 5. Return the final tree.
-    List(offloaded)
   }
 
-  protected def offload(node: SparkPlan): SparkPlan
-
-  final override def shape(): Shape[SparkPlan] = Shapes.fixedHeight(1)
-}
-
-object RasOffload {
-  val validator = Validators
+  val validator: Validator = Validators
     .builder()
     .fallbackByHint()
     .fallbackIfScanOnly()
@@ -82,4 +60,85 @@ object RasOffload {
     .build()
 
   private val rewrites = RewriteSingleNode.allRules()
+
+  def from[T <: SparkPlan: ClassTag](base: OffloadSingleNode): RasOffload = {
+    new RasOffload {
+      override def offload(plan: SparkPlan): SparkPlan = base.offload(plan)
+      override def typeIdentifier(): TypeIdentifier = TypeIdentifier.of[T]
+    }
+  }
+
+  def from(identifier: TypeIdentifier, base: OffloadSingleNode): RasOffload = {
+    new RasOffload {
+      override def offload(plan: SparkPlan): SparkPlan = base.offload(plan)
+      override def typeIdentifier(): TypeIdentifier = identifier
+    }
+  }
+
+  implicit class RasOffloadOps(base: RasOffload) {
+    def toRule: RasRule[SparkPlan] = {
+      new RuleImpl(base)
+    }
+  }
+
+  private class RuleImpl(base: RasOffload) extends RasRule[SparkPlan] {
+    private val typeIdentifier: TypeIdentifier = base.typeIdentifier()
+
+    final override def shift(node: SparkPlan): Iterable[SparkPlan] = {
+      // 0. If the node is already offloaded, fail fast.
+      assert(typeIdentifier.isInstance(node))
+
+      // 1. Rewrite the node to form that native library supports.
+      val rewritten = rewrites.foldLeft(node) {
+        case (node, rewrite) =>
+          node.transformUp {
+            case p =>
+              val out = rewrite.rewrite(p)
+              out
+          }
+      }
+
+      // 2. Walk the rewritten tree.
+      val offloaded = rewritten.transformUp {
+        case from if typeIdentifier.isInstance(from) =>
+          // 3. Validate current node. If passed, offload it.
+          validator.validate(from) match {
+            case Validator.Passed =>
+              val offloaded = base.offload(from)
+              offloaded match {
+                case t: GlutenPlan if !t.doValidate().isValid =>
+                  // 4. If native validation fails on the offloaded node, 
return the
+                  // original one.
+                  from
+                case other =>
+                  other
+              }
+            case Validator.Failed(reason) =>
+              from
+          }
+      }
+
+      // 5. If rewritten plan is not offload-able, discard it.
+      if (offloaded.fastEquals(rewritten)) {
+        return List.empty
+      }
+
+      // 6. Otherwise, return the final tree.
+      List(offloaded)
+    }
+
+    override def shape(): Shape[SparkPlan] = {
+      pattern(node[SparkPlan](new Pattern.Matcher[SparkPlan] {
+        override def apply(plan: SparkPlan): Boolean = {
+          if (plan.isInstanceOf[GlutenPlan]) {
+            return false
+          }
+          if (typeIdentifier.isInstance(plan)) {
+            return true
+          }
+          false
+        }
+      }).build())
+    }
+  }
 }
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffloadFilter.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffloadFilter.scala
index 030d05d47..54ab9158b 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffloadFilter.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffloadFilter.scala
@@ -21,7 +21,7 @@ import org.apache.gluten.backendsapi.BackendsApiManager
 import org.apache.spark.sql.execution.{FilterExec, SparkPlan}
 
 object RasOffloadFilter extends RasOffload {
-  override protected def offload(node: SparkPlan): SparkPlan = node match {
+  override def offload(node: SparkPlan): SparkPlan = node match {
     case FilterExec(condition, child) =>
       val out = BackendsApiManager.getSparkPlanExecApiInstance
         .genFilterExecTransformer(condition, child)
@@ -29,4 +29,7 @@ object RasOffloadFilter extends RasOffload {
     case other =>
       other
   }
+
+  override def typeIdentifier(): RasOffload.TypeIdentifier =
+    RasOffload.TypeIdentifier.of[FilterExec]
 }
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffloadAggregate.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffloadHashAggregate.scala
similarity index 83%
rename from 
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffloadAggregate.scala
rename to 
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffloadHashAggregate.scala
index e48545ae9..6c125478b 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffloadAggregate.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffloadHashAggregate.scala
@@ -21,11 +21,14 @@ import 
org.apache.gluten.execution.HashAggregateExecBaseTransformer
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.aggregate.HashAggregateExec
 
-object RasOffloadAggregate extends RasOffload {
-  override protected def offload(node: SparkPlan): SparkPlan = node match {
+object RasOffloadHashAggregate extends RasOffload {
+  override def offload(node: SparkPlan): SparkPlan = node match {
     case agg: HashAggregateExec =>
       val out = HashAggregateExecBaseTransformer.from(agg)()
       out
     case other => other
   }
+
+  override def typeIdentifier(): RasOffload.TypeIdentifier =
+    RasOffload.TypeIdentifier.of[HashAggregateExec]
 }
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RemoveFilter.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RemoveFilter.scala
index c9f4b27bf..46b3b7f9e 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RemoveFilter.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RemoveFilter.scala
@@ -42,7 +42,7 @@ object RemoveFilter extends RasRule[SparkPlan] {
 
   override def shape(): Shape[SparkPlan] =
     pattern(
-      node[SparkPlan](
+      branch[SparkPlan](
         clazz(classOf[FilterExecTransformerBase]),
         leaf(clazz(classOf[BasicScanExecTransformer]))
       ).build())
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
index 73a126f8d..3fd2839b5 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
@@ -31,7 +31,7 @@ import scala.collection.mutable
 trait Transition {
   final def apply(plan: SparkPlan): SparkPlan = {
     val out = apply0(plan)
-    if (out.fastEquals(plan)) {
+    if (out eq plan) {
       assert(
         this == Transition.empty,
         "TransitionDef.empty / Transition.empty should be used when defining 
an empty transition.")
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala
index f20d05c7c..e60a94717 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala
@@ -52,7 +52,7 @@ object Pattern {
     def children(count: Int): Seq[Node[T]]
   }
 
-  private case class Any[T <: AnyRef]() extends Node[Null] {
+  private case class Any private () extends Node[Null] {
     override def skip(): Boolean = false
     override def abort(node: CanonicalNode[Null]): Boolean = false
     override def matches(node: CanonicalNode[Null]): Boolean = true
@@ -60,12 +60,12 @@ object Pattern {
   }
 
   private object Any {
-    val INSTANCE: Any[Null] = Any[Null]()
+    val INSTANCE: Node[Null] = new Any()
     // Enclose default constructor.
-    private def apply[T <: AnyRef](): Any[T] = new Any()
+    private def apply(): Node[Null] = throw new UnsupportedOperationException()
   }
 
-  private case class Ignore[T <: AnyRef]() extends Node[Null] {
+  private case class Ignore private () extends Node[Null] {
     override def skip(): Boolean = true
     override def abort(node: CanonicalNode[Null]): Boolean = false
     override def matches(node: CanonicalNode[Null]): Boolean =
@@ -74,10 +74,17 @@ object Pattern {
   }
 
   private object Ignore {
-    val INSTANCE: Ignore[Null] = Ignore[Null]()
+    val INSTANCE: Node[Null] = new Ignore()
 
     // Enclose default constructor.
-    private def apply[T <: AnyRef](): Ignore[T] = new Ignore()
+    private def apply(): Node[Null] = throw new UnsupportedOperationException()
+  }
+
+  private case class Single[T <: AnyRef](matcher: Matcher[T]) extends Node[T] {
+    override def skip(): Boolean = false
+    override def abort(node: CanonicalNode[T]): Boolean = false
+    override def matches(node: CanonicalNode[T]): Boolean = 
matcher(node.self())
+    override def children(count: Int): Seq[Node[T]] = (0 until count).map(_ => 
ignore[T])
   }
 
   private case class Branch[T <: AnyRef](matcher: Matcher[T], children: 
Seq[Node[T]])
@@ -93,7 +100,8 @@ object Pattern {
 
   def any[T <: AnyRef]: Node[T] = Any.INSTANCE.asInstanceOf[Node[T]]
   def ignore[T <: AnyRef]: Node[T] = Ignore.INSTANCE.asInstanceOf[Node[T]]
-  def node[T <: AnyRef](matcher: Matcher[T], children: Node[T]*): Node[T] =
+  def node[T <: AnyRef](matcher: Matcher[T]): Node[T] = Single(matcher)
+  def branch[T <: AnyRef](matcher: Matcher[T], children: Node[T]*): Node[T] =
     Branch(matcher, children.toSeq)
   def leaf[T <: AnyRef](matcher: Matcher[T]): Node[T] = Branch(matcher, 
List.empty)
 
diff --git 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/WizardSuite.scala 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/WizardSuite.scala
index 523a22689..59cc44600 100644
--- 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/WizardSuite.scala
+++ 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/WizardSuite.scala
@@ -198,18 +198,18 @@ class WizardSuite extends AnyFunSuite {
       findWithPatterns(
         List(
           Pattern
-            .node[TestNode](
+            .branch[TestNode](
               _ => true,
-              Pattern.node(_ => true, Pattern.ignore),
-              Pattern.node(_ => true, Pattern.ignore))
+              Pattern.branch(_ => true, Pattern.ignore),
+              Pattern.branch(_ => true, Pattern.ignore))
             .build())) == List(Binary(n1, Unary(n2, Group(3)), Unary(n3, 
Group(4)))))
 
     // Pattern pruning should emit all results
     val pattern1 = Pattern
-      .node[TestNode](_ => true, Pattern.node(_ => true, Pattern.ignore), 
Pattern.ignore)
+      .branch[TestNode](_ => true, Pattern.branch(_ => true, Pattern.ignore), 
Pattern.ignore)
       .build()
     val pattern2 = Pattern
-      .node[TestNode](_ => true, Pattern.ignore, Pattern.node(_ => true, 
Pattern.ignore))
+      .branch[TestNode](_ => true, Pattern.ignore, Pattern.branch(_ => true, 
Pattern.ignore))
       .build()
 
     assert(
@@ -219,10 +219,10 @@ class WizardSuite extends AnyFunSuite {
 
     // Distinguish between ignore and any
     val pattern3 = Pattern
-      .node[TestNode](_ => true, Pattern.node(_ => true, Pattern.any), 
Pattern.ignore)
+      .branch[TestNode](_ => true, Pattern.branch(_ => true, Pattern.any), 
Pattern.ignore)
       .build()
     val pattern4 = Pattern
-      .node[TestNode](_ => true, Pattern.ignore, Pattern.node(_ => true, 
Pattern.any))
+      .branch[TestNode](_ => true, Pattern.ignore, Pattern.branch(_ => true, 
Pattern.any))
       .build()
 
     assert(
@@ -231,6 +231,13 @@ class WizardSuite extends AnyFunSuite {
         Binary(n1, Group(1), Unary(n3, Leaf(n6, 1))),
         Binary(n1, Unary(n2, Leaf(n4, 1)), Group(2))))
 
+    // Single
+    val pattern5 = Pattern.node[TestNode](_ => true).build()
+    assert(findWithPatterns(List(pattern5)) == List(Binary(n1, Group(1), 
Group(2))))
+    val pattern6 = Pattern.node[TestNode](_.isInstanceOf[Binary]).build()
+    assert(findWithPatterns(List(pattern6)) == List(Binary(n1, Group(1), 
Group(2))))
+    val pattern7 = Pattern.node[TestNode](_.isInstanceOf[Leaf]).build()
+    assert(findWithPatterns(List(pattern7)).isEmpty)
   }
 
   test("Prune by mask") {
diff --git 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/rule/PatternSuite.scala
 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/rule/PatternSuite.scala
index 2a86f164d..64b66bbaf 100644
--- 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/rule/PatternSuite.scala
+++ 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/rule/PatternSuite.scala
@@ -72,12 +72,12 @@ class PatternSuite extends AnyFunSuite {
     val path = MockRasPath.mock(ras, Unary("n1", Leaf("n2", 1)))
     assert(path.height() == 2)
 
-    val pattern1 = Pattern.node[TestNode](n => n.isInstanceOf[Unary], 
Pattern.ignore).build()
+    val pattern1 = Pattern.branch[TestNode](n => n.isInstanceOf[Unary], 
Pattern.ignore).build()
     assert(pattern1.matches(path, 1))
     assert(pattern1.matches(path, 2))
 
     val pattern2 =
-      Pattern.node[TestNode](n => n.asInstanceOf[Unary].name == "foo", 
Pattern.ignore).build()
+      Pattern.branch[TestNode](n => n.asInstanceOf[Unary].name == "foo", 
Pattern.ignore).build()
     assert(!pattern2.matches(path, 1))
     assert(!pattern2.matches(path, 2))
   }
@@ -98,11 +98,11 @@ class PatternSuite extends AnyFunSuite {
     assert(path.height() == 4)
 
     val pattern = Pattern
-      .node[TestNode](
+      .branch[TestNode](
         n => n.isInstanceOf[Binary],
-        Pattern.node(
+        Pattern.branch(
           n => n.isInstanceOf[Unary],
-          Pattern.node(
+          Pattern.branch(
             n => n.isInstanceOf[Unary],
             Pattern.ignore
           )
@@ -131,11 +131,11 @@ class PatternSuite extends AnyFunSuite {
     assert(path.height() == 4)
 
     val pattern1 = Pattern
-      .node[TestNode](
+      .branch[TestNode](
         n => n.isInstanceOf[Binary],
-        Pattern.node(
+        Pattern.branch(
           n => n.isInstanceOf[Unary],
-          Pattern.node(
+          Pattern.branch(
             n => n.isInstanceOf[Unary],
             Pattern.leaf(
               _.asInstanceOf[Leaf].name == "foo"
@@ -152,13 +152,13 @@ class PatternSuite extends AnyFunSuite {
     assert(!pattern1.matches(path, 4))
 
     val pattern2 = Pattern
-      .node[TestNode](
+      .branch[TestNode](
         n => n.isInstanceOf[Binary],
-        Pattern.node(
+        Pattern.branch(
           n => n.isInstanceOf[Unary],
-          Pattern.node(
+          Pattern.branch(
             n => n.isInstanceOf[Unary],
-            Pattern.node(
+            Pattern.branch(
               n => n.isInstanceOf[Unary],
               Pattern.ignore
             )
@@ -188,9 +188,9 @@ class PatternSuite extends AnyFunSuite {
     assert(path.height() == 2)
 
     val pattern1 = Pattern
-      .node[TestNode](
+      .branch[TestNode](
         Pattern.Matchers.clazz(classOf[Unary]),
-        Pattern.node(Pattern.Matchers.clazz(classOf[Leaf])))
+        Pattern.branch(Pattern.Matchers.clazz(classOf[Leaf])))
       .build()
     assert(pattern1.matches(path, 1))
     assert(pattern1.matches(path, 2))
@@ -202,19 +202,20 @@ class PatternSuite extends AnyFunSuite {
     assert(!pattern2.matches(path, 2))
 
     val pattern3 = Pattern
-      .node[TestNode](
+      .branch[TestNode](
         Pattern.Matchers
           .or(Pattern.Matchers.clazz(classOf[Unary]), 
Pattern.Matchers.clazz(classOf[Leaf])),
-        Pattern.node(Pattern.Matchers.clazz(classOf[Leaf])))
+        Pattern.branch(Pattern.Matchers.clazz(classOf[Leaf]))
+      )
       .build()
     assert(pattern3.matches(path, 1))
     assert(pattern3.matches(path, 2))
 
     val pattern4 = Pattern
-      .node[TestNode](
+      .branch[TestNode](
         Pattern.Matchers
           .or(Pattern.Matchers.clazz(classOf[Unary]), 
Pattern.Matchers.clazz(classOf[Leaf])),
-        Pattern.node(Pattern.Matchers
+        Pattern.branch(Pattern.Matchers
           .or(Pattern.Matchers.clazz(classOf[Unary]), 
Pattern.Matchers.clazz(classOf[Unary])))
       )
       .build()


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to