kazuyukitanimura commented on code in PR #1192: URL: https://github.com/apache/datafusion-comet/pull/1192#discussion_r1902145561
########## common/src/main/scala/org/apache/comet/CometConf.scala: ########## @@ -272,18 +272,19 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(false) - val COMET_EXEC_SHUFFLE_COMPRESSION_CODEC: ConfigEntry[String] = conf( - s"$COMET_EXEC_CONFIG_PREFIX.shuffle.compression.codec") - .doc( - "The codec of Comet native shuffle used to compress shuffle data. Only zstd is supported. " + - "Compression can be disabled by setting spark.shuffle.compress=false.") - .stringConf - .checkValues(Set("zstd")) - .createWithDefault("zstd") + val COMET_EXEC_SHUFFLE_COMPRESSION_CODEC: ConfigEntry[String] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.compression.codec") + .doc( + "The codec of Comet native shuffle used to compress shuffle data. lz4, zstd, and " + + "snappy are supported. Compression can be disabled by setting " + + "spark.shuffle.compress=false.") + .stringConf + .checkValues(Set("zstd", "lz4", "snappy")) + .createWithDefault("lz4") val COMET_EXEC_SHUFFLE_COMPRESSION_LEVEL: ConfigEntry[Int] = conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.compression.level") Review Comment: nit: if this is only for zstd, this config name should be zstd specific? ########## native/core/src/execution/shuffle/shuffle_writer.rs: ########## @@ -1567,17 +1585,41 @@ pub fn write_ipc_compressed<W: Write + Seek>( let mut timer = ipc_time.timer(); let start_pos = output.stream_position()?; - // write ipc_length placeholder - output.write_all(&[0u8; 8])?; + // seek past ipc_length placeholder + output.seek_relative(8)?; + + // write number of columns because JVM side needs to know how many addresses to allocate + let field_count = batch.schema().fields().len(); + output.write_all(&field_count.to_le_bytes())?; Review Comment: Just for me to understand, this was not previously written. Where is this read? ########## native/core/src/execution/shuffle/shuffle_writer.rs: ########## @@ -1567,17 +1585,41 @@ pub fn write_ipc_compressed<W: Write + Seek>( let mut timer = ipc_time.timer(); let start_pos = output.stream_position()?; - // write ipc_length placeholder - output.write_all(&[0u8; 8])?; + // seek past ipc_length placeholder + output.seek_relative(8)?; Review Comment: Are these skipped 8 bytes guaranteed to be zero or we do not have to worry about the existing contents? ########## spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala: ########## @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.shuffle + +import java.io.{EOFException, InputStream} +import java.nio.{ByteBuffer, ByteOrder} +import java.nio.channels.{Channels, ReadableByteChannel} + +import org.apache.spark.TaskContext +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.vectorized.ColumnarBatch + +import org.apache.comet.Native +import org.apache.comet.vector.NativeUtil + +/** + * This iterator wraps a Spark input stream that is reading shuffle blocks generated by the Comet + * native ShuffleWriterExec and then calls native code to decompress and decode the shuffle blocks + * and use Arrow FFI to return the Arrow record batch. + */ +case class NativeBatchDecoderIterator( + var in: InputStream, + taskContext: TaskContext, + decodeTime: SQLMetric) + extends Iterator[ColumnarBatch] { + + private var isClosed = false + private val longBuf = ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN) + private val native = new Native() + private val nativeUtil = new NativeUtil() + private var currentBatch: ColumnarBatch = null + private var batch = fetchNext() + + import NativeBatchDecoderIterator.threadLocalDataBuf + + if (taskContext != null) { + taskContext.addTaskCompletionListener[Unit](_ => { + close() + }) + } + + private val channel: ReadableByteChannel = if (in != null) { + Channels.newChannel(in) + } else { + null + } + + def hasNext(): Boolean = { + if (channel == null || isClosed) { + return false + } + if (batch.isDefined) { + return true + } + + // Release the previous batch. + if (currentBatch != null) { + currentBatch.close() + currentBatch = null + } + + batch = fetchNext() + if (batch.isEmpty) { + close() + return false + } + true + } + + def next(): ColumnarBatch = { + if (!hasNext) { + throw new NoSuchElementException + } + + val nextBatch = batch.get + + currentBatch = nextBatch + batch = None + currentBatch + } + + private def fetchNext(): Option[ColumnarBatch] = { + if (channel == null || isClosed) { + return None + } + + // read compressed batch size from header + try { + longBuf.clear() + while (longBuf.hasRemaining && channel.read(longBuf) >= 0) {} + } catch { + case _: EOFException => + close() + return None + } + + // If we reach the end of the stream, we are done, or if we read partial length + // then the stream is corrupted. + if (longBuf.hasRemaining) { + if (longBuf.position() == 0) { + close() + return None + } + throw new EOFException("Data corrupt: unexpected EOF while reading compressed ipc lengths") + } + + // get compressed length (including headers) + longBuf.flip() + val compressedLength = longBuf.getLong.toInt + + // read field count from header + longBuf.clear() + while (longBuf.hasRemaining && channel.read(longBuf) >= 0) {} + if (longBuf.hasRemaining) { + throw new EOFException("Data corrupt: unexpected EOF while reading field count") + } + longBuf.flip() + val fieldCount = longBuf.getLong.toInt + + // read body + val bytesToRead = compressedLength - 8 + var dataBuf = threadLocalDataBuf.get() + if (dataBuf.capacity() < bytesToRead) { + val newCapacity = bytesToRead * 2 Review Comment: Can this ever overflow by `* 2`? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org