mridulm commented on a change in pull request #33451: URL: https://github.com/apache/spark/pull/33451#discussion_r676871069
########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java ########## @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.checksum; + +import java.io.*; +import java.util.zip.Adler32; +import java.util.zip.CRC32; +import java.util.zip.CheckedInputStream; +import java.util.zip.Checksum; + +import com.google.common.io.ByteStreams; +import org.apache.spark.annotation.Private; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.corruption.Cause; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A set of utility functions for the shuffle checksum. + */ +@Private +public class ShuffleChecksumHelper { + private static final Logger logger = + LoggerFactory.getLogger(ShuffleChecksumHelper.class); + + public static final int CHECKSUM_CALCULATION_BUFFER = 8192; + public static final Checksum[] EMPTY_CHECKSUM = new Checksum[0]; + public static final long[] EMPTY_CHECKSUM_VALUE = new long[0]; + + public static Checksum[] createPartitionChecksums(int numPartitions, String algorithm) { + return getChecksumsByAlgorithm(numPartitions, algorithm); + } + + private static Checksum[] getChecksumsByAlgorithm(int num, String algorithm) { + Checksum[] checksums; + switch (algorithm) { + case "ADLER32": + checksums = new Adler32[num]; + for (int i = 0; i < num; i++) { + checksums[i] = new Adler32(); + } + return checksums; + + case "CRC32": + checksums = new CRC32[num]; + for (int i = 0; i < num; i++) { + checksums[i] = new CRC32(); + } + return checksums; + + default: + throw new UnsupportedOperationException( + "Unsupported shuffle checksum algorithm: " + algorithm); + } + } + + public static Checksum getChecksumByAlgorithm(String algorithm) { + return getChecksumsByAlgorithm(1, algorithm)[0]; + } + + public static String getChecksumFileName(String blockName, String algorithm) { + // append the shuffle checksum algorithm as the file extension + return String.format("%s.%s", blockName, algorithm); + } + + public static Checksum getChecksumByFileExtension(String fileName) { Review comment: We can remove this and use the algo passed in explicitly. ########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java ########## @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.checksum; + +import java.io.*; +import java.util.zip.Adler32; +import java.util.zip.CRC32; +import java.util.zip.CheckedInputStream; +import java.util.zip.Checksum; + +import com.google.common.io.ByteStreams; +import org.apache.spark.annotation.Private; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.corruption.Cause; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A set of utility functions for the shuffle checksum. + */ +@Private +public class ShuffleChecksumHelper { + private static final Logger logger = + LoggerFactory.getLogger(ShuffleChecksumHelper.class); + + public static final int CHECKSUM_CALCULATION_BUFFER = 8192; + public static final Checksum[] EMPTY_CHECKSUM = new Checksum[0]; + public static final long[] EMPTY_CHECKSUM_VALUE = new long[0]; + + public static Checksum[] createPartitionChecksums(int numPartitions, String algorithm) { + return getChecksumsByAlgorithm(numPartitions, algorithm); + } + + private static Checksum[] getChecksumsByAlgorithm(int num, String algorithm) { + Checksum[] checksums; + switch (algorithm) { + case "ADLER32": + checksums = new Adler32[num]; + for (int i = 0; i < num; i++) { + checksums[i] = new Adler32(); + } + return checksums; + + case "CRC32": + checksums = new CRC32[num]; + for (int i = 0; i < num; i++) { + checksums[i] = new CRC32(); + } + return checksums; + + default: + throw new UnsupportedOperationException( + "Unsupported shuffle checksum algorithm: " + algorithm); + } + } + + public static Checksum getChecksumByAlgorithm(String algorithm) { + return getChecksumsByAlgorithm(1, algorithm)[0]; + } + + public static String getChecksumFileName(String blockName, String algorithm) { + // append the shuffle checksum algorithm as the file extension + return String.format("%s.%s", blockName, algorithm); + } + + public static Checksum getChecksumByFileExtension(String fileName) { + int index = fileName.lastIndexOf("."); + String algorithm = fileName.substring(index + 1); + return getChecksumsByAlgorithm(1, algorithm)[0]; + } + + private static long readChecksumByReduceId(File checksumFile, int reduceId) throws IOException { + try (DataInputStream in = new DataInputStream(new FileInputStream(checksumFile))) { + ByteStreams.skipFully(in, reduceId * 8); + return in.readLong(); + } + } + + private static long calculateChecksumForPartition( + ManagedBuffer partitionData, + Checksum checksumAlgo) throws IOException { + InputStream in = partitionData.createInputStream(); + byte[] buffer = new byte[CHECKSUM_CALCULATION_BUFFER]; + try(CheckedInputStream checksumIn = new CheckedInputStream(in, checksumAlgo)) { + while (checksumIn.read(buffer, 0, CHECKSUM_CALCULATION_BUFFER) != -1) {} + return checksumAlgo.getValue(); + } + } + + /** + * Diagnose the possible cause of the shuffle data corruption by verifying the shuffle checksums. + * + * There're 3 different kinds of checksums for the same shuffle partition: + * - checksum (c1) that is calculated by the shuffle data reader + * - checksum (c2) that is calculated by the shuffle data writer and stored in the checksum file + * - checksum (c3) that is recalculated during diagnosis + * + * And the diagnosis mechanism works like this: + * If c2 != c3, we suspect the corruption is caused by the DISK_ISSUE. Otherwise, if c1 != c3, + * we suspect the corruption is caused by the NETWORK_ISSUE. Otherwise, the cause remains + * CHECKSUM_VERIFY_PASS. In case of the any other failures, the cause remains UNKNOWN_ISSUE. + * + * @param checksumFile The checksum file that written by the shuffle writer + * @param reduceId The reduceId of the shuffle block + * @param partitionData The partition data of the shuffle block + * @param checksumByReader The checksum value that calculated by the shuffle data reader + * @return The cause of data corruption + */ + public static Cause diagnoseCorruption( + File checksumFile, + int reduceId, + ManagedBuffer partitionData, + long checksumByReader) { + Cause cause; + try { + long diagnoseStart = System.currentTimeMillis(); + // Try to get the checksum instance before reading the checksum file so that + // `UnsupportedOperationException` can be thrown first before `FileNotFoundException` + // when the checksum algorithm isn't supported. + Checksum checksumAlgo = getChecksumByFileExtension(checksumFile.getName()); + long checksumByWriter = readChecksumByReduceId(checksumFile, reduceId); + long checksumByReCalculation = calculateChecksumForPartition(partitionData, checksumAlgo); + long duration = System.currentTimeMillis() - diagnoseStart; + logger.info("Shuffle corruption diagnosis took {} ms, checksum file {}", + duration, checksumFile.getAbsolutePath()); + if (checksumByWriter != checksumByReCalculation) { + cause = Cause.DISK_ISSUE; + } else if (checksumByWriter != checksumByReader) { Review comment: I probably mentioned this before - can we move this to before we recompute checksum via `calculateChecksumForPartition` ? It is a cheap check, and if network is the culprit, we immediately return. ########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java ########## @@ -374,6 +376,32 @@ public int removeBlocks(String appId, String execId, String[] blockIds) { .collect(Collectors.toMap(Pair::getKey, Pair::getValue)); } + /** + * Diagnose the possible cause of the shuffle data corruption by verify the shuffle checksums + */ + public Cause diagnoseShuffleBlockCorruption( + String appId, + String execId, + int shuffleId, + long mapId, + int reduceId, + long checksumByReader) { + ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId)); + String fileName = "shuffle_" + shuffleId + "_" + mapId + "_0.checksum"; + File probeFile = ExecutorDiskUtils.getFile( + executor.localDirs, + executor.subDirsPerLocalDir, + fileName); Review comment: Yeah, I meant something similar ... we dont need to do this for this PR btw; just thinking out. ########## File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala ########## @@ -822,8 +836,15 @@ final class ShuffleBlockFetcherIterator( } } catch { case e: IOException => - buf.release() + // When shuffle checksum is enabled, for a block that is corrupted twice, + // we'd calculate the checksum of the block by consuming the remaining data + // in the buf. So, we should release the buf later. + if (!(checksumEnabled && corruptedBlocks.contains(blockId))) { + buf.release() + } Review comment: To clarify, with this change, for all fetches after the first failure, we will diagnose (except if cause == disk) ? If yes, the change proposed in `diagnoseCorruption` will be more useful. ########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java ########## @@ -46,6 +46,43 @@ protected volatile TransportClientFactory clientFactory; protected String appId; + protected TransportConf transportConf; Review comment: Yeah, I was not sure if this would be easy to do ... wanted to surface it anyway. Thanks for double checking @Ngone51 ! ########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java ########## @@ -374,6 +379,27 @@ public int removeBlocks(String appId, String execId, String[] blockIds) { .collect(Collectors.toMap(Pair::getKey, Pair::getValue)); } + /** + * Diagnose the possible cause of the shuffle data corruption by verify the shuffle checksums + */ + public Cause diagnoseShuffleBlockCorruption( + String appId, + String execId, + int shuffleId, + long mapId, + int reduceId, + long checksumByReader, + String algorithm) { + ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId)); + String fileName = "shuffle_" + shuffleId + "_" + mapId + "_0.checksum." + algorithm; Review comment: That is a cleaner solution, thanks ! Will also mean we can spread the checksums without needing to couple it with the shuffle file. ########## File path: common/network-common/src/main/java/org/apache/spark/network/corruption/Cause.java ########## @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.corruption; + +/** + * The cause of shuffle data corruption. + */ Review comment: nit: Add `@since` ? ########## File path: common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java ########## @@ -108,6 +116,107 @@ public void testCompatibilityWithOldVersion() { verifyOpenBlockLatencyMetrics(2, 2); } + private void checkDiagnosisResult( + String algorithm, + Cause expectedCaused) throws IOException { + String appId = "app0"; + String execId = "execId"; + int shuffleId = 0; + long mapId = 0; + int reduceId = 0; + + // prepare the checksum file + File tmpDir = Files.createTempDir(); + File checksumFile = new File(tmpDir, + "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".checksum." + algorithm); + DataOutputStream out = new DataOutputStream(new FileOutputStream(checksumFile)); + long checksumByReader = 0L; + if (expectedCaused != Cause.UNSUPPORTED_CHECKSUM_ALGORITHM) { + Checksum checksum = ShuffleChecksumHelper.getChecksumByAlgorithm(algorithm); + CheckedInputStream checkedIn = new CheckedInputStream( + blockMarkers[0].createInputStream(), checksum); + byte[] buffer = new byte[10]; + ByteStreams.readFully(checkedIn, buffer, 0, (int) blockMarkers[0].size()); + long checksumByWriter = checkedIn.getChecksum().getValue(); + + switch (expectedCaused) { + case DISK_ISSUE: + out.writeLong(-checksumByWriter); + checksumByReader = checksumByWriter; + break; + + case NETWORK_ISSUE: + out.writeLong(checksumByWriter); + checksumByReader = -1 * checksumByWriter; Review comment: nit: `- checksumByWriter` ? ########## File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala ########## @@ -971,7 +1002,51 @@ final class ShuffleBlockFetcherIterator( currentResult.mapIndex, currentResult.address, detectCorrupt && streamCompressedOrEncrypted, - currentResult.isNetworkReqDone)) + currentResult.isNetworkReqDone, + Option(checkedIn))) + } + + /** + * Get the suspect corruption cause for the corrupted block. It should be only invoked + * when checksum is enabled. Review comment: ```suggestion * when checksum is enabled and corruption was detected atleast once. ``` ########## File path: common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java ########## @@ -108,6 +116,107 @@ public void testCompatibilityWithOldVersion() { verifyOpenBlockLatencyMetrics(2, 2); } + private void checkDiagnosisResult( + String algorithm, + Cause expectedCaused) throws IOException { + String appId = "app0"; + String execId = "execId"; + int shuffleId = 0; + long mapId = 0; + int reduceId = 0; + + // prepare the checksum file + File tmpDir = Files.createTempDir(); + File checksumFile = new File(tmpDir, + "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".checksum." + algorithm); + DataOutputStream out = new DataOutputStream(new FileOutputStream(checksumFile)); + long checksumByReader = 0L; + if (expectedCaused != Cause.UNSUPPORTED_CHECKSUM_ALGORITHM) { + Checksum checksum = ShuffleChecksumHelper.getChecksumByAlgorithm(algorithm); + CheckedInputStream checkedIn = new CheckedInputStream( + blockMarkers[0].createInputStream(), checksum); + byte[] buffer = new byte[10]; + ByteStreams.readFully(checkedIn, buffer, 0, (int) blockMarkers[0].size()); + long checksumByWriter = checkedIn.getChecksum().getValue(); + + switch (expectedCaused) { + case DISK_ISSUE: + out.writeLong(-checksumByWriter); + checksumByReader = checksumByWriter; + break; Review comment: It will be fun if `checksumByWriter` ended up becoming `MIN_VALUE` :-) Do you want to handle that as well ? ########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java ########## @@ -374,6 +379,27 @@ public int removeBlocks(String appId, String execId, String[] blockIds) { .collect(Collectors.toMap(Pair::getKey, Pair::getValue)); } + /** + * Diagnose the possible cause of the shuffle data corruption by verify the shuffle checksums + */ + public Cause diagnoseShuffleBlockCorruption( + String appId, + String execId, + int shuffleId, + long mapId, + int reduceId, + long checksumByReader, + String algorithm) { + ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId)); + String fileName = "shuffle_" + shuffleId + "_" + mapId + "_0.checksum." + algorithm; Review comment: nit: Btw, add a comment that this should be in sync with `IndexShuffleBlockResolver.getChecksumFile` ? ########## File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala ########## @@ -834,17 +855,27 @@ final class ShuffleBlockFetcherIterator( pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) // Set result to null to trigger another iteration of the while loop. result = null - } else { - if (buf.isInstanceOf[FileSegmentManagedBuffer] - || corruptedBlocks.contains(blockId)) { - throwFetchFailedException(blockId, mapIndex, address, e) + } else if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + throwFetchFailedException(blockId, mapIndex, address, e) + } else if (corruptedBlocks.contains(blockId)) { + // It's the second time this block is detected corrupted + if (checksumEnabled) { + // Diagnose the cause of data corruption if shuffle checksum is enabled + val cause = diagnoseCorruption(checkedIn, address, blockId) + buf.release() + val errorMsg = s"Block $blockId is corrupted due to $cause." Review comment: Only if cause != `CHECKSUM_VERIFY_PASS` ? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
