This is an automated email from the ASF dual-hosted git repository.
marin-ma pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 4634c80a21 [GLUTEN-11915][VL] Support RowBasedChecksum in
ColumnarShuffleWriter (SPARK-51756) (#12067)
4634c80a21 is described below
commit 4634c80a21397cde76eb791c331e396341011245
Author: Shaojie Li <[email protected]>
AuthorDate: Mon Jun 8 14:30:03 2026 -0700
[GLUTEN-11915][VL] Support RowBasedChecksum in ColumnarShuffleWriter
(SPARK-51756) (#12067)
Implement order-independent row-based checksum for non-deterministic stage
retry detection.
- C++ computeRowBasedChecksums(): UnsafeRowFast + XXH64, per-partition
XOR+SUM
- JNI: pass config, return checksum array
- Scala: read SQLConf (OR logic), pass to native, use for MapStatus
- Shim: GlutenMapStatusUtil for Spark 3.3-4.1 compatibility
- Tests: C++ unit (4/4) + Scala integration (3/3)
---
.../VeloxCelebornColumnarShuffleWriter.scala | 3 +-
.../writer/VeloxUniffleColumnarShuffleWriter.java | 3 +-
.../spark/shuffle/ColumnarShuffleWriter.scala | 21 +++-
cpp/core/jni/JniWrapper.cc | 16 ++-
cpp/core/shuffle/Options.h | 2 +
cpp/core/shuffle/ShuffleWriter.cc | 4 +
cpp/core/shuffle/ShuffleWriter.h | 2 +
cpp/velox/shuffle/VeloxHashShuffleWriter.cc | 65 +++++++++++
cpp/velox/shuffle/VeloxHashShuffleWriter.h | 11 +-
cpp/velox/tests/CMakeLists.txt | 1 +
cpp/velox/tests/RowBasedChecksumTest.cc | 121 +++++++++++++++++++++
.../gluten/vectorized/GlutenSplitResult.java | 9 +-
.../gluten/vectorized/ShuffleWriterJniWrapper.java | 3 +-
.../gluten/utils/velox/VeloxTestSettings.scala | 6 +-
.../sql/gluten/GlutenRowBasedChecksumSuite.scala | 91 ++++++++++++++++
.../apache/spark/shuffle/GlutenMapStatusUtil.scala | 32 ++++++
.../apache/spark/shuffle/GlutenMapStatusUtil.scala | 32 ++++++
.../apache/spark/shuffle/GlutenMapStatusUtil.scala | 32 ++++++
.../apache/spark/shuffle/GlutenMapStatusUtil.scala | 32 ++++++
.../apache/spark/shuffle/GlutenMapStatusUtil.scala | 37 +++++++
20 files changed, 509 insertions(+), 14 deletions(-)
diff --git
a/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarShuffleWriter.scala
b/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarShuffleWriter.scala
index 03208adbcf..7123525a8a 100644
---
a/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarShuffleWriter.scala
+++
b/backends-velox/src-celeborn/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarShuffleWriter.scala
@@ -151,7 +151,8 @@ class VeloxCelebornColumnarShuffleWriter[K, V](
nativeBufferSize,
GlutenConfig.get.columnarShuffleReallocThreshold,
GlutenConfig.get.columnarShufflePartitionBufferEvictThreshold,
- partitionWriterHandle
+ partitionWriterHandle,
+ false
)
case SortShuffleWriterType =>
shuffleWriterJniWrapper.createSortShuffleWriter(
diff --git
a/backends-velox/src-uniffle/main/java/org/apache/spark/shuffle/writer/VeloxUniffleColumnarShuffleWriter.java
b/backends-velox/src-uniffle/main/java/org/apache/spark/shuffle/writer/VeloxUniffleColumnarShuffleWriter.java
index e01f97ba3e..c6a11fb1f3 100644
---
a/backends-velox/src-uniffle/main/java/org/apache/spark/shuffle/writer/VeloxUniffleColumnarShuffleWriter.java
+++
b/backends-velox/src-uniffle/main/java/org/apache/spark/shuffle/writer/VeloxUniffleColumnarShuffleWriter.java
@@ -186,7 +186,8 @@ public class VeloxUniffleColumnarShuffleWriter<K, V>
extends RssShuffleWriter<K,
nativeBufferSize,
reallocThreshold,
GlutenConfig.get().columnarShufflePartitionBufferEvictThreshold(),
- partitionWriterHandle);
+ partitionWriterHandle,
+ false);
}
runtime
diff --git
a/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala
b/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala
index ff1c66e037..73bb2770e5 100644
---
a/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala
+++
b/backends-velox/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala
@@ -60,6 +60,8 @@ class ColumnarShuffleWriter[K, V](
private val blockManager = SparkEnv.get.blockManager
+ private val rowBasedChecksumEnabled: Boolean =
GlutenMapStatusUtil.isRowBasedChecksumEnabled
+
// Are we in the process of stopping? Because map tasks can call stop() with
success = true
// and then call stop() with success = false if they get an exception, we
want to make sure
// we don't try deleting files, etc twice.
@@ -193,7 +195,8 @@ class ColumnarShuffleWriter[K, V](
nativeBufferSize,
reallocThreshold,
GlutenConfig.get.columnarShufflePartitionBufferEvictThreshold,
- partitionWriterHandle
+ partitionWriterHandle,
+ rowBasedChecksumEnabled
)
}
@@ -282,7 +285,15 @@ class ColumnarShuffleWriter[K, V](
// almost 3 times than vanilla spark partitionLengths
// This value is sensitive in rules such as AQE rule OptimizeSkewedJoin
DynamicJoinSelection
// May affect the final plan
- mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths,
mapId)
+ val rowChecksums = splitResult.getRowBasedChecksums
+ val aggregatedChecksum = if (rowChecksums != null &&
rowChecksums.nonEmpty) {
+ rowChecksums.foldLeft(0L)((acc, c) => acc * 31L + c)
+ } else 0L
+ mapStatus = GlutenMapStatusUtil.createMapStatus(
+ blockManager.shuffleServerId,
+ partitionLengths,
+ mapId,
+ aggregatedChecksum)
}
private def handleEmptyInput(): Unit = {
@@ -293,7 +304,11 @@ class ColumnarShuffleWriter[K, V](
partitionLengths,
Array[Long](),
null)
- mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths,
mapId)
+ mapStatus = GlutenMapStatusUtil.createMapStatus(
+ blockManager.shuffleServerId,
+ partitionLengths,
+ mapId,
+ 0L)
}
@throws[IOException]
diff --git a/cpp/core/jni/JniWrapper.cc b/cpp/core/jni/JniWrapper.cc
index b1580782c0..827c5ad8bd 100644
--- a/cpp/core/jni/JniWrapper.cc
+++ b/cpp/core/jni/JniWrapper.cc
@@ -270,7 +270,7 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) {
jniByteInputStreamClose = getMethodIdOrError(env, jniByteInputStreamClass,
"close", "()V");
splitResultClass = createGlobalClassReferenceOrError(env,
"Lorg/apache/gluten/vectorized/GlutenSplitResult;");
- splitResultConstructor = getMethodIdOrError(env, splitResultClass, "<init>",
"(JJJJJJJJJJDJ[J[J)V");
+ splitResultConstructor = getMethodIdOrError(env, splitResultClass, "<init>",
"(JJJJJJJJJJDJ[J[J[J)V");
metricsBuilderClass = createGlobalClassReferenceOrError(env,
"Lorg/apache/gluten/metrics/Metrics;");
@@ -993,7 +993,8 @@ JNIEXPORT jlong JNICALL
Java_org_apache_gluten_vectorized_ShuffleWriterJniWrappe
jint splitBufferSize,
jdouble splitBufferReallocThreshold,
jint partitionBufferEvictThreshold,
- jlong partitionWriterHandle) {
+ jlong partitionWriterHandle,
+ jboolean rowBasedChecksumEnabled) {
JNI_METHOD_START
const auto ctx = getRuntime(env, wrapper);
@@ -1009,6 +1010,7 @@ JNIEXPORT jlong JNICALL
Java_org_apache_gluten_vectorized_ShuffleWriterJniWrappe
splitBufferSize,
splitBufferReallocThreshold,
partitionBufferEvictThreshold);
+ shuffleWriterOptions->rowBasedChecksumEnabled = rowBasedChecksumEnabled;
return ctx->saveObject(ctx->createShuffleWriter(numPartitions,
partitionWriter, shuffleWriterOptions));
JNI_METHOD_END(kInvalidObjectHandle)
@@ -1163,6 +1165,13 @@ JNIEXPORT jobject JNICALL
Java_org_apache_gluten_vectorized_ShuffleWriterJniWrap
auto rawSrc = reinterpret_cast<const jlong*>(rawPartitionLengths.data());
env->SetLongArrayRegion(rawPartitionLengthArr, 0,
rawPartitionLengths.size(), rawSrc);
+ const auto& rowBasedChecksums = shuffleWriter->rowBasedChecksums();
+ auto rowBasedChecksumArr = env->NewLongArray(rowBasedChecksums.size());
+ if (!rowBasedChecksums.empty()) {
+ auto checksumSrc = reinterpret_cast<const
jlong*>(rowBasedChecksums.data());
+ env->SetLongArrayRegion(rowBasedChecksumArr, 0, rowBasedChecksums.size(),
checksumSrc);
+ }
+
jobject splitResult = env->NewObject(
splitResultClass,
splitResultConstructor,
@@ -1179,7 +1188,8 @@ JNIEXPORT jobject JNICALL
Java_org_apache_gluten_vectorized_ShuffleWriterJniWrap
shuffleWriter->avgDictionaryFields(),
shuffleWriter->dictionarySize(),
partitionLengthArr,
- rawPartitionLengthArr);
+ rawPartitionLengthArr,
+ rowBasedChecksumArr);
return splitResult;
JNI_METHOD_END(nullptr)
diff --git a/cpp/core/shuffle/Options.h b/cpp/core/shuffle/Options.h
index 649a164774..e4b24740a8 100644
--- a/cpp/core/shuffle/Options.h
+++ b/cpp/core/shuffle/Options.h
@@ -74,6 +74,7 @@ struct ShuffleWriterOptions {
ShuffleWriterType shuffleWriterType;
Partitioning partitioning = Partitioning::kRoundRobin;
int32_t startPartitionId = 0;
+ bool rowBasedChecksumEnabled = false;
ShuffleWriterOptions(ShuffleWriterType shuffleWriterType) :
shuffleWriterType(shuffleWriterType) {}
@@ -234,5 +235,6 @@ struct ShuffleWriterMetrics {
int64_t dictionarySize{0};
std::vector<int64_t> partitionLengths{};
std::vector<int64_t> rawPartitionLengths{}; // Uncompressed size.
+ std::vector<int64_t> rowBasedChecksums{}; // Per-partition row-based
checksums.
};
} // namespace gluten
diff --git a/cpp/core/shuffle/ShuffleWriter.cc
b/cpp/core/shuffle/ShuffleWriter.cc
index 3f0feadfb0..7287181ffa 100644
--- a/cpp/core/shuffle/ShuffleWriter.cc
+++ b/cpp/core/shuffle/ShuffleWriter.cc
@@ -109,6 +109,10 @@ const std::vector<int64_t>&
ShuffleWriter::rawPartitionLengths() const {
return metrics_.rawPartitionLengths;
}
+const std::vector<int64_t>& ShuffleWriter::rowBasedChecksums() const {
+ return metrics_.rowBasedChecksums;
+}
+
ShuffleWriter::ShuffleWriter(int32_t numPartitions, Partitioning partitioning)
: numPartitions_(numPartitions), partitioning_(partitioning) {}
} // namespace gluten
diff --git a/cpp/core/shuffle/ShuffleWriter.h b/cpp/core/shuffle/ShuffleWriter.h
index 934ad09076..2a892e2b32 100644
--- a/cpp/core/shuffle/ShuffleWriter.h
+++ b/cpp/core/shuffle/ShuffleWriter.h
@@ -67,6 +67,8 @@ class ShuffleWriter : public Reclaimable {
const std::vector<int64_t>& rawPartitionLengths() const;
+ const std::vector<int64_t>& rowBasedChecksums() const;
+
protected:
ShuffleWriter(int32_t numPartitions, Partitioning partitioning);
diff --git a/cpp/velox/shuffle/VeloxHashShuffleWriter.cc
b/cpp/velox/shuffle/VeloxHashShuffleWriter.cc
index dfb799806a..0126b2846e 100644
--- a/cpp/velox/shuffle/VeloxHashShuffleWriter.cc
+++ b/cpp/velox/shuffle/VeloxHashShuffleWriter.cc
@@ -25,6 +25,8 @@
#include "utils/VeloxArrowUtils.h"
#include "velox/buffer/Buffer.h"
#include "velox/common/base/Nulls.h"
+#include "velox/external/xxhash/xxhash.h"
+#include "velox/row/UnsafeRowFast.h"
#include "velox/type/HugeInt.h"
#include "velox/type/Timestamp.h"
#include "velox/type/Type.h"
@@ -182,6 +184,11 @@ arrow::Status VeloxHashShuffleWriter::init() {
partitionBufferBase_.resize(numPartitions_);
+ if (rowBasedChecksumEnabled_) {
+ checksumXor_.resize(numPartitions_, 0);
+ checksumSum_.resize(numPartitions_, 0);
+ }
+
return arrow::Status::OK();
}
@@ -362,6 +369,17 @@ arrow::Status VeloxHashShuffleWriter::stop() {
stat();
+ // Populate row-based checksums into metrics.
+ if (rowBasedChecksumEnabled_) {
+ metrics_.rowBasedChecksums.resize(numPartitions_);
+ for (auto pid = 0; pid < numPartitions_; ++pid) {
+ int64_t xorVal = checksumXor_[pid];
+ int64_t sumVal = checksumSum_[pid];
+ int64_t rotated = (static_cast<uint64_t>(sumVal) << 27) |
(static_cast<uint64_t>(sumVal) >> 37);
+ metrics_.rowBasedChecksums[pid] = xorVal ^ rotated;
+ }
+ }
+
return arrow::Status::OK();
}
@@ -423,6 +441,7 @@ void VeloxHashShuffleWriter::setSplitState(SplitState
state) {
arrow::Status VeloxHashShuffleWriter::doSplit(const
facebook::velox::RowVector& rv, int64_t memLimit) {
auto rowNum = rv.size();
RETURN_NOT_OK(buildPartition2Row(rowNum));
+ computeRowBasedChecksums(rv);
RETURN_NOT_OK(updateInputHasNull(rv));
{
@@ -1617,4 +1636,50 @@ bool
VeloxHashShuffleWriter::isExtremelyLargeBatch(facebook::velox::RowVectorPtr
return (rv->size() > maxBatchSize_ && maxBatchSize_ > 0);
}
+void VeloxHashShuffleWriter::computeRowBasedChecksums(const
facebook::velox::RowVector& rv) {
+ if (!rowBasedChecksumEnabled_) {
+ return;
+ }
+
+ auto numRows = rv.size();
+ VELOX_DCHECK(rv.nulls() == nullptr, "RowVector with top-level nulls not
supported for checksum");
+ // Get the RowVector to serialize (strip pid column if present).
+ facebook::velox::RowVectorPtr dataVector;
+ if (partitioner_->hasPid()) {
+ // Strip the first column (partition id).
+ auto rowType = std::dynamic_pointer_cast<const
facebook::velox::RowType>(rv.type());
+ std::vector<std::string> names(rowType->names().begin() + 1,
rowType->names().end());
+ std::vector<facebook::velox::TypePtr> types(rowType->children().begin() +
1, rowType->children().end());
+ std::vector<facebook::velox::VectorPtr> children(rv.children().begin() +
1, rv.children().end());
+ auto dataType = facebook::velox::ROW(std::move(names), std::move(types));
+ dataVector =
+ std::make_shared<facebook::velox::RowVector>(rv.pool(), dataType,
nullptr, numRows, std::move(children));
+ } else {
+ auto rowType = std::dynamic_pointer_cast<const
facebook::velox::RowType>(rv.type());
+ dataVector = std::make_shared<facebook::velox::RowVector>(rv.pool(),
rowType, nullptr, numRows, rv.children());
+ }
+
+ facebook::velox::row::UnsafeRowFast fast(dataVector);
+ auto dataType = std::dynamic_pointer_cast<const
facebook::velox::RowType>(dataVector->type());
+ auto fixedSize = facebook::velox::row::UnsafeRowFast::fixedRowSize(dataType);
+ int32_t bufSize = fixedSize.value_or(1024);
+ if (checksumBuffer_.size() < static_cast<size_t>(bufSize)) {
+ checksumBuffer_.resize(bufSize);
+ }
+
+ for (uint32_t row = 0; row < numRows; ++row) {
+ auto pid = row2Partition_[row];
+ auto size = fast.rowSize(row);
+ if (size > static_cast<int32_t>(checksumBuffer_.size())) {
+ checksumBuffer_.resize(size);
+ }
+ std::memset(checksumBuffer_.data(), 0, size);
+ fast.serialize(row, checksumBuffer_.data());
+
+ auto hash = static_cast<int64_t>(XXH64(checksumBuffer_.data(), size, 0));
+ checksumXor_[pid] ^= hash;
+ checksumSum_[pid] += hash;
+ }
+}
+
} // namespace gluten
diff --git a/cpp/velox/shuffle/VeloxHashShuffleWriter.h
b/cpp/velox/shuffle/VeloxHashShuffleWriter.h
index d2901019b7..ea7a659f2e 100644
--- a/cpp/velox/shuffle/VeloxHashShuffleWriter.h
+++ b/cpp/velox/shuffle/VeloxHashShuffleWriter.h
@@ -279,7 +279,8 @@ class VeloxHashShuffleWriter : public VeloxShuffleWriter {
: VeloxShuffleWriter(numPartitions, partitionWriter, options,
memoryManager),
splitBufferSize_(options->splitBufferSize),
splitBufferReallocThreshold_(options->splitBufferReallocThreshold),
- partitionBufferEvictThreshold_(options->partitionBufferEvictThreshold)
{
+ partitionBufferEvictThreshold_(options->partitionBufferEvictThreshold),
+ rowBasedChecksumEnabled_(options->rowBasedChecksumEnabled) {
arenas_.resize(numPartitions);
}
@@ -516,6 +517,14 @@ class VeloxHashShuffleWriter : public VeloxShuffleWriter {
// See inputEncodingSkippedBatches() above.
int64_t inputEncodingSkippedBatches_{0};
+
+ // Row-based checksum state (per-partition XOR + SUM aggregation).
+ bool rowBasedChecksumEnabled_{false};
+ std::vector<int64_t> checksumXor_;
+ std::vector<int64_t> checksumSum_;
+ std::vector<char> checksumBuffer_;
+
+ void computeRowBasedChecksums(const facebook::velox::RowVector& rv);
}; // class VeloxHashBasedShuffleWriter
} // namespace gluten
diff --git a/cpp/velox/tests/CMakeLists.txt b/cpp/velox/tests/CMakeLists.txt
index ebcac56bc5..3052ce23f6 100644
--- a/cpp/velox/tests/CMakeLists.txt
+++ b/cpp/velox/tests/CMakeLists.txt
@@ -141,6 +141,7 @@ if(ENABLE_S3)
add_velox_test(gluten_s3_file_system_test SOURCES GlutenS3FileSystemTest.cc)
endif()
add_velox_test(scoped_timer_test SOURCES ScopedTimerTest.cc)
+add_velox_test(row_based_checksum_test SOURCES RowBasedChecksumTest.cc)
if(BUILD_EXAMPLES)
add_velox_test(my_udf_test SOURCES MyUdfTest.cc)
endif()
diff --git a/cpp/velox/tests/RowBasedChecksumTest.cc
b/cpp/velox/tests/RowBasedChecksumTest.cc
new file mode 100644
index 0000000000..8f640dfdd9
--- /dev/null
+++ b/cpp/velox/tests/RowBasedChecksumTest.cc
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <gtest/gtest.h>
+
+#include "velox/common/memory/Memory.h"
+#include "velox/external/xxhash/xxhash.h"
+#include "velox/row/UnsafeRowFast.h"
+#include "velox/type/Type.h"
+#include "velox/vector/FlatVector.h"
+#include "velox/vector/tests/utils/VectorTestBase.h"
+
+using namespace facebook::velox;
+
+class RowBasedChecksumTest : public test::VectorTestBase, public testing::Test
{
+ protected:
+ static void SetUpTestSuite() {
+ memory::MemoryManager::testingSetInstance({});
+ }
+ // Simulate the checksum computation from VeloxHashShuffleWriter.
+ std::pair<int64_t, int64_t> computeChecksums(const RowVectorPtr& rv, const
std::vector<uint32_t>& rowOrder) {
+ row::UnsafeRowFast fast(rv);
+ auto rowType = std::dynamic_pointer_cast<const RowType>(rv->type());
+ auto fixedSize = row::UnsafeRowFast::fixedRowSize(rowType);
+ int32_t bufSize = fixedSize.value_or(1024);
+ std::vector<char> buffer(bufSize, 0);
+
+ int64_t checksumXor = 0;
+ int64_t checksumSum = 0;
+
+ for (auto row : rowOrder) {
+ auto size = fast.rowSize(row);
+ if (size > static_cast<int32_t>(buffer.size())) {
+ buffer.resize(size);
+ }
+ std::memset(buffer.data(), 0, size);
+ fast.serialize(row, buffer.data());
+
+ auto hash = static_cast<int64_t>(XXH64(buffer.data(), size, 0));
+ checksumXor ^= hash;
+ checksumSum += hash;
+ }
+
+ int64_t rotated = (static_cast<uint64_t>(checksumSum) << 27) |
(static_cast<uint64_t>(checksumSum) >> 37);
+ return {checksumXor ^ rotated, checksumSum};
+ }
+};
+
+TEST_F(RowBasedChecksumTest, orderIndependence) {
+ // Create a RowVector with 5 rows: (int, string)
+ auto rv = makeRowVector(
+ {"a", "b"},
+ {makeFlatVector<int32_t>({10, 20, 30, 40, 50}),
+ makeFlatVector<StringView>({"hello", "world", "foo", "bar", "baz"})});
+
+ // Compute checksum in original order
+ std::vector<uint32_t> order1 = {0, 1, 2, 3, 4};
+ auto [checksum1, _1] = computeChecksums(rv, order1);
+
+ // Compute checksum in reversed order
+ std::vector<uint32_t> order2 = {4, 3, 2, 1, 0};
+ auto [checksum2, _2] = computeChecksums(rv, order2);
+
+ // Compute checksum in shuffled order
+ std::vector<uint32_t> order3 = {2, 4, 0, 3, 1};
+ auto [checksum3, _3] = computeChecksums(rv, order3);
+
+ // All should be equal (order-independent)
+ EXPECT_EQ(checksum1, checksum2);
+ EXPECT_EQ(checksum1, checksum3);
+ EXPECT_NE(checksum1, 0); // Should be non-zero
+}
+
+TEST_F(RowBasedChecksumTest, differentDataProducesDifferentChecksum) {
+ auto rv1 = makeRowVector({"a"}, {makeFlatVector<int64_t>({1, 2, 3})});
+ auto rv2 = makeRowVector({"a"}, {makeFlatVector<int64_t>({1, 2, 4})}); //
last value different
+
+ std::vector<uint32_t> order = {0, 1, 2};
+ auto [checksum1, _1] = computeChecksums(rv1, order);
+ auto [checksum2, _2] = computeChecksums(rv2, order);
+
+ EXPECT_NE(checksum1, checksum2);
+}
+
+TEST_F(RowBasedChecksumTest, nullHandling) {
+ auto rv1 = makeRowVector({"a"}, {makeNullableFlatVector<int32_t>({1,
std::nullopt, 3})});
+ auto rv2 = makeRowVector({"a"}, {makeNullableFlatVector<int32_t>({1, 0,
3})}); // 0 vs null
+
+ std::vector<uint32_t> order = {0, 1, 2};
+ auto [checksum1, _1] = computeChecksums(rv1, order);
+ auto [checksum2, _2] = computeChecksums(rv2, order);
+
+ // null and 0 should produce different checksums
+ EXPECT_NE(checksum1, checksum2);
+}
+
+TEST_F(RowBasedChecksumTest, deterministic) {
+ auto rv =
+ makeRowVector({"a", "b"}, {makeFlatVector<int64_t>({100, 200, 300}),
makeFlatVector<double>({1.1, 2.2, 3.3})});
+
+ std::vector<uint32_t> order = {0, 1, 2};
+ auto [checksum1, _1] = computeChecksums(rv, order);
+ auto [checksum2, _2] = computeChecksums(rv, order);
+
+ // Same input, same order -> same result (deterministic)
+ EXPECT_EQ(checksum1, checksum2);
+}
diff --git
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/GlutenSplitResult.java
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/GlutenSplitResult.java
index 96b2a3fc54..a22d27da11 100644
---
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/GlutenSplitResult.java
+++
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/GlutenSplitResult.java
@@ -25,6 +25,7 @@ public class GlutenSplitResult {
private final long totalBytesEvicted;
private final long[] partitionLengths;
private final long[] rawPartitionLengths;
+ private final long[] rowBasedChecksums;
private final long bytesToEvict;
private final long peakBytes;
private final long sortTime;
@@ -46,7 +47,8 @@ public class GlutenSplitResult {
double avgDictionaryFields,
long dictionarySize,
long[] partitionLengths,
- long[] rawPartitionLengths) {
+ long[] rawPartitionLengths,
+ long[] rowBasedChecksums) {
this.totalComputePidTime = totalComputePidTime;
this.totalWriteTime = totalWriteTime;
this.totalEvictTime = totalEvictTime;
@@ -55,6 +57,7 @@ public class GlutenSplitResult {
this.totalBytesEvicted = totalBytesEvicted;
this.partitionLengths = partitionLengths;
this.rawPartitionLengths = rawPartitionLengths;
+ this.rowBasedChecksums = rowBasedChecksums;
this.bytesToEvict = totalBytesToEvict;
this.peakBytes = peakBytes;
this.sortTime = totalSortTime;
@@ -99,6 +102,10 @@ public class GlutenSplitResult {
return rawPartitionLengths;
}
+ public long[] getRowBasedChecksums() {
+ return rowBasedChecksums;
+ }
+
public long getBytesToEvict() {
return bytesToEvict;
}
diff --git
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ShuffleWriterJniWrapper.java
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ShuffleWriterJniWrapper.java
index 87685f8505..2f31a529d9 100644
---
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ShuffleWriterJniWrapper.java
+++
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ShuffleWriterJniWrapper.java
@@ -44,7 +44,8 @@ public class ShuffleWriterJniWrapper implements RuntimeAware {
int splitBufferSize,
double splitBufferReallocThreshold,
int partitionBufferEvictThreshold,
- long partitionWriterHandle);
+ long partitionWriterHandle,
+ boolean rowBasedChecksumEnabled);
public native long createSortShuffleWriter(
int numPartitions,
diff --git
a/gluten-ut/spark41/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
b/gluten-ut/spark41/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
index 1f5eaf3ac5..1f6919774d 100644
---
a/gluten-ut/spark41/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
+++
b/gluten-ut/spark41/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala
@@ -40,7 +40,7 @@ import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.execution.metric.{GlutenCustomMetricsSuite,
GlutenSQLMetricsSuite}
import org.apache.spark.sql.execution.python._
import
org.apache.spark.sql.extension.{GlutenCollapseProjectExecTransformerSuite,
GlutenSessionExtensionSuite}
-import org.apache.spark.sql.gluten.{GlutenFallbackStrategiesSuite,
GlutenFallbackSuite}
+import org.apache.spark.sql.gluten.{GlutenFallbackStrategiesSuite,
GlutenFallbackSuite, GlutenRowBasedChecksumSuite}
import org.apache.spark.sql.hive.execution._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.streaming._
@@ -1035,8 +1035,7 @@ class VeloxTestSettings extends BackendTestSettings {
// TODO: fix on Spark-4.1 introduced by
https://github.com/apache/spark/pull/47856
.exclude("SPARK-49386: test SortMergeJoin (with spill by size threshold)")
enableSuite[GlutenMathFunctionsSuite]
- // TODO: fix on Spark-4.1 see https://github.com/apache/spark/pull/50230
- // enableSuite[GlutenMapStatusEndToEndSuite]
+ enableSuite[GlutenMapStatusEndToEndSuite]
enableSuite[GlutenMetadataCacheSuite]
.exclude("SPARK-16336,SPARK-27961 Suggest fixing FileNotFoundException")
enableSuite[GlutenMiscFunctionsSuite]
@@ -1102,6 +1101,7 @@ class VeloxTestSettings extends BackendTestSettings {
enableSuite[GlutenUnsafeRowChecksumSuite]
enableSuite[GlutenXPathFunctionsSuite]
enableSuite[GlutenFallbackSuite]
+ enableSuite[GlutenRowBasedChecksumSuite]
enableSuite[GlutenHashAggregationQuerySuite]
// TODO: fix on https://github.com/apache/gluten/issues/11919
.exclude("udaf with all data types")
diff --git
a/gluten-ut/spark41/src/test/scala/org/apache/spark/sql/gluten/GlutenRowBasedChecksumSuite.scala
b/gluten-ut/spark41/src/test/scala/org/apache/spark/sql/gluten/GlutenRowBasedChecksumSuite.scala
new file mode 100644
index 0000000000..8ee91ca133
--- /dev/null
+++
b/gluten-ut/spark41/src/test/scala/org/apache/spark/sql/gluten/GlutenRowBasedChecksumSuite.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.gluten
+
+import org.apache.gluten.config.GlutenConfig
+
+import org.apache.spark.{MapOutputTrackerMaster, SparkConf}
+import org.apache.spark.sql.GlutenSQLTestsTrait
+import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.internal.SQLConf
+
+/**
+ * End-to-end tests for the row-based checksum (SPARK-51756) computed by
Gluten's
+ * ColumnarShuffleWriter. Verifies that `MapStatus.checksumValue` is
propagated, deterministic for
+ * identical input, and changes when row data changes.
+ */
+class GlutenRowBasedChecksumSuite extends GlutenSQLTestsTrait {
+
+ override def sparkConf: SparkConf = {
+ super.sparkConf
+ .set(SQLConf.LEAF_NODE_DEFAULT_PARALLELISM.key, "5")
+ .set(SQLConf.CLASSIC_SHUFFLE_DEPENDENCY_FILE_CLEANUP_ENABLED.key,
"false")
+ // Disable ANSI fallback to force Gluten's ColumnarShuffleWriter path.
+ .set(GlutenConfig.GLUTEN_ANSI_FALLBACK_ENABLED.key, "false")
+ }
+
+ private def getLatestShuffleChecksumValues(): Array[Long] = {
+ val tracker = spark.sparkContext.env.mapOutputTracker
+ .asInstanceOf[MapOutputTrackerMaster]
+ val latestShuffleId = tracker.shuffleStatuses.keys.max
+ tracker.shuffleStatuses(latestShuffleId).mapStatuses.map(_.checksumValue)
+ }
+
+ test("Gluten row-based checksum is deterministic") {
+ withSQLConf(
+ SQLConf.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED.key -> "true",
+ SQLConf.CLASSIC_SHUFFLE_DEPENDENCY_FILE_CLEANUP_ENABLED.key -> "false") {
+ withTable("t_det1", "t_det2") {
+ spark.range(500).repartition(5,
col("id")).write.mode("overwrite").saveAsTable("t_det1")
+ val checksums1 = getLatestShuffleChecksumValues()
+
+ spark.range(500).repartition(5,
col("id")).write.mode("overwrite").saveAsTable("t_det2")
+ val checksums2 = getLatestShuffleChecksumValues()
+
+ // Same input -> same checksumValue (deterministic)
+ assert(
+ checksums1.zip(checksums2).forall { case (a, b) => a == b },
+ s"Checksums not deterministic: ${checksums1.toSeq} vs
${checksums2.toSeq}")
+ }
+ }
+ }
+
+ test("Gluten row-based checksum detects data change") {
+ withSQLConf(
+ SQLConf.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED.key -> "true",
+ SQLConf.CLASSIC_SHUFFLE_DEPENDENCY_FILE_CLEANUP_ENABLED.key -> "false") {
+ withTable("t_diff1", "t_diff2") {
+ spark.range(500).repartition(5,
col("id")).write.mode("overwrite").saveAsTable("t_diff1")
+ val checksums1 = getLatestShuffleChecksumValues()
+
+ // Different data
+ spark
+ .range(500, 1000)
+ .repartition(5, col("id"))
+ .write
+ .mode("overwrite")
+ .saveAsTable("t_diff2")
+ val checksums2 = getLatestShuffleChecksumValues()
+
+ // Different input -> different checksumValue
+ assert(
+ checksums1.zip(checksums2).exists { case (a, b) => a != b },
+ s"Checksums should differ for different data: ${checksums1.toSeq} vs
${checksums2.toSeq}")
+ }
+ }
+ }
+}
diff --git
a/shims/spark33/src/main/scala/org/apache/spark/shuffle/GlutenMapStatusUtil.scala
b/shims/spark33/src/main/scala/org/apache/spark/shuffle/GlutenMapStatusUtil.scala
new file mode 100644
index 0000000000..9b5000946a
--- /dev/null
+++
b/shims/spark33/src/main/scala/org/apache/spark/shuffle/GlutenMapStatusUtil.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.shuffle
+
+import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.storage.BlockManagerId
+
+object GlutenMapStatusUtil {
+ def isRowBasedChecksumEnabled: Boolean = false
+
+ def createMapStatus(
+ loc: BlockManagerId,
+ uncompressedSizes: Array[Long],
+ mapTaskId: Long,
+ checksumValue: Long): MapStatus = {
+ MapStatus(loc, uncompressedSizes, mapTaskId)
+ }
+}
diff --git
a/shims/spark34/src/main/scala/org/apache/spark/shuffle/GlutenMapStatusUtil.scala
b/shims/spark34/src/main/scala/org/apache/spark/shuffle/GlutenMapStatusUtil.scala
new file mode 100644
index 0000000000..9b5000946a
--- /dev/null
+++
b/shims/spark34/src/main/scala/org/apache/spark/shuffle/GlutenMapStatusUtil.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.shuffle
+
+import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.storage.BlockManagerId
+
+object GlutenMapStatusUtil {
+ def isRowBasedChecksumEnabled: Boolean = false
+
+ def createMapStatus(
+ loc: BlockManagerId,
+ uncompressedSizes: Array[Long],
+ mapTaskId: Long,
+ checksumValue: Long): MapStatus = {
+ MapStatus(loc, uncompressedSizes, mapTaskId)
+ }
+}
diff --git
a/shims/spark35/src/main/scala/org/apache/spark/shuffle/GlutenMapStatusUtil.scala
b/shims/spark35/src/main/scala/org/apache/spark/shuffle/GlutenMapStatusUtil.scala
new file mode 100644
index 0000000000..9b5000946a
--- /dev/null
+++
b/shims/spark35/src/main/scala/org/apache/spark/shuffle/GlutenMapStatusUtil.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.shuffle
+
+import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.storage.BlockManagerId
+
+object GlutenMapStatusUtil {
+ def isRowBasedChecksumEnabled: Boolean = false
+
+ def createMapStatus(
+ loc: BlockManagerId,
+ uncompressedSizes: Array[Long],
+ mapTaskId: Long,
+ checksumValue: Long): MapStatus = {
+ MapStatus(loc, uncompressedSizes, mapTaskId)
+ }
+}
diff --git
a/shims/spark40/src/main/scala/org/apache/spark/shuffle/GlutenMapStatusUtil.scala
b/shims/spark40/src/main/scala/org/apache/spark/shuffle/GlutenMapStatusUtil.scala
new file mode 100644
index 0000000000..9b5000946a
--- /dev/null
+++
b/shims/spark40/src/main/scala/org/apache/spark/shuffle/GlutenMapStatusUtil.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.shuffle
+
+import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.storage.BlockManagerId
+
+object GlutenMapStatusUtil {
+ def isRowBasedChecksumEnabled: Boolean = false
+
+ def createMapStatus(
+ loc: BlockManagerId,
+ uncompressedSizes: Array[Long],
+ mapTaskId: Long,
+ checksumValue: Long): MapStatus = {
+ MapStatus(loc, uncompressedSizes, mapTaskId)
+ }
+}
diff --git
a/shims/spark41/src/main/scala/org/apache/spark/shuffle/GlutenMapStatusUtil.scala
b/shims/spark41/src/main/scala/org/apache/spark/shuffle/GlutenMapStatusUtil.scala
new file mode 100644
index 0000000000..bf8108c4b0
--- /dev/null
+++
b/shims/spark41/src/main/scala/org/apache/spark/shuffle/GlutenMapStatusUtil.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.shuffle
+
+import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.storage.BlockManagerId
+
+object GlutenMapStatusUtil {
+ def isRowBasedChecksumEnabled: Boolean = {
+ val sqlConf = SQLConf.get
+ sqlConf.shuffleOrderIndependentChecksumEnabled ||
+ sqlConf.shuffleChecksumMismatchFullRetryEnabled
+ }
+
+ def createMapStatus(
+ loc: BlockManagerId,
+ uncompressedSizes: Array[Long],
+ mapTaskId: Long,
+ checksumValue: Long): MapStatus = {
+ MapStatus(loc, uncompressedSizes, mapTaskId, checksumValue)
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]