This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 90961bc90 [VL] RAS: Optimize offload rule code to gain better
compatibility with rewrite rules (#5836)
90961bc90 is described below
commit 90961bc907955409e1f3b7c09af00aa3bf7abf16
Author: Hongze Zhang <[email protected]>
AuthorDate: Thu May 23 08:45:33 2024 +0800
[VL] RAS: Optimize offload rule code to gain better compatibility with
rewrite rules (#5836)
---
.../columnar/enumerated/EnumeratedTransform.scala | 59 ++++++---
.../columnar/enumerated/PushFilterToScan.scala | 6 +-
.../extension/columnar/enumerated/RasOffload.scala | 147 +++++++++++++++------
.../columnar/enumerated/RasOffloadFilter.scala | 5 +-
...gregate.scala => RasOffloadHashAggregate.scala} | 7 +-
.../columnar/enumerated/RemoveFilter.scala | 2 +-
.../extension/columnar/transition/Transition.scala | 2 +-
.../scala/org/apache/gluten/ras/path/Pattern.scala | 22 ++-
.../org/apache/gluten/ras/path/WizardSuite.scala | 21 ++-
.../org/apache/gluten/ras/rule/PatternSuite.scala | 37 +++---
10 files changed, 206 insertions(+), 102 deletions(-)
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
index 50f0dce13..c41c1ca2c 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
@@ -16,21 +16,29 @@
*/
package org.apache.gluten.extension.columnar.enumerated
-import org.apache.gluten.extension.columnar.{OffloadExchange, OffloadJoin,
OffloadOthers, OffloadSingleNode}
+import org.apache.gluten.extension.columnar.{OffloadExchange, OffloadJoin,
OffloadOthers}
import org.apache.gluten.extension.columnar.transition.ConventionReq
import org.apache.gluten.planner.GlutenOptimization
import org.apache.gluten.planner.property.Conv
import org.apache.gluten.ras.property.PropertySet
+import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.utils.LogLevelUtil
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.aggregate.{ObjectHashAggregateExec,
SortAggregateExec}
+import org.apache.spark.sql.execution.datasources.WriteFilesExec
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExecBase
+import org.apache.spark.sql.execution.exchange.Exchange
+import org.apache.spark.sql.execution.joins.BaseJoinExec
+import org.apache.spark.sql.execution.python.EvalPythonExec
+import org.apache.spark.sql.execution.window.WindowExec
+import org.apache.spark.sql.hive.HiveTableScanExecTransformer
case class EnumeratedTransform(session: SparkSession, outputsColumnar: Boolean)
extends Rule[SparkPlan]
with LogLevelUtil {
- import EnumeratedTransform._
private val rules = List(
new PushFilterToScan(RasOffload.validator),
@@ -40,11 +48,35 @@ case class EnumeratedTransform(session: SparkSession,
outputsColumnar: Boolean)
// TODO: Should obey ReplaceSingleNode#applyScanNotTransformable to select
// (vanilla) scan with cheaper sub-query plan through cost model.
private val offloadRules = List(
- new AsRasOffload(OffloadOthers()),
- new AsRasOffload(OffloadExchange()),
- new AsRasOffload(OffloadJoin()),
- RasOffloadAggregate,
- RasOffloadFilter
+ RasOffload.from[Exchange](OffloadExchange()).toRule,
+ RasOffload.from[BaseJoinExec](OffloadJoin()).toRule,
+ RasOffloadHashAggregate.toRule,
+ RasOffloadFilter.toRule,
+ RasOffload.from[DataSourceV2ScanExecBase](OffloadOthers()).toRule,
+ RasOffload.from[DataSourceScanExec](OffloadOthers()).toRule,
+ RasOffload
+ .from(
+ (node: SparkPlan) =>
HiveTableScanExecTransformer.isHiveTableScan(node),
+ OffloadOthers())
+ .toRule,
+ RasOffload.from[CoalesceExec](OffloadOthers()).toRule,
+ RasOffload.from[ProjectExec](OffloadOthers()).toRule,
+ RasOffload.from[SortAggregateExec](OffloadOthers()).toRule,
+ RasOffload.from[ObjectHashAggregateExec](OffloadOthers()).toRule,
+ RasOffload.from[UnionExec](OffloadOthers()).toRule,
+ RasOffload.from[ExpandExec](OffloadOthers()).toRule,
+ RasOffload.from[WriteFilesExec](OffloadOthers()).toRule,
+ RasOffload.from[SortExec](OffloadOthers()).toRule,
+ RasOffload.from[TakeOrderedAndProjectExec](OffloadOthers()).toRule,
+ RasOffload.from[WindowExec](OffloadOthers()).toRule,
+ RasOffload
+ .from(
+ (node: SparkPlan) =>
SparkShimLoader.getSparkShims.isWindowGroupLimitExec(node),
+ OffloadOthers())
+ .toRule,
+ RasOffload.from[LimitExec](OffloadOthers()).toRule,
+ RasOffload.from[GenerateExec](OffloadOthers()).toRule,
+ RasOffload.from[EvalPythonExec](OffloadOthers()).toRule
)
private val optimization = GlutenOptimization(rules ++ offloadRules)
@@ -67,13 +99,4 @@ case class EnumeratedTransform(session: SparkSession,
outputsColumnar: Boolean)
}
}
-object EnumeratedTransform {
-
- /** Accepts a [[OffloadSingleNode]] rule to convert it into a RAS offload
rule. */
- private class AsRasOffload(delegate: OffloadSingleNode) extends RasOffload {
- override protected def offload(node: SparkPlan): SparkPlan = {
- val out = delegate.offload(node)
- out
- }
- }
-}
+object EnumeratedTransform {}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/PushFilterToScan.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/PushFilterToScan.scala
index 388668287..611d6db0b 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/PushFilterToScan.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/PushFilterToScan.scala
@@ -50,16 +50,16 @@ class PushFilterToScan(validator: Validator) extends
RasRule[SparkPlan] {
override def shape(): Shape[SparkPlan] =
anyOf(
pattern(
- node[SparkPlan](
+ branch[SparkPlan](
clazz(classOf[FilterExec]),
leaf(
or(clazz(classOf[FileSourceScanExec]),
clazz(classOf[BatchScanExec]))
)
).build()),
pattern(
- node[SparkPlan](
+ branch[SparkPlan](
clazz(classOf[FilterExec]),
- node(
+ branch(
clazz(classOf[ColumnarToRowTransition]),
leaf(
or(clazz(classOf[FileSourceScanExec]),
clazz(classOf[BatchScanExec]))
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala
index 5cabfa88e..6af89dc05 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffload.scala
@@ -17,61 +17,39 @@
package org.apache.gluten.extension.columnar.enumerated
import org.apache.gluten.extension.GlutenPlan
+import org.apache.gluten.extension.columnar.OffloadSingleNode
import org.apache.gluten.extension.columnar.rewrite.RewriteSingleNode
import org.apache.gluten.extension.columnar.validator.{Validator, Validators}
-import org.apache.gluten.ras.rule.{RasRule, Shape, Shapes}
+import org.apache.gluten.ras.path.Pattern
+import org.apache.gluten.ras.path.Pattern.node
+import org.apache.gluten.ras.rule.{RasRule, Shape}
+import org.apache.gluten.ras.rule.Shapes.pattern
import org.apache.spark.sql.execution.SparkPlan
-trait RasOffload extends RasRule[SparkPlan] {
- import RasOffload._
+import scala.reflect.{classTag, ClassTag}
- final override def shift(node: SparkPlan): Iterable[SparkPlan] = {
- // 0. If the node is already offloaded, return fast.
- if (node.isInstanceOf[GlutenPlan]) {
- return List.empty
- }
+trait RasOffload {
+ def offload(plan: SparkPlan): SparkPlan
+ def typeIdentifier(): RasOffload.TypeIdentifier
+}
- // 1. Rewrite the node to form that native library supports.
- val rewritten = rewrites.foldLeft(node) {
- case (node, rewrite) =>
- node.transformUp {
- case p =>
- val out = rewrite.rewrite(p)
- out
- }
- }
+object RasOffload {
+ trait TypeIdentifier {
+ def isInstance(node: SparkPlan): Boolean
+ }
- // 2. Walk the rewritten tree.
- val offloaded = rewritten.transformUp {
- case from =>
- // 3. Validate current node. If passed, offload it.
- validator.validate(from) match {
- case Validator.Passed =>
- offload(from) match {
- case t: GlutenPlan if !t.doValidate().isValid =>
- // 4. If native validation fails on the offloaded node, return
the
- // original one.
- from
- case other =>
- other
- }
- case Validator.Failed(reason) =>
- from
- }
+ object TypeIdentifier {
+ def of[T <: SparkPlan: ClassTag]: TypeIdentifier = {
+ val nodeClass: Class[SparkPlan] =
+ classTag[T].runtimeClass.asInstanceOf[Class[SparkPlan]]
+ new TypeIdentifier {
+ override def isInstance(node: SparkPlan): Boolean =
nodeClass.isInstance(node)
+ }
}
-
- // 5. Return the final tree.
- List(offloaded)
}
- protected def offload(node: SparkPlan): SparkPlan
-
- final override def shape(): Shape[SparkPlan] = Shapes.fixedHeight(1)
-}
-
-object RasOffload {
- val validator = Validators
+ val validator: Validator = Validators
.builder()
.fallbackByHint()
.fallbackIfScanOnly()
@@ -82,4 +60,85 @@ object RasOffload {
.build()
private val rewrites = RewriteSingleNode.allRules()
+
+ def from[T <: SparkPlan: ClassTag](base: OffloadSingleNode): RasOffload = {
+ new RasOffload {
+ override def offload(plan: SparkPlan): SparkPlan = base.offload(plan)
+ override def typeIdentifier(): TypeIdentifier = TypeIdentifier.of[T]
+ }
+ }
+
+ def from(identifier: TypeIdentifier, base: OffloadSingleNode): RasOffload = {
+ new RasOffload {
+ override def offload(plan: SparkPlan): SparkPlan = base.offload(plan)
+ override def typeIdentifier(): TypeIdentifier = identifier
+ }
+ }
+
+ implicit class RasOffloadOps(base: RasOffload) {
+ def toRule: RasRule[SparkPlan] = {
+ new RuleImpl(base)
+ }
+ }
+
+ private class RuleImpl(base: RasOffload) extends RasRule[SparkPlan] {
+ private val typeIdentifier: TypeIdentifier = base.typeIdentifier()
+
+ final override def shift(node: SparkPlan): Iterable[SparkPlan] = {
+ // 0. If the node is already offloaded, fail fast.
+ assert(typeIdentifier.isInstance(node))
+
+ // 1. Rewrite the node to form that native library supports.
+ val rewritten = rewrites.foldLeft(node) {
+ case (node, rewrite) =>
+ node.transformUp {
+ case p =>
+ val out = rewrite.rewrite(p)
+ out
+ }
+ }
+
+ // 2. Walk the rewritten tree.
+ val offloaded = rewritten.transformUp {
+ case from if typeIdentifier.isInstance(from) =>
+ // 3. Validate current node. If passed, offload it.
+ validator.validate(from) match {
+ case Validator.Passed =>
+ val offloaded = base.offload(from)
+ offloaded match {
+ case t: GlutenPlan if !t.doValidate().isValid =>
+ // 4. If native validation fails on the offloaded node,
return the
+ // original one.
+ from
+ case other =>
+ other
+ }
+ case Validator.Failed(reason) =>
+ from
+ }
+ }
+
+ // 5. If rewritten plan is not offload-able, discard it.
+ if (offloaded.fastEquals(rewritten)) {
+ return List.empty
+ }
+
+ // 6. Otherwise, return the final tree.
+ List(offloaded)
+ }
+
+ override def shape(): Shape[SparkPlan] = {
+ pattern(node[SparkPlan](new Pattern.Matcher[SparkPlan] {
+ override def apply(plan: SparkPlan): Boolean = {
+ if (plan.isInstanceOf[GlutenPlan]) {
+ return false
+ }
+ if (typeIdentifier.isInstance(plan)) {
+ return true
+ }
+ false
+ }
+ }).build())
+ }
+ }
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffloadFilter.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffloadFilter.scala
index 030d05d47..54ab9158b 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffloadFilter.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffloadFilter.scala
@@ -21,7 +21,7 @@ import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.spark.sql.execution.{FilterExec, SparkPlan}
object RasOffloadFilter extends RasOffload {
- override protected def offload(node: SparkPlan): SparkPlan = node match {
+ override def offload(node: SparkPlan): SparkPlan = node match {
case FilterExec(condition, child) =>
val out = BackendsApiManager.getSparkPlanExecApiInstance
.genFilterExecTransformer(condition, child)
@@ -29,4 +29,7 @@ object RasOffloadFilter extends RasOffload {
case other =>
other
}
+
+ override def typeIdentifier(): RasOffload.TypeIdentifier =
+ RasOffload.TypeIdentifier.of[FilterExec]
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffloadAggregate.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffloadHashAggregate.scala
similarity index 83%
rename from
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffloadAggregate.scala
rename to
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffloadHashAggregate.scala
index e48545ae9..6c125478b 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffloadAggregate.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RasOffloadHashAggregate.scala
@@ -21,11 +21,14 @@ import
org.apache.gluten.execution.HashAggregateExecBaseTransformer
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
-object RasOffloadAggregate extends RasOffload {
- override protected def offload(node: SparkPlan): SparkPlan = node match {
+object RasOffloadHashAggregate extends RasOffload {
+ override def offload(node: SparkPlan): SparkPlan = node match {
case agg: HashAggregateExec =>
val out = HashAggregateExecBaseTransformer.from(agg)()
out
case other => other
}
+
+ override def typeIdentifier(): RasOffload.TypeIdentifier =
+ RasOffload.TypeIdentifier.of[HashAggregateExec]
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RemoveFilter.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RemoveFilter.scala
index c9f4b27bf..46b3b7f9e 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RemoveFilter.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RemoveFilter.scala
@@ -42,7 +42,7 @@ object RemoveFilter extends RasRule[SparkPlan] {
override def shape(): Shape[SparkPlan] =
pattern(
- node[SparkPlan](
+ branch[SparkPlan](
clazz(classOf[FilterExecTransformerBase]),
leaf(clazz(classOf[BasicScanExecTransformer]))
).build())
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
index 73a126f8d..3fd2839b5 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
@@ -31,7 +31,7 @@ import scala.collection.mutable
trait Transition {
final def apply(plan: SparkPlan): SparkPlan = {
val out = apply0(plan)
- if (out.fastEquals(plan)) {
+ if (out eq plan) {
assert(
this == Transition.empty,
"TransitionDef.empty / Transition.empty should be used when defining
an empty transition.")
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala
index f20d05c7c..e60a94717 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala
@@ -52,7 +52,7 @@ object Pattern {
def children(count: Int): Seq[Node[T]]
}
- private case class Any[T <: AnyRef]() extends Node[Null] {
+ private case class Any private () extends Node[Null] {
override def skip(): Boolean = false
override def abort(node: CanonicalNode[Null]): Boolean = false
override def matches(node: CanonicalNode[Null]): Boolean = true
@@ -60,12 +60,12 @@ object Pattern {
}
private object Any {
- val INSTANCE: Any[Null] = Any[Null]()
+ val INSTANCE: Node[Null] = new Any()
// Enclose default constructor.
- private def apply[T <: AnyRef](): Any[T] = new Any()
+ private def apply(): Node[Null] = throw new UnsupportedOperationException()
}
- private case class Ignore[T <: AnyRef]() extends Node[Null] {
+ private case class Ignore private () extends Node[Null] {
override def skip(): Boolean = true
override def abort(node: CanonicalNode[Null]): Boolean = false
override def matches(node: CanonicalNode[Null]): Boolean =
@@ -74,10 +74,17 @@ object Pattern {
}
private object Ignore {
- val INSTANCE: Ignore[Null] = Ignore[Null]()
+ val INSTANCE: Node[Null] = new Ignore()
// Enclose default constructor.
- private def apply[T <: AnyRef](): Ignore[T] = new Ignore()
+ private def apply(): Node[Null] = throw new UnsupportedOperationException()
+ }
+
+ private case class Single[T <: AnyRef](matcher: Matcher[T]) extends Node[T] {
+ override def skip(): Boolean = false
+ override def abort(node: CanonicalNode[T]): Boolean = false
+ override def matches(node: CanonicalNode[T]): Boolean =
matcher(node.self())
+ override def children(count: Int): Seq[Node[T]] = (0 until count).map(_ =>
ignore[T])
}
private case class Branch[T <: AnyRef](matcher: Matcher[T], children:
Seq[Node[T]])
@@ -93,7 +100,8 @@ object Pattern {
def any[T <: AnyRef]: Node[T] = Any.INSTANCE.asInstanceOf[Node[T]]
def ignore[T <: AnyRef]: Node[T] = Ignore.INSTANCE.asInstanceOf[Node[T]]
- def node[T <: AnyRef](matcher: Matcher[T], children: Node[T]*): Node[T] =
+ def node[T <: AnyRef](matcher: Matcher[T]): Node[T] = Single(matcher)
+ def branch[T <: AnyRef](matcher: Matcher[T], children: Node[T]*): Node[T] =
Branch(matcher, children.toSeq)
def leaf[T <: AnyRef](matcher: Matcher[T]): Node[T] = Branch(matcher,
List.empty)
diff --git
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/WizardSuite.scala
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/WizardSuite.scala
index 523a22689..59cc44600 100644
---
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/WizardSuite.scala
+++
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/WizardSuite.scala
@@ -198,18 +198,18 @@ class WizardSuite extends AnyFunSuite {
findWithPatterns(
List(
Pattern
- .node[TestNode](
+ .branch[TestNode](
_ => true,
- Pattern.node(_ => true, Pattern.ignore),
- Pattern.node(_ => true, Pattern.ignore))
+ Pattern.branch(_ => true, Pattern.ignore),
+ Pattern.branch(_ => true, Pattern.ignore))
.build())) == List(Binary(n1, Unary(n2, Group(3)), Unary(n3,
Group(4)))))
// Pattern pruning should emit all results
val pattern1 = Pattern
- .node[TestNode](_ => true, Pattern.node(_ => true, Pattern.ignore),
Pattern.ignore)
+ .branch[TestNode](_ => true, Pattern.branch(_ => true, Pattern.ignore),
Pattern.ignore)
.build()
val pattern2 = Pattern
- .node[TestNode](_ => true, Pattern.ignore, Pattern.node(_ => true,
Pattern.ignore))
+ .branch[TestNode](_ => true, Pattern.ignore, Pattern.branch(_ => true,
Pattern.ignore))
.build()
assert(
@@ -219,10 +219,10 @@ class WizardSuite extends AnyFunSuite {
// Distinguish between ignore and any
val pattern3 = Pattern
- .node[TestNode](_ => true, Pattern.node(_ => true, Pattern.any),
Pattern.ignore)
+ .branch[TestNode](_ => true, Pattern.branch(_ => true, Pattern.any),
Pattern.ignore)
.build()
val pattern4 = Pattern
- .node[TestNode](_ => true, Pattern.ignore, Pattern.node(_ => true,
Pattern.any))
+ .branch[TestNode](_ => true, Pattern.ignore, Pattern.branch(_ => true,
Pattern.any))
.build()
assert(
@@ -231,6 +231,13 @@ class WizardSuite extends AnyFunSuite {
Binary(n1, Group(1), Unary(n3, Leaf(n6, 1))),
Binary(n1, Unary(n2, Leaf(n4, 1)), Group(2))))
+ // Single
+ val pattern5 = Pattern.node[TestNode](_ => true).build()
+ assert(findWithPatterns(List(pattern5)) == List(Binary(n1, Group(1),
Group(2))))
+ val pattern6 = Pattern.node[TestNode](_.isInstanceOf[Binary]).build()
+ assert(findWithPatterns(List(pattern6)) == List(Binary(n1, Group(1),
Group(2))))
+ val pattern7 = Pattern.node[TestNode](_.isInstanceOf[Leaf]).build()
+ assert(findWithPatterns(List(pattern7)).isEmpty)
}
test("Prune by mask") {
diff --git
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/rule/PatternSuite.scala
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/rule/PatternSuite.scala
index 2a86f164d..64b66bbaf 100644
---
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/rule/PatternSuite.scala
+++
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/rule/PatternSuite.scala
@@ -72,12 +72,12 @@ class PatternSuite extends AnyFunSuite {
val path = MockRasPath.mock(ras, Unary("n1", Leaf("n2", 1)))
assert(path.height() == 2)
- val pattern1 = Pattern.node[TestNode](n => n.isInstanceOf[Unary],
Pattern.ignore).build()
+ val pattern1 = Pattern.branch[TestNode](n => n.isInstanceOf[Unary],
Pattern.ignore).build()
assert(pattern1.matches(path, 1))
assert(pattern1.matches(path, 2))
val pattern2 =
- Pattern.node[TestNode](n => n.asInstanceOf[Unary].name == "foo",
Pattern.ignore).build()
+ Pattern.branch[TestNode](n => n.asInstanceOf[Unary].name == "foo",
Pattern.ignore).build()
assert(!pattern2.matches(path, 1))
assert(!pattern2.matches(path, 2))
}
@@ -98,11 +98,11 @@ class PatternSuite extends AnyFunSuite {
assert(path.height() == 4)
val pattern = Pattern
- .node[TestNode](
+ .branch[TestNode](
n => n.isInstanceOf[Binary],
- Pattern.node(
+ Pattern.branch(
n => n.isInstanceOf[Unary],
- Pattern.node(
+ Pattern.branch(
n => n.isInstanceOf[Unary],
Pattern.ignore
)
@@ -131,11 +131,11 @@ class PatternSuite extends AnyFunSuite {
assert(path.height() == 4)
val pattern1 = Pattern
- .node[TestNode](
+ .branch[TestNode](
n => n.isInstanceOf[Binary],
- Pattern.node(
+ Pattern.branch(
n => n.isInstanceOf[Unary],
- Pattern.node(
+ Pattern.branch(
n => n.isInstanceOf[Unary],
Pattern.leaf(
_.asInstanceOf[Leaf].name == "foo"
@@ -152,13 +152,13 @@ class PatternSuite extends AnyFunSuite {
assert(!pattern1.matches(path, 4))
val pattern2 = Pattern
- .node[TestNode](
+ .branch[TestNode](
n => n.isInstanceOf[Binary],
- Pattern.node(
+ Pattern.branch(
n => n.isInstanceOf[Unary],
- Pattern.node(
+ Pattern.branch(
n => n.isInstanceOf[Unary],
- Pattern.node(
+ Pattern.branch(
n => n.isInstanceOf[Unary],
Pattern.ignore
)
@@ -188,9 +188,9 @@ class PatternSuite extends AnyFunSuite {
assert(path.height() == 2)
val pattern1 = Pattern
- .node[TestNode](
+ .branch[TestNode](
Pattern.Matchers.clazz(classOf[Unary]),
- Pattern.node(Pattern.Matchers.clazz(classOf[Leaf])))
+ Pattern.branch(Pattern.Matchers.clazz(classOf[Leaf])))
.build()
assert(pattern1.matches(path, 1))
assert(pattern1.matches(path, 2))
@@ -202,19 +202,20 @@ class PatternSuite extends AnyFunSuite {
assert(!pattern2.matches(path, 2))
val pattern3 = Pattern
- .node[TestNode](
+ .branch[TestNode](
Pattern.Matchers
.or(Pattern.Matchers.clazz(classOf[Unary]),
Pattern.Matchers.clazz(classOf[Leaf])),
- Pattern.node(Pattern.Matchers.clazz(classOf[Leaf])))
+ Pattern.branch(Pattern.Matchers.clazz(classOf[Leaf]))
+ )
.build()
assert(pattern3.matches(path, 1))
assert(pattern3.matches(path, 2))
val pattern4 = Pattern
- .node[TestNode](
+ .branch[TestNode](
Pattern.Matchers
.or(Pattern.Matchers.clazz(classOf[Unary]),
Pattern.Matchers.clazz(classOf[Leaf])),
- Pattern.node(Pattern.Matchers
+ Pattern.branch(Pattern.Matchers
.or(Pattern.Matchers.clazz(classOf[Unary]),
Pattern.Matchers.clazz(classOf[Unary])))
)
.build()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]