This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new eb653ba39 [VL] RAS: New rule RemoveSort to remove unnecessary sorts
(#6107)
eb653ba39 is described below
commit eb653ba39bc7d79dcf86aea3a05375d74244753e
Author: Hongze Zhang <[email protected]>
AuthorDate: Tue Jun 18 09:16:33 2024 +0800
[VL] RAS: New rule RemoveSort to remove unnecessary sorts (#6107)
---
.../columnar/enumerated/EnumeratedTransform.scala | 1 +
.../extension/columnar/enumerated/RemoveSort.scala | 61 ++++++++++++++++++++++
.../scala/org/apache/gluten/ras/path/Pattern.scala | 40 +++++++++++---
.../org/apache/gluten/ras/rule/PatternSuite.scala | 30 ++++++++++-
4 files changed, 124 insertions(+), 8 deletions(-)
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
index 0b9dcc663..9a54a1014 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
@@ -43,6 +43,7 @@ case class EnumeratedTransform(session: SparkSession,
outputsColumnar: Boolean)
private val rules = List(
new PushFilterToScan(RasOffload.validator),
+ RemoveSort,
RemoveFilter
)
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RemoveSort.scala
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RemoveSort.scala
new file mode 100644
index 000000000..5b5d5e541
--- /dev/null
+++
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RemoveSort.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension.columnar.enumerated
+
+import org.apache.gluten.execution.{HashAggregateExecBaseTransformer,
ShuffledHashJoinExecTransformerBase, SortExecTransformer}
+import org.apache.gluten.extension.GlutenPlan
+import org.apache.gluten.ras.path.Pattern._
+import org.apache.gluten.ras.path.Pattern.Matchers._
+import org.apache.gluten.ras.rule.{RasRule, Shape}
+import org.apache.gluten.ras.rule.Shapes._
+
+import org.apache.spark.sql.catalyst.expressions.SortOrder
+import org.apache.spark.sql.execution.SparkPlan
+
+/**
+ * Removes unnecessary sort if its parent doesn't require for sorted input.
+ *
+ * TODO: Sort's removal could be made much simpler once output ordering is
added as a physical
+ * property in RAS planer.
+ */
+object RemoveSort extends RasRule[SparkPlan] {
+ private val appliedTypes: Seq[Class[_ <: GlutenPlan]] =
+ List(classOf[HashAggregateExecBaseTransformer],
classOf[ShuffledHashJoinExecTransformerBase])
+
+ override def shift(node: SparkPlan): Iterable[SparkPlan] = {
+ assert(node.isInstanceOf[GlutenPlan])
+ val newChildren = node.requiredChildOrdering.zip(node.children).map {
+ case (Nil, sort: SortExecTransformer) =>
+ // Parent doesn't ask for sorted input from this child but a sort op
was somehow added.
+ // Remove it.
+ sort.child
+ case (req, child) =>
+ // Parent asks for sorted input from this child. Do nothing but an
assertion.
+ assert(SortOrder.orderingSatisfies(child.outputOrdering, req))
+ child
+ }
+ val out = List(node.withNewChildren(newChildren))
+ out
+ }
+ override def shape(): Shape[SparkPlan] = pattern(
+ branch2[SparkPlan](
+ or(appliedTypes.map(clazz[SparkPlan](_)): _*),
+ _ >= 1,
+ _ => node(clazz(classOf[GlutenPlan]))
+ ).build()
+ )
+}
diff --git
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala
index e60a94717..f54b031b0 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala
@@ -87,14 +87,35 @@ object Pattern {
override def children(count: Int): Seq[Node[T]] = (0 until count).map(_ =>
ignore[T])
}
- private case class Branch[T <: AnyRef](matcher: Matcher[T], children:
Seq[Node[T]])
+ private case class Branch[T <: AnyRef](matcher: Matcher[T], children:
Branch.ChildrenFactory[T])
extends Node[T] {
override def skip(): Boolean = false
- override def abort(node: CanonicalNode[T]): Boolean = node.childrenCount
!= children.size
+ override def abort(node: CanonicalNode[T]): Boolean =
+ !children.acceptsChildrenCount(node.childrenCount)
override def matches(node: CanonicalNode[T]): Boolean =
matcher(node.self())
override def children(count: Int): Seq[Node[T]] = {
- assert(count == children.size)
- children
+ assert(children.acceptsChildrenCount(count))
+ (0 until count).map(children.child)
+ }
+ }
+
+ private object Branch {
+ trait ChildrenFactory[T <: AnyRef] {
+ def child(index: Int): Node[T]
+ def acceptsChildrenCount(count: Int): Boolean
+ }
+
+ object ChildrenFactory {
+ case class Plain[T <: AnyRef](nodes: Seq[Node[T]]) extends
ChildrenFactory[T] {
+ override def child(index: Int): Node[T] = nodes(index)
+ override def acceptsChildrenCount(count: Int): Boolean = nodes.size ==
count
+ }
+
+ case class Func[T <: AnyRef](arity: Int => Boolean, func: Int => Node[T])
+ extends ChildrenFactory[T] {
+ override def child(index: Int): Node[T] = func(index)
+ override def acceptsChildrenCount(count: Int): Boolean = arity(count)
+ }
}
}
@@ -102,8 +123,15 @@ object Pattern {
def ignore[T <: AnyRef]: Node[T] = Ignore.INSTANCE.asInstanceOf[Node[T]]
def node[T <: AnyRef](matcher: Matcher[T]): Node[T] = Single(matcher)
def branch[T <: AnyRef](matcher: Matcher[T], children: Node[T]*): Node[T] =
- Branch(matcher, children.toSeq)
- def leaf[T <: AnyRef](matcher: Matcher[T]): Node[T] = Branch(matcher,
List.empty)
+ Branch(matcher, Branch.ChildrenFactory.Plain(children.toSeq))
+ // Similar to #branch, but with unknown arity.
+ def branch2[T <: AnyRef](
+ matcher: Matcher[T],
+ arity: Int => Boolean,
+ children: Int => Node[T]): Node[T] =
+ Branch(matcher, Branch.ChildrenFactory.Func(arity, children))
+ def leaf[T <: AnyRef](matcher: Matcher[T]): Node[T] =
+ Branch(matcher, Branch.ChildrenFactory.Plain(List.empty))
implicit class NodeImplicits[T <: AnyRef](node: Node[T]) {
def build(): Pattern[T] = {
diff --git
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/rule/PatternSuite.scala
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/rule/PatternSuite.scala
index 64b66bbaf..dc7f5e883 100644
---
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/rule/PatternSuite.scala
+++
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/rule/PatternSuite.scala
@@ -59,6 +59,29 @@ class PatternSuite extends AnyFunSuite {
assert(pattern.matches(path, 1))
}
+ test("Match branch") {
+ val ras =
+ Ras[TestNode](
+ PlanModelImpl,
+ CostModelImpl,
+ MetadataModelImpl,
+ PropertyModelImpl,
+ ExplainImpl,
+ RasRule.Factory.none())
+
+ val path1 = MockRasPath.mock(ras, Branch("n1", List()))
+ val path2 = MockRasPath.mock(ras, Branch("n1", List(Leaf("n2", 1))))
+ val path3 = MockRasPath.mock(ras, Branch("n1", List(Leaf("n2", 1),
Leaf("n3", 1))))
+
+ val pattern =
+ Pattern.branch2[TestNode](n => n.isInstanceOf[Branch], _ >= 1, _ =>
Pattern.any).build()
+ assert(!pattern.matches(path1, 1))
+ assert(pattern.matches(path2, 1))
+ assert(pattern.matches(path2, 2))
+ assert(pattern.matches(path3, 1))
+ assert(pattern.matches(path3, 2))
+ }
+
test("Match unary") {
val ras =
Ras[TestNode](
@@ -231,17 +254,20 @@ object PatternSuite {
case class Unary(name: String, child: TestNode) extends UnaryLike {
override def selfCost(): Long = 1
-
override def withNewChildren(child: TestNode): UnaryLike = copy(child =
child)
}
case class Binary(name: String, left: TestNode, right: TestNode) extends
BinaryLike {
override def selfCost(): Long = 1
-
override def withNewChildren(left: TestNode, right: TestNode): BinaryLike =
copy(left = left, right = right)
}
+ case class Branch(name: String, children: Seq[TestNode]) extends TestNode {
+ override def selfCost(): Long = 1
+ override def withNewChildren(children: Seq[TestNode]): TestNode =
copy(children = children)
+ }
+
case class DummyGroup() extends LeafLike {
override def makeCopy(): LeafLike = throw new
UnsupportedOperationException()
override def selfCost(): Long = throw new UnsupportedOperationException()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]