This is an automated email from the ASF dual-hosted git repository.

hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new eb653ba39 [VL] RAS: New rule RemoveSort to remove unnecessary sorts 
(#6107)
eb653ba39 is described below

commit eb653ba39bc7d79dcf86aea3a05375d74244753e
Author: Hongze Zhang <[email protected]>
AuthorDate: Tue Jun 18 09:16:33 2024 +0800

    [VL] RAS: New rule RemoveSort to remove unnecessary sorts (#6107)
---
 .../columnar/enumerated/EnumeratedTransform.scala  |  1 +
 .../extension/columnar/enumerated/RemoveSort.scala | 61 ++++++++++++++++++++++
 .../scala/org/apache/gluten/ras/path/Pattern.scala | 40 +++++++++++---
 .../org/apache/gluten/ras/rule/PatternSuite.scala  | 30 ++++++++++-
 4 files changed, 124 insertions(+), 8 deletions(-)

diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
index 0b9dcc663..9a54a1014 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala
@@ -43,6 +43,7 @@ case class EnumeratedTransform(session: SparkSession, 
outputsColumnar: Boolean)
 
   private val rules = List(
     new PushFilterToScan(RasOffload.validator),
+    RemoveSort,
     RemoveFilter
   )
 
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RemoveSort.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RemoveSort.scala
new file mode 100644
index 000000000..5b5d5e541
--- /dev/null
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RemoveSort.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.extension.columnar.enumerated
+
+import org.apache.gluten.execution.{HashAggregateExecBaseTransformer, 
ShuffledHashJoinExecTransformerBase, SortExecTransformer}
+import org.apache.gluten.extension.GlutenPlan
+import org.apache.gluten.ras.path.Pattern._
+import org.apache.gluten.ras.path.Pattern.Matchers._
+import org.apache.gluten.ras.rule.{RasRule, Shape}
+import org.apache.gluten.ras.rule.Shapes._
+
+import org.apache.spark.sql.catalyst.expressions.SortOrder
+import org.apache.spark.sql.execution.SparkPlan
+
+/**
+ * Removes unnecessary sort if its parent doesn't require for sorted input.
+ *
+ * TODO: Sort's removal could be made much simpler once output ordering is 
added as a physical
+ * property in RAS planer.
+ */
+object RemoveSort extends RasRule[SparkPlan] {
+  private val appliedTypes: Seq[Class[_ <: GlutenPlan]] =
+    List(classOf[HashAggregateExecBaseTransformer], 
classOf[ShuffledHashJoinExecTransformerBase])
+
+  override def shift(node: SparkPlan): Iterable[SparkPlan] = {
+    assert(node.isInstanceOf[GlutenPlan])
+    val newChildren = node.requiredChildOrdering.zip(node.children).map {
+      case (Nil, sort: SortExecTransformer) =>
+        // Parent doesn't ask for sorted input from this child but a sort op 
was somehow added.
+        // Remove it.
+        sort.child
+      case (req, child) =>
+        // Parent asks for sorted input from this child. Do nothing but an 
assertion.
+        assert(SortOrder.orderingSatisfies(child.outputOrdering, req))
+        child
+    }
+    val out = List(node.withNewChildren(newChildren))
+    out
+  }
+  override def shape(): Shape[SparkPlan] = pattern(
+    branch2[SparkPlan](
+      or(appliedTypes.map(clazz[SparkPlan](_)): _*),
+      _ >= 1,
+      _ => node(clazz(classOf[GlutenPlan]))
+    ).build()
+  )
+}
diff --git 
a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala 
b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala
index e60a94717..f54b031b0 100644
--- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala
+++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala
@@ -87,14 +87,35 @@ object Pattern {
     override def children(count: Int): Seq[Node[T]] = (0 until count).map(_ => 
ignore[T])
   }
 
-  private case class Branch[T <: AnyRef](matcher: Matcher[T], children: 
Seq[Node[T]])
+  private case class Branch[T <: AnyRef](matcher: Matcher[T], children: 
Branch.ChildrenFactory[T])
     extends Node[T] {
     override def skip(): Boolean = false
-    override def abort(node: CanonicalNode[T]): Boolean = node.childrenCount 
!= children.size
+    override def abort(node: CanonicalNode[T]): Boolean =
+      !children.acceptsChildrenCount(node.childrenCount)
     override def matches(node: CanonicalNode[T]): Boolean = 
matcher(node.self())
     override def children(count: Int): Seq[Node[T]] = {
-      assert(count == children.size)
-      children
+      assert(children.acceptsChildrenCount(count))
+      (0 until count).map(children.child)
+    }
+  }
+
+  private object Branch {
+    trait ChildrenFactory[T <: AnyRef] {
+      def child(index: Int): Node[T]
+      def acceptsChildrenCount(count: Int): Boolean
+    }
+
+    object ChildrenFactory {
+      case class Plain[T <: AnyRef](nodes: Seq[Node[T]]) extends 
ChildrenFactory[T] {
+        override def child(index: Int): Node[T] = nodes(index)
+        override def acceptsChildrenCount(count: Int): Boolean = nodes.size == 
count
+      }
+
+      case class Func[T <: AnyRef](arity: Int => Boolean, func: Int => Node[T])
+        extends ChildrenFactory[T] {
+        override def child(index: Int): Node[T] = func(index)
+        override def acceptsChildrenCount(count: Int): Boolean = arity(count)
+      }
     }
   }
 
@@ -102,8 +123,15 @@ object Pattern {
   def ignore[T <: AnyRef]: Node[T] = Ignore.INSTANCE.asInstanceOf[Node[T]]
   def node[T <: AnyRef](matcher: Matcher[T]): Node[T] = Single(matcher)
   def branch[T <: AnyRef](matcher: Matcher[T], children: Node[T]*): Node[T] =
-    Branch(matcher, children.toSeq)
-  def leaf[T <: AnyRef](matcher: Matcher[T]): Node[T] = Branch(matcher, 
List.empty)
+    Branch(matcher, Branch.ChildrenFactory.Plain(children.toSeq))
+  // Similar to #branch, but with unknown arity.
+  def branch2[T <: AnyRef](
+      matcher: Matcher[T],
+      arity: Int => Boolean,
+      children: Int => Node[T]): Node[T] =
+    Branch(matcher, Branch.ChildrenFactory.Func(arity, children))
+  def leaf[T <: AnyRef](matcher: Matcher[T]): Node[T] =
+    Branch(matcher, Branch.ChildrenFactory.Plain(List.empty))
 
   implicit class NodeImplicits[T <: AnyRef](node: Node[T]) {
     def build(): Pattern[T] = {
diff --git 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/rule/PatternSuite.scala
 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/rule/PatternSuite.scala
index 64b66bbaf..dc7f5e883 100644
--- 
a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/rule/PatternSuite.scala
+++ 
b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/rule/PatternSuite.scala
@@ -59,6 +59,29 @@ class PatternSuite extends AnyFunSuite {
     assert(pattern.matches(path, 1))
   }
 
+  test("Match branch") {
+    val ras =
+      Ras[TestNode](
+        PlanModelImpl,
+        CostModelImpl,
+        MetadataModelImpl,
+        PropertyModelImpl,
+        ExplainImpl,
+        RasRule.Factory.none())
+
+    val path1 = MockRasPath.mock(ras, Branch("n1", List()))
+    val path2 = MockRasPath.mock(ras, Branch("n1", List(Leaf("n2", 1))))
+    val path3 = MockRasPath.mock(ras, Branch("n1", List(Leaf("n2", 1), 
Leaf("n3", 1))))
+
+    val pattern =
+      Pattern.branch2[TestNode](n => n.isInstanceOf[Branch], _ >= 1, _ => 
Pattern.any).build()
+    assert(!pattern.matches(path1, 1))
+    assert(pattern.matches(path2, 1))
+    assert(pattern.matches(path2, 2))
+    assert(pattern.matches(path3, 1))
+    assert(pattern.matches(path3, 2))
+  }
+
   test("Match unary") {
     val ras =
       Ras[TestNode](
@@ -231,17 +254,20 @@ object PatternSuite {
 
   case class Unary(name: String, child: TestNode) extends UnaryLike {
     override def selfCost(): Long = 1
-
     override def withNewChildren(child: TestNode): UnaryLike = copy(child = 
child)
   }
 
   case class Binary(name: String, left: TestNode, right: TestNode) extends 
BinaryLike {
     override def selfCost(): Long = 1
-
     override def withNewChildren(left: TestNode, right: TestNode): BinaryLike =
       copy(left = left, right = right)
   }
 
+  case class Branch(name: String, children: Seq[TestNode]) extends TestNode {
+    override def selfCost(): Long = 1
+    override def withNewChildren(children: Seq[TestNode]): TestNode = 
copy(children = children)
+  }
+
   case class DummyGroup() extends LeafLike {
     override def makeCopy(): LeafLike = throw new 
UnsupportedOperationException()
     override def selfCost(): Long = throw new UnsupportedOperationException()


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to