This is an automated email from the ASF dual-hosted git repository.
philo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 9755cdb6b [VL] Fix shuffle error when null type is used (#5961)
9755cdb6b is described below
commit 9755cdb6b9756751b0f8a2c0ea519837c01e8def
Author: Xiduo You <[email protected]>
AuthorDate: Tue Jun 4 22:51:43 2024 +0800
[VL] Fix shuffle error when null type is used (#5961)
---
.../org/apache/gluten/execution/TestOperator.scala | 72 +++++++++++++++-------
cpp/core/shuffle/Payload.cc | 2 +
cpp/velox/shuffle/VeloxHashBasedShuffleWriter.cc | 8 +++
3 files changed, 61 insertions(+), 21 deletions(-)
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala
b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala
index 905d30055..bc51ee7cb 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala
@@ -1682,26 +1682,56 @@ class TestOperator extends
VeloxWholeStageTransformerSuite {
}
}
- test("Fix shuffle with round robin partitioning fail") {
- def checkNullTypeRepartition(df: => DataFrame, numProject: Int): Unit = {
- var expected: Array[Row] = null
- withSQLConf("spark.sql.execution.sortBeforeRepartition" -> "false") {
- expected = df.collect()
- }
- val actual = df
- checkAnswer(actual, expected)
- assert(
- collect(actual.queryExecution.executedPlan) { case p: ProjectExec => p
}.size == numProject
- )
- }
-
- checkNullTypeRepartition(
- spark.table("lineitem").selectExpr("l_orderkey", "null as
x").repartition(),
- 0
- )
- checkNullTypeRepartition(
- spark.table("lineitem").selectExpr("null as x", "null as
y").repartition(),
- 1
- )
+ test("Fix shuffle with null type failure") {
+ // single and other partitioning
+ Seq("1", "2").foreach {
+ numShufflePartitions =>
+ withSQLConf("spark.sql.shuffle.partitions" -> numShufflePartitions) {
+ def checkNullTypeRepartition(df: => DataFrame, numProject: Int):
Unit = {
+ var expected: Array[Row] = null
+ withSQLConf("spark.sql.execution.sortBeforeRepartition" ->
"false") {
+ expected = df.collect()
+ }
+ val actual = df
+ checkAnswer(actual, expected)
+ assert(
+ collect(actual.queryExecution.executedPlan) {
+ case p: ProjectExec => p
+ }.size == numProject
+ )
+ assert(
+ collect(actual.queryExecution.executedPlan) {
+ case shuffle: ColumnarShuffleExchangeExec => shuffle
+ }.size == 1
+ )
+ }
+
+ // hash
+ checkNullTypeRepartition(
+ spark
+ .table("lineitem")
+ .selectExpr("l_orderkey", "null as x")
+ .repartition($"l_orderkey"),
+ 0
+ )
+ // range
+ checkNullTypeRepartition(
+ spark
+ .table("lineitem")
+ .selectExpr("l_orderkey", "null as x")
+ .repartitionByRange($"l_orderkey"),
+ 0
+ )
+ // round robin
+ checkNullTypeRepartition(
+ spark.table("lineitem").selectExpr("l_orderkey", "null as
x").repartition(),
+ 0
+ )
+ checkNullTypeRepartition(
+ spark.table("lineitem").selectExpr("null as x", "null as
y").repartition(),
+ 1
+ )
+ }
+ }
}
}
diff --git a/cpp/core/shuffle/Payload.cc b/cpp/core/shuffle/Payload.cc
index 626ed0cf0..beca3fa02 100644
--- a/cpp/core/shuffle/Payload.cc
+++ b/cpp/core/shuffle/Payload.cc
@@ -327,6 +327,8 @@ arrow::Result<std::vector<std::shared_ptr<arrow::Buffer>>>
BlockPayload::deseria
case arrow::ListType::type_id: {
hasComplexDataType = true;
} break;
+ case arrow::NullType::type_id:
+ break;
default: {
buffers.emplace_back();
ARROW_ASSIGN_OR_RAISE(buffers.back(), readBuffer());
diff --git a/cpp/velox/shuffle/VeloxHashBasedShuffleWriter.cc
b/cpp/velox/shuffle/VeloxHashBasedShuffleWriter.cc
index daff13703..741ca8ab9 100644
--- a/cpp/velox/shuffle/VeloxHashBasedShuffleWriter.cc
+++ b/cpp/velox/shuffle/VeloxHashBasedShuffleWriter.cc
@@ -129,6 +129,14 @@ arrow::Status collectFlatVectorBufferStringView(
return arrow::Status::OK();
}
+template <>
+arrow::Status collectFlatVectorBuffer<facebook::velox::TypeKind::UNKNOWN>(
+ facebook::velox::BaseVector* vector,
+ std::vector<std::shared_ptr<arrow::Buffer>>& buffers,
+ arrow::MemoryPool* pool) {
+ return arrow::Status::OK();
+}
+
template <>
arrow::Status collectFlatVectorBuffer<facebook::velox::TypeKind::VARCHAR>(
facebook::velox::BaseVector* vector,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]