This is an automated email from the ASF dual-hosted git repository.

rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 4ed9cbd2a [CELEBORN-2263] Fix IndexOutOfBoundsException while reading 
from S3
4ed9cbd2a is described below

commit 4ed9cbd2a6f2019b79f248e13cc99b14dc6e23e0
Author: Enrico Olivelli <[email protected]>
AuthorDate: Sun Feb 15 23:16:29 2026 +0800

    [CELEBORN-2263] Fix IndexOutOfBoundsException while reading from S3
    
    ### What changes were proposed in this pull request?
    
    Properly pass the size of the array to the InputStream that feeds the flush.
    
    ### Why are the changes needed?
    
    Because without this change if the array is bigger than the buffer, then 
the inputstream returns garbage, resulting in corrupted data on S3.
    
    ### Does this PR resolve a correctness bug?
    
    Yes.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New unit test + Manual testing.
    
    Closes #3600 from eolivelli/CELEBORN-2263-apache.
    
    Authored-by: Enrico Olivelli <[email protected]>
    Signed-off-by: Shuang <[email protected]>
---
 .../service/deploy/worker/storage/FlushTask.scala  |   6 +-
 .../deploy/worker/storage/FlushTaskSuite.scala     | 145 +++++++++++++++++++++
 2 files changed, 148 insertions(+), 3 deletions(-)

diff --git 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/FlushTask.scala
 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/FlushTask.scala
index 4effe460d..a0dd85717 100644
--- 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/FlushTask.scala
+++ 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/FlushTask.scala
@@ -41,7 +41,7 @@ abstract private[worker] class FlushTask(
       copyBytes: Array[Byte],
       length: Int): Array[Byte] = {
     if (copyBytes != null && copyBytes.length >= length) {
-      buffer.readBytes(copyBytes, 0, length)
+      buffer.getBytes(buffer.readerIndex(), copyBytes, 0, length)
       copyBytes
     } else {
       ByteBufUtil.getBytes(buffer)
@@ -144,7 +144,7 @@ private[worker] class S3FlushTask(
   override def flush(copyBytes: Array[Byte]): Unit = {
     val readableBytes = buffer.readableBytes()
     val bytes = convertBufferToBytes(buffer, copyBytes, readableBytes)
-    val inputStream = new ByteArrayInputStream(bytes)
+    val inputStream = new ByteArrayInputStream(bytes, 0, readableBytes)
     flush(inputStream) {
       s3MultipartUploader.putPart(inputStream, partNumber, finalFlush)
       source.incCounter(WorkerSource.S3_FLUSH_COUNT)
@@ -166,7 +166,7 @@ private[worker] class OssFlushTask(
   override def flush(copyBytes: Array[Byte]): Unit = {
     val readableBytes = buffer.readableBytes()
     val bytes = convertBufferToBytes(buffer, copyBytes, readableBytes)
-    val inputStream = new ByteArrayInputStream(bytes)
+    val inputStream = new ByteArrayInputStream(bytes, 0, readableBytes)
     flush(inputStream) {
       ossMultipartUploader.putPart(inputStream, partNumber, finalFlush)
       source.incCounter(WorkerSource.OSS_FLUSH_COUNT)
diff --git 
a/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/FlushTaskSuite.scala
 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/FlushTaskSuite.scala
new file mode 100644
index 000000000..da7e2b456
--- /dev/null
+++ 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/worker/storage/FlushTaskSuite.scala
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.service.deploy.worker.storage
+
+import java.io.ByteArrayInputStream
+
+import io.netty.buffer.{ByteBufAllocator, CompositeByteBuf, 
UnpooledByteBufAllocator}
+import org.apache.commons.io.IOUtils
+import org.mockito.ArgumentCaptor
+import org.mockito.ArgumentMatchersSugar.eqTo
+import org.mockito.MockitoSugar.{verify, _}
+import org.scalatest.prop.TableDrivenPropertyChecks.forAll
+import org.scalatest.prop.Tables.Table
+
+import org.apache.celeborn.CelebornFunSuite
+import org.apache.celeborn.common.metrics.source.AbstractSource
+import org.apache.celeborn.server.common.service.mpu.MultipartUploadHandler
+import org.apache.celeborn.service.deploy.worker.WorkerSource
+
+class FlushTaskSuite extends CelebornFunSuite {
+
+  private val ALLOCATOR: ByteBufAllocator = UnpooledByteBufAllocator.DEFAULT
+
+  test("OSSFlushTask flush should work with buffers of varous sizes") {
+    runTest(
+      (
+          mockBuffer: CompositeByteBuf,
+          mockNotifier: FlushNotifier,
+          keepBuffer: Boolean,
+          mockSource: AbstractSource,
+          mockMultipartUploader: MultipartUploadHandler,
+          partNumber: Int,
+          finalFlush: Boolean) =>
+        new OssFlushTask(
+          mockBuffer,
+          mockNotifier,
+          keepBuffer,
+          mockSource,
+          mockMultipartUploader,
+          partNumber,
+          finalFlush),
+      (mockSource: AbstractSource, expectedLength: Int) => {
+        verify(mockSource).incCounter(WorkerSource.OSS_FLUSH_COUNT)
+        verify(mockSource).incCounter(WorkerSource.OSS_FLUSH_SIZE, 
expectedLength)
+      })
+  }
+
+  test("SSFlushTask flush should work with buffers of varous sizes") {
+    runTest(
+      (
+          mockBuffer: CompositeByteBuf,
+          mockNotifier: FlushNotifier,
+          keepBuffer: Boolean,
+          mockSource: AbstractSource,
+          mockMultipartUploader: MultipartUploadHandler,
+          partNumber: Int,
+          finalFlush: Boolean) =>
+        new S3FlushTask(
+          mockBuffer,
+          mockNotifier,
+          keepBuffer,
+          mockSource,
+          mockMultipartUploader,
+          partNumber,
+          finalFlush),
+      (mockSource: AbstractSource, expectedLength: Int) => {
+        verify(mockSource).incCounter(WorkerSource.S3_FLUSH_COUNT)
+        verify(mockSource).incCounter(WorkerSource.S3_FLUSH_SIZE, 
expectedLength)
+      })
+  }
+
+  def runTest(
+      builder: (
+          CompositeByteBuf,
+          FlushNotifier,
+          Boolean,
+          AbstractSource,
+          MultipartUploadHandler,
+          Int,
+          Boolean) => DfsFlushTask,
+      metricsChecker: (AbstractSource, Int) => Unit) = {
+    val bytes = "another test data".getBytes("UTF-8")
+    val len = bytes.length
+
+    // Define the scenarios: (scenario name, size to allocate)
+    val scenarios = Table(
+      ("description", "allocatedSize"),
+      ("provider buffer is the same size as the buffer", len),
+      ("provider buffer is bigger", len + 10),
+      ("provider buffer smaller", len - 5))
+
+    forAll(scenarios) { (description, bufferSize) =>
+      val mockBuffer = spy(ALLOCATOR.compositeBuffer())
+      mockBuffer.writeBytes(bytes)
+      val mockNotifier = mock[FlushNotifier]
+      val mockSource = mock[AbstractSource]
+      val mockMultipartUploader = mock[MultipartUploadHandler]
+      val partNumber = 3
+      val finalFlush = false
+
+      val flushTask = builder(
+        mockBuffer,
+        mockNotifier,
+        false, // keepBuffer
+        mockSource,
+        mockMultipartUploader,
+        partNumber,
+        finalFlush)
+
+      val copyBytesArray = new Array[Byte](bufferSize)
+      flushTask.flush(copyBytesArray)
+
+      // buffer position is not moved
+      assert(mockBuffer.readableBytes() == bytes.length)
+
+      val streamCaptor = ArgumentCaptor.forClass(classOf[ByteArrayInputStream])
+      verify(mockMultipartUploader).putPart(
+        streamCaptor.capture(),
+        eqTo(partNumber),
+        eqTo(finalFlush))
+      metricsChecker(mockSource, bytes.length)
+
+      val capturedStream = streamCaptor.getValue
+      val capturedBytes = IOUtils.toByteArray(capturedStream);
+      assert(capturedBytes sameElements bytes, s"Content mismatch on: 
$description")
+
+      mockBuffer.release()
+    }
+  }
+}

Reply via email to