This is an automated email from the ASF dual-hosted git repository.
rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 4ed9cbd2a [CELEBORN-2263] Fix IndexOutOfBoundsException while reading
from S3
4ed9cbd2a is described below
commit 4ed9cbd2a6f2019b79f248e13cc99b14dc6e23e0
Author: Enrico Olivelli <[email protected]>
AuthorDate: Sun Feb 15 23:16:29 2026 +0800
[CELEBORN-2263] Fix IndexOutOfBoundsException while reading from S3
### What changes were proposed in this pull request?
Properly pass the size of the array to the InputStream that feeds the flush.
### Why are the changes needed?
Because without this change if the array is bigger than the buffer, then
the inputstream returns garbage, resulting in corrupted data on S3.
### Does this PR resolve a correctness bug?
Yes.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
New unit test + Manual testing.
Closes #3600 from eolivelli/CELEBORN-2263-apache.
Authored-by: Enrico Olivelli <[email protected]>
Signed-off-by: Shuang <[email protected]>
---
.../service/deploy/worker/storage/FlushTask.scala | 6 +-
.../deploy/worker/storage/FlushTaskSuite.scala | 145 +++++++++++++++++++++
2 files changed, 148 insertions(+), 3 deletions(-)
diff --git
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/FlushTask.scala
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/FlushTask.scala
index 4effe460d..a0dd85717 100644
---
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/FlushTask.scala
+++
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/FlushTask.scala
@@ -41,7 +41,7 @@ abstract private[worker] class FlushTask(
copyBytes: Array[Byte],
length: Int): Array[Byte] = {
if (copyBytes != null && copyBytes.length >= length) {
- buffer.readBytes(copyBytes, 0, length)
+ buffer.getBytes(buffer.readerIndex(), copyBytes, 0, length)
copyBytes
} else {
ByteBufUtil.getBytes(buffer)
@@ -144,7 +144,7 @@ private[worker] class S3FlushTask(
override def flush(copyBytes: Array[Byte]): Unit = {
val readableBytes = buffer.readableBytes()
val bytes = convertBufferToBytes(buffer, copyBytes, readableBytes)
- val inputStream = new ByteArrayInputStream(bytes)
+ val inputStream = new ByteArrayInputStream(bytes, 0, readableBytes)
flush(inputStream) {
s3MultipartUploader.putPart(inputStream, partNumber, finalFlush)
source.incCounter(WorkerSource.S3_FLUSH_COUNT)
@@ -166,7 +166,7 @@ private[worker] class OssFlushTask(
override def flush(copyBytes: Array[Byte]): Unit = {
val readableBytes = buffer.readableBytes()
val bytes = convertBufferToBytes(buffer, copyBytes, readableBytes)
- val inputStream = new ByteArrayInputStream(bytes)
+ val inputStream = new ByteArrayInputStream(bytes, 0, readableBytes)
flush(inputStream) {
ossMultipartUploader.putPart(inputStream, partNumber, finalFlush)
source.incCounter(WorkerSource.OSS_FLUSH_COUNT)
diff --git
a/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/FlushTaskSuite.scala
b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/FlushTaskSuite.scala
new file mode 100644
index 000000000..da7e2b456
--- /dev/null
+++
b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/FlushTaskSuite.scala
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.service.deploy.worker.storage
+
+import java.io.ByteArrayInputStream
+
+import io.netty.buffer.{ByteBufAllocator, CompositeByteBuf,
UnpooledByteBufAllocator}
+import org.apache.commons.io.IOUtils
+import org.mockito.ArgumentCaptor
+import org.mockito.ArgumentMatchersSugar.eqTo
+import org.mockito.MockitoSugar.{verify, _}
+import org.scalatest.prop.TableDrivenPropertyChecks.forAll
+import org.scalatest.prop.Tables.Table
+
+import org.apache.celeborn.CelebornFunSuite
+import org.apache.celeborn.common.metrics.source.AbstractSource
+import org.apache.celeborn.server.common.service.mpu.MultipartUploadHandler
+import org.apache.celeborn.service.deploy.worker.WorkerSource
+
+class FlushTaskSuite extends CelebornFunSuite {
+
+ private val ALLOCATOR: ByteBufAllocator = UnpooledByteBufAllocator.DEFAULT
+
+ test("OSSFlushTask flush should work with buffers of varous sizes") {
+ runTest(
+ (
+ mockBuffer: CompositeByteBuf,
+ mockNotifier: FlushNotifier,
+ keepBuffer: Boolean,
+ mockSource: AbstractSource,
+ mockMultipartUploader: MultipartUploadHandler,
+ partNumber: Int,
+ finalFlush: Boolean) =>
+ new OssFlushTask(
+ mockBuffer,
+ mockNotifier,
+ keepBuffer,
+ mockSource,
+ mockMultipartUploader,
+ partNumber,
+ finalFlush),
+ (mockSource: AbstractSource, expectedLength: Int) => {
+ verify(mockSource).incCounter(WorkerSource.OSS_FLUSH_COUNT)
+ verify(mockSource).incCounter(WorkerSource.OSS_FLUSH_SIZE,
expectedLength)
+ })
+ }
+
+ test("SSFlushTask flush should work with buffers of varous sizes") {
+ runTest(
+ (
+ mockBuffer: CompositeByteBuf,
+ mockNotifier: FlushNotifier,
+ keepBuffer: Boolean,
+ mockSource: AbstractSource,
+ mockMultipartUploader: MultipartUploadHandler,
+ partNumber: Int,
+ finalFlush: Boolean) =>
+ new S3FlushTask(
+ mockBuffer,
+ mockNotifier,
+ keepBuffer,
+ mockSource,
+ mockMultipartUploader,
+ partNumber,
+ finalFlush),
+ (mockSource: AbstractSource, expectedLength: Int) => {
+ verify(mockSource).incCounter(WorkerSource.S3_FLUSH_COUNT)
+ verify(mockSource).incCounter(WorkerSource.S3_FLUSH_SIZE,
expectedLength)
+ })
+ }
+
+ def runTest(
+ builder: (
+ CompositeByteBuf,
+ FlushNotifier,
+ Boolean,
+ AbstractSource,
+ MultipartUploadHandler,
+ Int,
+ Boolean) => DfsFlushTask,
+ metricsChecker: (AbstractSource, Int) => Unit) = {
+ val bytes = "another test data".getBytes("UTF-8")
+ val len = bytes.length
+
+ // Define the scenarios: (scenario name, size to allocate)
+ val scenarios = Table(
+ ("description", "allocatedSize"),
+ ("provider buffer is the same size as the buffer", len),
+ ("provider buffer is bigger", len + 10),
+ ("provider buffer smaller", len - 5))
+
+ forAll(scenarios) { (description, bufferSize) =>
+ val mockBuffer = spy(ALLOCATOR.compositeBuffer())
+ mockBuffer.writeBytes(bytes)
+ val mockNotifier = mock[FlushNotifier]
+ val mockSource = mock[AbstractSource]
+ val mockMultipartUploader = mock[MultipartUploadHandler]
+ val partNumber = 3
+ val finalFlush = false
+
+ val flushTask = builder(
+ mockBuffer,
+ mockNotifier,
+ false, // keepBuffer
+ mockSource,
+ mockMultipartUploader,
+ partNumber,
+ finalFlush)
+
+ val copyBytesArray = new Array[Byte](bufferSize)
+ flushTask.flush(copyBytesArray)
+
+ // buffer position is not moved
+ assert(mockBuffer.readableBytes() == bytes.length)
+
+ val streamCaptor = ArgumentCaptor.forClass(classOf[ByteArrayInputStream])
+ verify(mockMultipartUploader).putPart(
+ streamCaptor.capture(),
+ eqTo(partNumber),
+ eqTo(finalFlush))
+ metricsChecker(mockSource, bytes.length)
+
+ val capturedStream = streamCaptor.getValue
+ val capturedBytes = IOUtils.toByteArray(capturedStream);
+ assert(capturedBytes sameElements bytes, s"Content mismatch on:
$description")
+
+ mockBuffer.release()
+ }
+ }
+}