This is an automated email from the ASF dual-hosted git repository.

marong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new fa86e7682 [GLUTEN-4875][VL]Support spark sql conf 
sortBeforeRepartition to avoid stage partial retry casuing result mismatch 
(#4872)
fa86e7682 is described below

commit fa86e7682636a4a03a0ea7b60275c41b8567d4f0
Author: Terry Wang <[email protected]>
AuthorDate: Mon Mar 18 09:07:27 2024 +0800

    [GLUTEN-4875][VL]Support spark sql conf sortBeforeRepartition to avoid 
stage partial retry casuing result mismatch (#4872)
---
 .../backendsapi/velox/SparkPlanExecApiImpl.scala   | 22 ++++++++++++++++---
 .../io/glutenproject/execution/TestOperator.scala  | 25 ++++++++++++++++++++++
 .../org/apache/spark/sql/GlutenImplicitsTest.scala |  4 ++--
 .../GlutenReplaceHashWithSortAggSuite.scala        |  3 +++
 .../GlutenReplaceHashWithSortAggSuite.scala        |  3 +++
 5 files changed, 52 insertions(+), 5 deletions(-)

diff --git 
a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala
 
b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala
index 60aff4b29..b8da55ea9 100644
--- 
a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala
+++ 
b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala
@@ -36,12 +36,12 @@ import 
org.apache.spark.sql.catalyst.{AggregateFunctionRewriteRule, FlushableHas
 import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
 import org.apache.spark.sql.catalyst.catalog.BucketSpec
 import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, 
CreateNamedStruct, ElementAt, Expression, ExpressionInfo, Generator, 
GetArrayItem, GetMapValue, GetStructField, If, IsNaN, Literal, Murmur3Hash, 
NamedExpression, NaNvl, PosExplode, Round, StringSplit, StringTrim}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, 
Cast, CreateNamedStruct, ElementAt, Expression, ExpressionInfo, Generator, 
GetArrayItem, GetMapValue, GetStructField, If, IsNaN, Literal, Murmur3Hash, 
NamedExpression, NaNvl, PosExplode, Round, SortOrder, StringSplit, StringTrim}
 import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
HLLAdapter}
 import org.apache.spark.sql.catalyst.optimizer.BuildSide
 import org.apache.spark.sql.catalyst.plans.JoinType
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, 
HashPartitioning, Partitioning}
+import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, 
HashPartitioning, Partitioning, RoundRobinPartitioning}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution.{BroadcastUtils, 
ColumnarBuildSideRelation, ColumnarShuffleExchangeExec, SparkPlan, 
VeloxColumnarWriteFilesExec}
 import org.apache.spark.sql.execution.datasources.{FileFormat, WriteFilesExec}
@@ -232,7 +232,23 @@ class SparkPlanExecApiImpl extends SparkPlanExecApi {
           TransformHints.tagNotTransformable(shuffle, validationResult)
           shuffle.withNewChildren(newChild :: Nil)
         }
-
+      case RoundRobinPartitioning(num) if SQLConf.get.sortBeforeRepartition && 
num > 1 =>
+        val hashExpr = new Murmur3Hash(newChild.output)
+        val projectList = Seq(Alias(hashExpr, "hash_partition_key")()) ++ 
newChild.output
+        val projectTransformer = ProjectExecTransformer(projectList, newChild)
+        val sortOrder = SortOrder(projectTransformer.output.head, Ascending)
+        val sortByHashCode = SortExecTransformer(Seq(sortOrder), global = 
false, projectTransformer)
+        val dropSortColumnTransformer = 
ProjectExecTransformer(projectList.drop(1), sortByHashCode)
+        val validationResult = dropSortColumnTransformer.doValidate()
+        if (validationResult.isValid) {
+          ColumnarShuffleExchangeExec(
+            shuffle,
+            dropSortColumnTransformer,
+            dropSortColumnTransformer.output)
+        } else {
+          TransformHints.tagNotTransformable(shuffle, validationResult)
+          shuffle.withNewChildren(newChild :: Nil)
+        }
       case _ =>
         ColumnarShuffleExchangeExec(shuffle, newChild, null)
     }
diff --git 
a/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala 
b/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala
index 5dbad565c..239bec57a 100644
--- 
a/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala
+++ 
b/backends-velox/src/test/scala/io/glutenproject/execution/TestOperator.scala
@@ -1233,4 +1233,29 @@ class TestOperator extends 
VeloxWholeStageTransformerSuite {
       checkOperatorMatch[HashAggregateExecTransformer]
     }
   }
+
+  test("test roundrobine with sort") {
+    // scalastyle:off
+    runQueryAndCompare("SELECT /*+ REPARTITION(3) */ l_orderkey, l_partkey 
FROM lineitem") {
+      /*
+        ColumnarExchange RoundRobinPartitioning(3), REPARTITION_BY_NUM, 
[l_orderkey#16L, l_partkey#17L)
+        +- ^(2) ProjectExecTransformer [l_orderkey#16L, l_partkey#17L]
+          +- ^(2) SortExecTransformer [hash_partition_key#302 ASC NULLS 
FIRST], false, 0
+            +- ^(2) ProjectExecTransformer [hash(l_orderkey#16L, 
l_partkey#17L) AS hash_partition_key#302, l_orderkey#16L, l_partkey#17L]
+                +- ^(2) BatchScanExecTransformer[l_orderkey#16L, 
l_partkey#17L] ParquetScan DataFilters: [], Format: parquet, Location: 
InMemoryFileIndex(1 paths)[..., PartitionFilters: [], PushedFilters: [], 
ReadSchema: struct<l_orderkey:bigint,l_partkey:bigint>, PushedFilters: [] 
RuntimeFilters: []
+       */
+      checkOperatorMatch[SortExecTransformer]
+    }
+    // scalastyle:on
+
+    withSQLConf("spark.sql.execution.sortBeforeRepartition" -> "false") {
+      runQueryAndCompare("""SELECT /*+ REPARTITION(3) */
+                           | l_orderkey, l_partkey FROM 
lineitem""".stripMargin) {
+        df =>
+          {
+            
assert(getExecutedPlan(df).count(_.isInstanceOf[SortExecTransformer]) == 0)
+          }
+      }
+    }
+  }
 }
diff --git 
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenImplicitsTest.scala
 
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenImplicitsTest.scala
index 52a4db9e6..e4356cec8 100644
--- 
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenImplicitsTest.scala
+++ 
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenImplicitsTest.scala
@@ -119,10 +119,10 @@ class GlutenImplicitsTest extends GlutenSQLTestsBaseTrait 
{
   testGluten("fallbackSummary with cached data and shuffle") {
     withAQEEnabledAndDisabled {
       val df = spark.sql("select * from t1").filter(_.getLong(0) > 
0).cache.repartition()
-      assert(df.fallbackSummary().numGlutenNodes == 3, df.fallbackSummary())
+      assert(df.fallbackSummary().numGlutenNodes == 6, df.fallbackSummary())
       assert(df.fallbackSummary().numFallbackNodes == 1, df.fallbackSummary())
       df.collect()
-      assert(df.fallbackSummary().numGlutenNodes == 3, df.fallbackSummary())
+      assert(df.fallbackSummary().numGlutenNodes == 6, df.fallbackSummary())
       assert(df.fallbackSummary().numFallbackNodes == 1, df.fallbackSummary())
     }
   }
diff --git 
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/GlutenReplaceHashWithSortAggSuite.scala
 
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/GlutenReplaceHashWithSortAggSuite.scala
index bbc267ec2..c83912d7f 100644
--- 
a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/GlutenReplaceHashWithSortAggSuite.scala
+++ 
b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/execution/GlutenReplaceHashWithSortAggSuite.scala
@@ -61,6 +61,8 @@ class GlutenReplaceHashWithSortAggSuite
 
       Seq("FIRST", "COLLECT_LIST").foreach {
         aggExpr =>
+          // Because repartition modification causing the result sort order 
not same and the
+          // result not same, so we add order by key before comparing the 
result.
           val query =
             s"""
                |SELECT key, $aggExpr(key)
@@ -72,6 +74,7 @@ class GlutenReplaceHashWithSortAggSuite
                |   SORT BY key
                |)
                |GROUP BY key
+               |ORDER BY key
              """.stripMargin
           checkAggs(query, 2, 0, 2, 0)
       }
diff --git 
a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/GlutenReplaceHashWithSortAggSuite.scala
 
b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/GlutenReplaceHashWithSortAggSuite.scala
index f86509d44..f0898beb3 100644
--- 
a/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/GlutenReplaceHashWithSortAggSuite.scala
+++ 
b/gluten-ut/spark34/src/test/scala/org/apache/spark/sql/execution/GlutenReplaceHashWithSortAggSuite.scala
@@ -60,6 +60,8 @@ class GlutenReplaceHashWithSortAggSuite
 
       Seq("FIRST", "COLLECT_LIST").foreach {
         aggExpr =>
+          // Because repartition modification causing the result sort order 
not same and the
+          // result not same, so we add order by key before comparing the 
result.
           val query =
             s"""
                |SELECT key, $aggExpr(key)
@@ -71,6 +73,7 @@ class GlutenReplaceHashWithSortAggSuite
                |   SORT BY key
                |)
                |GROUP BY key
+               |ORDER BY key
              """.stripMargin
           checkAggs(query, 2, 0, 2, 0)
       }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to