This is an automated email from the ASF dual-hosted git repository.

philo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 9755cdb6b [VL] Fix shuffle error when null type is used (#5961)
9755cdb6b is described below

commit 9755cdb6b9756751b0f8a2c0ea519837c01e8def
Author: Xiduo You <[email protected]>
AuthorDate: Tue Jun 4 22:51:43 2024 +0800

    [VL] Fix shuffle error when null type is used (#5961)
---
 .../org/apache/gluten/execution/TestOperator.scala | 72 +++++++++++++++-------
 cpp/core/shuffle/Payload.cc                        |  2 +
 cpp/velox/shuffle/VeloxHashBasedShuffleWriter.cc   |  8 +++
 3 files changed, 61 insertions(+), 21 deletions(-)

diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala 
b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala
index 905d30055..bc51ee7cb 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala
@@ -1682,26 +1682,56 @@ class TestOperator extends 
VeloxWholeStageTransformerSuite {
     }
   }
 
-  test("Fix shuffle with round robin partitioning fail") {
-    def checkNullTypeRepartition(df: => DataFrame, numProject: Int): Unit = {
-      var expected: Array[Row] = null
-      withSQLConf("spark.sql.execution.sortBeforeRepartition" -> "false") {
-        expected = df.collect()
-      }
-      val actual = df
-      checkAnswer(actual, expected)
-      assert(
-        collect(actual.queryExecution.executedPlan) { case p: ProjectExec => p 
}.size == numProject
-      )
-    }
-
-    checkNullTypeRepartition(
-      spark.table("lineitem").selectExpr("l_orderkey", "null as 
x").repartition(),
-      0
-    )
-    checkNullTypeRepartition(
-      spark.table("lineitem").selectExpr("null as x", "null as 
y").repartition(),
-      1
-    )
+  test("Fix shuffle with null type failure") {
+    // single and other partitioning
+    Seq("1", "2").foreach {
+      numShufflePartitions =>
+        withSQLConf("spark.sql.shuffle.partitions" -> numShufflePartitions) {
+          def checkNullTypeRepartition(df: => DataFrame, numProject: Int): 
Unit = {
+            var expected: Array[Row] = null
+            withSQLConf("spark.sql.execution.sortBeforeRepartition" -> 
"false") {
+              expected = df.collect()
+            }
+            val actual = df
+            checkAnswer(actual, expected)
+            assert(
+              collect(actual.queryExecution.executedPlan) {
+                case p: ProjectExec => p
+              }.size == numProject
+            )
+            assert(
+              collect(actual.queryExecution.executedPlan) {
+                case shuffle: ColumnarShuffleExchangeExec => shuffle
+              }.size == 1
+            )
+          }
+
+          // hash
+          checkNullTypeRepartition(
+            spark
+              .table("lineitem")
+              .selectExpr("l_orderkey", "null as x")
+              .repartition($"l_orderkey"),
+            0
+          )
+          // range
+          checkNullTypeRepartition(
+            spark
+              .table("lineitem")
+              .selectExpr("l_orderkey", "null as x")
+              .repartitionByRange($"l_orderkey"),
+            0
+          )
+          // round robin
+          checkNullTypeRepartition(
+            spark.table("lineitem").selectExpr("l_orderkey", "null as 
x").repartition(),
+            0
+          )
+          checkNullTypeRepartition(
+            spark.table("lineitem").selectExpr("null as x", "null as 
y").repartition(),
+            1
+          )
+        }
+    }
   }
 }
diff --git a/cpp/core/shuffle/Payload.cc b/cpp/core/shuffle/Payload.cc
index 626ed0cf0..beca3fa02 100644
--- a/cpp/core/shuffle/Payload.cc
+++ b/cpp/core/shuffle/Payload.cc
@@ -327,6 +327,8 @@ arrow::Result<std::vector<std::shared_ptr<arrow::Buffer>>> 
BlockPayload::deseria
       case arrow::ListType::type_id: {
         hasComplexDataType = true;
       } break;
+      case arrow::NullType::type_id:
+        break;
       default: {
         buffers.emplace_back();
         ARROW_ASSIGN_OR_RAISE(buffers.back(), readBuffer());
diff --git a/cpp/velox/shuffle/VeloxHashBasedShuffleWriter.cc 
b/cpp/velox/shuffle/VeloxHashBasedShuffleWriter.cc
index daff13703..741ca8ab9 100644
--- a/cpp/velox/shuffle/VeloxHashBasedShuffleWriter.cc
+++ b/cpp/velox/shuffle/VeloxHashBasedShuffleWriter.cc
@@ -129,6 +129,14 @@ arrow::Status collectFlatVectorBufferStringView(
   return arrow::Status::OK();
 }
 
+template <>
+arrow::Status collectFlatVectorBuffer<facebook::velox::TypeKind::UNKNOWN>(
+    facebook::velox::BaseVector* vector,
+    std::vector<std::shared_ptr<arrow::Buffer>>& buffers,
+    arrow::MemoryPool* pool) {
+  return arrow::Status::OK();
+}
+
 template <>
 arrow::Status collectFlatVectorBuffer<facebook::velox::TypeKind::VARCHAR>(
     facebook::velox::BaseVector* vector,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to