mridulm commented on a change in pull request #32385: URL: https://github.com/apache/spark/pull/32385#discussion_r638118918
########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/DiagnoseCorruption.java ########## @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import io.netty.buffer.ByteBuf; +import org.apache.commons.lang3.builder.ToStringBuilder; +import org.apache.commons.lang3.builder.ToStringStyle; +import org.apache.spark.network.protocol.Encoders; + +/** Request to get the cause of a corrupted block. Returns {@link CorruptionCause} */ +public class DiagnoseCorruption extends BlockTransferMessage { + private final String appId; + private final String execId; + public final String blockId; + public final long checksum; + + public DiagnoseCorruption(String appId, String execId, String blockId, long checksum) { + this.appId = appId; + this.execId = execId; + this.blockId = blockId; + this.checksum = checksum; + } + + @Override + protected Type type() { + return Type.DIAGNOSE_CORRUPTION; + } + + @Override + public String toString() { + return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) + .append("appId", appId) + .append("execId", execId) + .append("blockId", blockId) + .append("checksum", checksum) + .toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + DiagnoseCorruption that = (DiagnoseCorruption) o; + + if (!appId.equals(that.appId)) return false; + if (!execId.equals(that.execId)) return false; + if (!blockId.equals(that.blockId)) return false; + return checksum == that.checksum; Review comment: super nit: check `checksum` first ? cheapest check .. ########## File path: core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java ########## @@ -57,11 +58,15 @@ private long currChannelPosition; private long bytesWrittenToMergedFile = 0L; + private Checksum checksumCal = null; + private long[] partitionChecksums = new long[0]; + private final File outputFile; private File outputTempFile; private FileOutputStream outputFileStream; - private FileChannel outputFileChannel; + private CountingWritableChannel outputChannel; private BufferedOutputStream outputBufferedFileStream; + private CheckedOutputStream checkedOutputStream; Review comment: You dont need a reference to this or to `outputFileStream`, they can be removed. ########## File path: core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java ########## @@ -57,11 +58,15 @@ private long currChannelPosition; private long bytesWrittenToMergedFile = 0L; + private Checksum checksumCal = null; Review comment: nit: `checksumCal` -> `checksumAlgo` or `checksumImpl` ? ########## File path: core/src/main/scala/org/apache/spark/storage/DiskStore.scala ########## @@ -328,23 +329,3 @@ private class ReadableChannelFileRegion(source: ReadableByteChannel, blockSize: override def deallocate(): Unit = source.close() } - -private class CountingWritableChannel(sink: WritableByteChannel) extends WritableByteChannel { - - private var count = 0L - - def getCount: Long = count - - override def write(src: ByteBuffer): Int = { - val written = sink.write(src) - if (written > 0) { - count += written - } - written - } - - override def isOpen(): Boolean = sink.isOpen() - - override def close(): Unit = sink.close() - -} Review comment: Nice unification ! ########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/DiagnoseCorruption.java ########## @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import io.netty.buffer.ByteBuf; +import org.apache.commons.lang3.builder.ToStringBuilder; +import org.apache.commons.lang3.builder.ToStringStyle; +import org.apache.spark.network.protocol.Encoders; + +/** Request to get the cause of a corrupted block. Returns {@link CorruptionCause} */ +public class DiagnoseCorruption extends BlockTransferMessage { + private final String appId; + private final String execId; + public final String blockId; + public final long checksum; + + public DiagnoseCorruption(String appId, String execId, String blockId, long checksum) { + this.appId = appId; + this.execId = execId; + this.blockId = blockId; + this.checksum = checksum; + } + + @Override + protected Type type() { + return Type.DIAGNOSE_CORRUPTION; + } + + @Override + public String toString() { + return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) + .append("appId", appId) + .append("execId", execId) + .append("blockId", blockId) + .append("checksum", checksum) + .toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + DiagnoseCorruption that = (DiagnoseCorruption) o; + + if (!appId.equals(that.appId)) return false; + if (!execId.equals(that.execId)) return false; + if (!blockId.equals(that.blockId)) return false; + return checksum == that.checksum; + } + + @Override + public int hashCode() { + int result = appId.hashCode(); + result = 31 * result + execId.hashCode(); + result = 31 * result + blockId.hashCode(); + result = 31 * result + (int) checksum; Review comment: nit: checksum -> `Long.hashCode(checksum)` ? ########## File path: core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java ########## @@ -21,6 +21,7 @@ import java.io.FileInputStream; import java.io.IOException; import java.io.OutputStream; +import java.nio.ByteBuffer; Review comment: Revert ? ########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/CorruptionCause.java ########## @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import io.netty.buffer.ByteBuf; +import org.apache.commons.lang3.builder.ToStringBuilder; +import org.apache.commons.lang3.builder.ToStringStyle; +import org.apache.spark.network.corruption.Cause; + +/** Response to the {@link DiagnoseCorruption} */ +public class CorruptionCause extends BlockTransferMessage { + public Cause cause; + + public CorruptionCause(Cause cause) { + this.cause = cause; + } + + @Override + protected Type type() { + return Type.CORRUPTION_CAUSE; + } + + @Override + public String toString() { + return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) + .append("cause", cause) + .toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + CorruptionCause that = (CorruptionCause) o; + return cause == that.cause; + } + + @Override + public int hashCode() { + return cause.hashCode(); + } + + @Override + public int encodedLength() { + return 4; /* encoded length of cause */ + } + + @Override + public void encode(ByteBuf buf) { + buf.writeInt(cause.ordinal()); Review comment: `int` -> `byte` ? ########## File path: core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java ########## @@ -100,12 +110,12 @@ public ShufflePartitionWriter getPartitionWriter(int reducePartitionId) throws I @Override public MapOutputCommitMessage commitAllPartitions() throws IOException { // Check the position after transferTo loop to see if it is in the right position and raise a - // exception if it is incorrect. The position will not be increased to the expected length + // exception if it is incorrect. The po sition will not be increased to the expected length Review comment: Revert this ? Looks like an accidental change. ########## File path: core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java ########## @@ -243,6 +265,10 @@ public void close() { isClosed = true; partitionLengths[partitionId] = count; bytesWrittenToMergedFile += count; + if (checksumCal != null) { + partitionChecksums[partitionId] = checksumCal.getValue(); + checksumCal.reset(); + } Review comment: nit: pull it into `saveChecksum(partitionId)` ? ########## File path: core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java ########## @@ -131,28 +141,40 @@ private void cleanUp() throws IOException { if (outputBufferedFileStream != null) { outputBufferedFileStream.close(); } - if (outputFileChannel != null) { - outputFileChannel.close(); + if (outputChannel != null) { + outputChannel.close(); + } + if (checkedOutputStream != null) { + checkedOutputStream.close(); } if (outputFileStream != null) { outputFileStream.close(); } + if (checksumCal != null) { + checksumCal.reset(); + } } private void initStream() throws IOException { if (outputFileStream == null) { outputFileStream = new FileOutputStream(outputTempFile, true); } + if (checksumCal != null && checkedOutputStream == null) { + checkedOutputStream = new CheckedOutputStream(outputFileStream, checksumCal); + } if (outputBufferedFileStream == null) { - outputBufferedFileStream = new BufferedOutputStream(outputFileStream, bufferSize); + outputBufferedFileStream = new BufferedOutputStream( + checksumCal != null ? checkedOutputStream : outputFileStream, bufferSize); } Review comment: If `outputFileStream` is `null`, we should reset `checkedOutputStream` and `outputBufferedFileStream` - not sure why we had individual `null` checks earlier. Also, need these to be reset to `null` in `cleanUp` ... Btw, if we remove references to these streams as I suggested above, it will make the `initChannel`/`initStream`/`cleanUp` simpler as well (and also fix this comment). Note: these are not specifically due to this PR, but get added to here. ########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java ########## @@ -47,6 +48,15 @@ protected volatile TransportClientFactory clientFactory; protected String appId; + public Cause diagnoseCorruption( Review comment: Include javadoc ? ########## File path: core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala ########## @@ -1762,7 +1762,7 @@ private[spark] class DAGScheduler( } if (shouldAbortStage) { - val abortMessage = if (disallowStageRetryForTest) { + val abortMessage = if (false) { Review comment: revert ? ########## File path: core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java ########## @@ -26,12 +26,12 @@ */ final class SpillInfo { final long[] partitionLengths; + final long[] partitionChecksums; final File file; - final TempShuffleBlockId blockId; - SpillInfo(int numPartitions, File file, TempShuffleBlockId blockId) { + SpillInfo(int numPartitions, File file, boolean checksumEnabled) { this.partitionLengths = new long[numPartitions]; + this.partitionChecksums = checksumEnabled ? new long[numPartitions] : new long[0]; Review comment: We are using `null` in `MapOutputCommitMessage ` while empty array here when checksum is disabled. Unify to a single idiom ? Given `writeMetadataFileAndCommit` is depending on empty array (based on how it is written up right now), thoughts on using `long[0]` ? (Btw, use a constant EMPTY_LONG_ARRAY if deciding to using `new long[0]` ########## File path: core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala ########## @@ -104,6 +108,40 @@ private[spark] class NettyBlockTransferService( } } + override def diagnoseCorruption( + host: String, + port: Int, + execId: String, + blockId: String, + checksum: Long): Cause = { + // A monitor for the thread to wait on. + val result = Promise[Cause]() + val client = clientFactory.createClient(host, port) + client.sendRpc(new DiagnoseCorruption(appId, execId, blockId, checksum).toByteBuffer, + new RpcResponseCallback { + override def onSuccess(response: ByteBuffer): Unit = { + val cause = BlockTransferMessage.Decoder + .fromByteBuffer(response).asInstanceOf[CorruptionCause] + result.success(cause.cause) + } + + override def onFailure(e: Throwable): Unit = { + logger.warn("Failed to get the corruption cause.", e) + result.success(Cause.UNKNOWN) + } + }) + val timeout = new RpcTimeout( + conf.get(Network.NETWORK_TIMEOUT).seconds, + Network.NETWORK_TIMEOUT.key) + try { + timeout.awaitResult(result.future) Review comment: Any implications of making this a sync call where we are blocking the thread ? ########## File path: core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala ########## @@ -174,6 +192,26 @@ private[spark] class IndexShuffleBlockResolver( } } + private def getChecksums(checksumFile: File, blockNum: Int): Array[Long] = { + if (!checksumFile.exists()) return null + val checksums = new ArrayBuffer[Long] + // Read the checksums of blocks + var in: DataInputStream = null + try { + in = new DataInputStream(new NioBufferedFileInputStream(checksumFile)) + while (checksums.size < blockNum) { + checksums += in.readLong() + } + } catch { + case _: IOException | _: EOFException => + return null + } finally { + in.close() + } + Review comment: Something like this might be better ? ```suggestion Try(Utils.tryWithResource(new DataInputStream(new NioBufferedFileInputStream(checksumFile))) { in => Array.tabulate(blockNum)(_ => in.readLong()) }).getOrElse(null) ``` ########## File path: core/src/main/scala/org/apache/spark/storage/BlockManager.scala ########## @@ -275,6 +279,45 @@ private[spark] class BlockManager( override def getLocalDiskDirs: Array[String] = diskBlockManager.localDirsString + override def diagnoseShuffleBlockCorruption(blockId: BlockId, clientChecksum: Long): Cause = { + assert(blockId.isInstanceOf[ShuffleBlockId], + s"Corruption diagnosis only supports shuffle block yet, but got $blockId") + val shuffleBlock = blockId.asInstanceOf[ShuffleBlockId] + val resolver = shuffleManager.shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver] + val checksumFile = resolver.getChecksumFile(shuffleBlock.shuffleId, shuffleBlock.mapId) + val reduceId = shuffleBlock.reduceId + if (checksumFile.exists()) { + var in: DataInputStream = null + try { + val channel = Files.newByteChannel(checksumFile.toPath) + channel.position(reduceId * 8L) + in = new DataInputStream(Channels.newInputStream(channel)) + val goldenChecksum = in.readLong() Review comment: Extract out a `readChecksum` and `computeChecksum` methods ? Btw, tryWithResource { DataInputStream(FileInputStream()).skip(reduceId * 8L).readLong() } would do the trick for readChecksum. ########## File path: core/src/main/scala/org/apache/spark/storage/BlockManager.scala ########## @@ -275,6 +279,45 @@ private[spark] class BlockManager( override def getLocalDiskDirs: Array[String] = diskBlockManager.localDirsString + override def diagnoseShuffleBlockCorruption(blockId: BlockId, clientChecksum: Long): Cause = { + assert(blockId.isInstanceOf[ShuffleBlockId], + s"Corruption diagnosis only supports shuffle block yet, but got $blockId") + val shuffleBlock = blockId.asInstanceOf[ShuffleBlockId] + val resolver = shuffleManager.shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver] + val checksumFile = resolver.getChecksumFile(shuffleBlock.shuffleId, shuffleBlock.mapId) + val reduceId = shuffleBlock.reduceId + if (checksumFile.exists()) { + var in: DataInputStream = null + try { + val channel = Files.newByteChannel(checksumFile.toPath) + channel.position(reduceId * 8L) + in = new DataInputStream(Channels.newInputStream(channel)) + val goldenChecksum = in.readLong() + val blockData = resolver.getBlockData(blockId) + val checksumIn = new CheckedInputStream(blockData.createInputStream(), new Adler32) + val buffer = new Array[Byte](8192) + while (checksumIn.read(buffer, 0, 8192) != -1) {} + val recalculatedChecksum = checksumIn.getChecksum.getValue Review comment: We are not closing `checksumIn` btw. ########## File path: core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java ########## @@ -76,6 +81,11 @@ public LocalDiskShuffleMapOutputWriter( (int) (long) sparkConf.get( package$.MODULE$.SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE()) * 1024; this.partitionLengths = new long[numPartitions]; + boolean checksumEnabled = (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_CHECKSUM()); + if (checksumEnabled) { + this.checksumCal = new Adler32(); Review comment: Pull the `Checksum` management out and get everyone to depend on that ? Will also allow us to change it in future if required. ########## File path: core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala ########## @@ -76,6 +77,9 @@ private[spark] class DiskBlockObjectWriter( private var initialized = false private var streamOpen = false private var hasBeenClosed = false + private var checksumEnabled = false + private var checksumCal: Checksum = null + private var checksumOutputStream: CheckedOutputStream = null Review comment: Same comment as above - reduce the number of streaming as fields ? ########## File path: core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala ########## @@ -104,6 +108,40 @@ private[spark] class NettyBlockTransferService( } } + override def diagnoseCorruption( + host: String, + port: Int, + execId: String, + blockId: String, + checksum: Long): Cause = { + // A monitor for the thread to wait on. + val result = Promise[Cause]() + val client = clientFactory.createClient(host, port) + client.sendRpc(new DiagnoseCorruption(appId, execId, blockId, checksum).toByteBuffer, + new RpcResponseCallback { + override def onSuccess(response: ByteBuffer): Unit = { + val cause = BlockTransferMessage.Decoder + .fromByteBuffer(response).asInstanceOf[CorruptionCause] + result.success(cause.cause) + } + + override def onFailure(e: Throwable): Unit = { + logger.warn("Failed to get the corruption cause.", e) + result.success(Cause.UNKNOWN) + } + }) + val timeout = new RpcTimeout( + conf.get(Network.NETWORK_TIMEOUT).seconds, Review comment: `seconds` -> `millis` ? ########## File path: core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala ########## @@ -333,13 +394,40 @@ private[spark] class IndexShuffleBlockResolver( if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) { throw new IOException("fail to rename file " + dataTmp + " to " + dataFile) } + + checksumTmpOpt.zip(checksumFileOpt).foreach { case (checksumTmp, checksumFile) => + val out = new DataOutputStream( + new BufferedOutputStream( + new FileOutputStream(checksumTmp) + ) + ) + Utils.tryWithSafeFinally { + checksums.foreach(out.writeLong) + } { + out.close() + } + + if (checksumFile.exists()) { + checksumFile.delete() + } + if (!checksumTmp.renameTo(checksumFile)) { + // It's not worthwhile to fail here after index file and data file are already + // successfully stored due to checksum is only used for the corner error case. + logWarning("fail to rename file " + checksumTmp + " to " + checksumFile) + } + } } } } finally { logDebug(s"Shuffle index for mapId $mapId: ${lengths.mkString("[", ",", "]")}") if (indexTmp.exists() && !indexTmp.delete()) { logError(s"Failed to delete temporary index file at ${indexTmp.getAbsolutePath}") } + checksumTmpOpt.foreach { checksumTmp => + if (checksumTmp.exists() && !checksumTmp.delete()) { + logError(s"Failed to delete temporary checksum file at ${checksumTmp.getAbsolutePath}") Review comment: `logInfo` ? ########## File path: common/network-common/src/main/java/org/apache/spark/network/corruption/Cause.java ########## @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.corruption; + +public enum Cause { + DISK, NETWORK, UNKNOWN; Review comment: `UNKNOWN` is handling three cases right now: * No checksum available for validation. * diagnosis failed due to some reason (timeout/failure/etc). * Checksum matches, no corruption detected. Anything else ? For the last, move it to a separate `Cause` ? ########## File path: core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala ########## @@ -333,13 +394,40 @@ private[spark] class IndexShuffleBlockResolver( if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) { throw new IOException("fail to rename file " + dataTmp + " to " + dataFile) } + + checksumTmpOpt.zip(checksumFileOpt).foreach { case (checksumTmp, checksumFile) => + val out = new DataOutputStream( + new BufferedOutputStream( + new FileOutputStream(checksumTmp) + ) + ) + Utils.tryWithSafeFinally { + checksums.foreach(out.writeLong) + } { + out.close() + } + + if (checksumFile.exists()) { + checksumFile.delete() + } + if (!checksumTmp.renameTo(checksumFile)) { + // It's not worthwhile to fail here after index file and data file are already + // successfully stored due to checksum is only used for the corner error case. + logWarning("fail to rename file " + checksumTmp + " to " + checksumFile) + } + } Review comment: ```suggestion checksumFileOpt.foreach { checksumFile => val checksumTmp = checksumTmpOpt.get Utils.tryWithResource(new DataOutputStream(new BufferedOutputStream( new FileOutputStream(checksumTmp))) { out => checksums.foreach(out.writeLong) }) if (checksumFile.exists()) { checksumFile.delete() } if (!checksumTmp.renameTo(checksumFile)) { // It's not worthwhile to fail here after index file and data file are already // successfully stored due to checksum is only used for the corner error case. logWarning("fail to rename file " + checksumTmp + " to " + checksumFile) } } ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
