mccheah commented on a change in pull request #25007:
[SPARK-28209][CORE][SHUFFLE] Proposed new shuffle writer API
URL: https://github.com/apache/spark/pull/25007#discussion_r305526405
##########
File path:
core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala
##########
@@ -142,82 +109,54 @@ class LocalDiskShuffleMapOutputWriterSuite extends
SparkFunSuite with BeforeAndA
intercept[IllegalStateException] {
stream.write(p)
}
- assert(writer.getNumBytesWritten === D_LEN)
+ assert(writer.getNumBytesWritten === data(p).length)
}
- mapOutputWriter.commitAllPartitions()
- val partitionLengths = (0 until NUM_PARTITIONS).map { _ =>
D_LEN.toDouble}.toArray
- assert(partitionSizesInMergedFile === partitionLengths)
- assert(mergedOutputFile.length() === partitionLengths.sum)
- assert(data === readRecordsFromFile(false))
+ verifyWrittenRecords()
}
test("writing to a channel") {
(0 until NUM_PARTITIONS).foreach { p =>
val writer = mapOutputWriter.getPartitionWriter(p)
- val channel = writer.openTransferrableChannel()
- val byteBuffer = ByteBuffer.allocate(D_LEN * 4)
- val intBuffer = byteBuffer.asIntBuffer()
- intBuffer.put(data(p))
- val numBytes = byteBuffer.remaining()
val outputTempFile = File.createTempFile("channelTemp", "", tempDir)
val outputTempFileStream = new FileOutputStream(outputTempFile)
- Utils.copyStream(
- new ByteBufferInputStream(byteBuffer),
- outputTempFileStream,
- closeStreams = true)
+ outputTempFileStream.write(data(p))
+ outputTempFileStream.close()
val tempFileInput = new FileInputStream(outputTempFile)
- channel.transferFrom(tempFileInput.getChannel, 0L, numBytes)
- // Bytes require * 4
- channel.close()
- tempFileInput.close()
- assert(writer.getNumBytesWritten === D_LEN * 4)
+ val channel = writer.openTransferrableChannel()
+ Utils.tryWithResource(new FileInputStream(outputTempFile)) {
tempFileInput =>
+ Utils.tryWithResource(writer.openTransferrableChannel()) { channel =>
+ channel.transferFrom(tempFileInput.getChannel, 0L, data(p).length)
+ }
+ }
+ assert(writer.getNumBytesWritten === data(p).length)
}
- mapOutputWriter.commitAllPartitions()
- val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN *
4).toDouble }.toArray
- assert(partitionSizesInMergedFile === partitionLengths)
- assert(mergedOutputFile.length() === partitionLengths.sum)
- assert(data === readRecordsFromFile(true))
+ verifyWrittenRecords()
}
- test("copyStreams with an outputstream") {
+ private def readRecordsFromFile() = {
+ var startOffset = 0L
+ val result = new Array[Array[Byte]](NUM_PARTITIONS)
(0 until NUM_PARTITIONS).foreach { p =>
- val writer = mapOutputWriter.getPartitionWriter(p)
- val stream = writer.openStream()
- val byteBuffer = ByteBuffer.allocate(D_LEN * 4)
- val intBuffer = byteBuffer.asIntBuffer()
- intBuffer.put(data(p))
- val in = new ByteArrayInputStream(byteBuffer.array())
- Utils.copyStream(in, stream, false, false)
- in.close()
- stream.close()
- assert(writer.getNumBytesWritten === D_LEN * 4)
+ val partitionSize = data(p).length
+ if (partitionSize > 0) {
+ val in = new FileInputStream(mergedOutputFile)
+ in.getChannel.position(startOffset)
+ val lin = new LimitedInputStream(in, partitionSize)
Review comment:
I tackled this a bit differently, but a similar spirit of the idea. Please
see the updated logic.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]