This is an automated email from the ASF dual-hosted git repository.

hongshun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/fluss.git


The following commit(s) were added to refs/heads/main by this push:
     new 9ec715290 [client]Fix memory leak if oom when decompress data in 
VectorLoader. (#2647)
9ec715290 is described below

commit 9ec715290590b5b0c15e759884e62d0715d5f358
Author: Hongshun Wang <[email protected]>
AuthorDate: Thu Feb 26 15:52:56 2026 +0800

    [client]Fix memory leak if oom when decompress data in VectorLoader. (#2647)
---
 .../org/apache/fluss/record/FlussVectorLoader.java | 160 +++++++++++++
 .../java/org/apache/fluss/utils/ArrowUtils.java    |   6 +-
 .../apache/fluss/record/FlussVectorLoaderTest.java | 256 +++++++++++++++++++++
 3 files changed, 419 insertions(+), 3 deletions(-)

diff --git 
a/fluss-common/src/main/java/org/apache/fluss/record/FlussVectorLoader.java 
b/fluss-common/src/main/java/org/apache/fluss/record/FlussVectorLoader.java
new file mode 100644
index 000000000..b37cdbdb3
--- /dev/null
+++ b/fluss-common/src/main/java/org/apache/fluss/record/FlussVectorLoader.java
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.fluss.record;
+
+import org.apache.fluss.shaded.arrow.org.apache.arrow.memory.ArrowBuf;
+import org.apache.fluss.shaded.arrow.org.apache.arrow.util.Collections2;
+import org.apache.fluss.shaded.arrow.org.apache.arrow.util.Preconditions;
+import org.apache.fluss.shaded.arrow.org.apache.arrow.vector.FieldVector;
+import org.apache.fluss.shaded.arrow.org.apache.arrow.vector.TypeLayout;
+import org.apache.fluss.shaded.arrow.org.apache.arrow.vector.VectorSchemaRoot;
+import 
org.apache.fluss.shaded.arrow.org.apache.arrow.vector.compression.CompressionCodec;
+import 
org.apache.fluss.shaded.arrow.org.apache.arrow.vector.compression.CompressionUtil;
+import 
org.apache.fluss.shaded.arrow.org.apache.arrow.vector.compression.CompressionUtil.CodecType;
+import 
org.apache.fluss.shaded.arrow.org.apache.arrow.vector.compression.NoCompressionCodec;
+import 
org.apache.fluss.shaded.arrow.org.apache.arrow.vector.ipc.message.ArrowFieldNode;
+import 
org.apache.fluss.shaded.arrow.org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
+import org.apache.fluss.shaded.arrow.org.apache.arrow.vector.types.pojo.Field;
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+
+/**
+ * A patched version of Arrow's {@code VectorLoader} that ensures decompressed 
buffers are properly
+ * released when an error (e.g. OOM) occurs during {@link 
#load(ArrowRecordBatch)}.
+ *
+ * <p>In the original Arrow implementation, the decompression loop runs 
<b>outside</b> the
+ * try-finally block that guards {@code loadFieldBuffers}. This means if 
decompression succeeds for
+ * the first N buffers of a field but fails on the (N+1)-th buffer, the 
already-decompressed buffers
+ * in {@code ownBuffers} are never closed, leaking Direct Memory.
+ *
+ * <p>This workaround moves the decompression loop <b>inside</b> the try block 
so that the finally
+ * clause always closes every buffer in {@code ownBuffers}, regardless of 
whether the load succeeds
+ * or fails:
+ *
+ * <ul>
+ *   <li><b>Success path:</b> {@code loadFieldBuffers} retains each buffer 
(ref count +1), then the
+ *       finally close decrements it back (ref count -1). The field vector 
still holds the buffer.
+ *   <li><b>Error path:</b> The finally close decrements each 
already-decompressed buffer's ref
+ *       count to 0, immediately freeing the Direct Memory.
+ * </ul>
+ *
+ * <p>TODO: This class should be removed once Apache Arrow fixes the buffer 
leak in their {@code
+ * VectorLoader.loadBuffers()} method. See:
+ *
+ * <ul>
+ *   <li>Apache Arrow issue: <a
+ *       
href="https://github.com/apache/arrow-java/issues/1037";>arrow-java#1037</a>
+ *   <li>Fluss issue: <a 
href="https://github.com/apache/fluss/issues/2646";>FLUSS-2646</a>
+ * </ul>
+ */
+public class FlussVectorLoader {
+    private final VectorSchemaRoot root;
+    private final CompressionCodec.Factory factory;
+    private boolean decompressionNeeded;
+
+    public FlussVectorLoader(VectorSchemaRoot root, CompressionCodec.Factory 
factory) {
+        this.root = root;
+        this.factory = factory;
+    }
+
+    public void load(ArrowRecordBatch recordBatch) {
+        Iterator<ArrowBuf> buffers = recordBatch.getBuffers().iterator();
+        Iterator<ArrowFieldNode> nodes = recordBatch.getNodes().iterator();
+        CompressionUtil.CodecType codecType =
+                
CodecType.fromCompressionType(recordBatch.getBodyCompression().getCodec());
+        this.decompressionNeeded = codecType != CodecType.NO_COMPRESSION;
+        CompressionCodec codec =
+                this.decompressionNeeded
+                        ? this.factory.createCodec(codecType)
+                        : NoCompressionCodec.INSTANCE;
+
+        for (FieldVector fieldVector : this.root.getFieldVectors()) {
+            this.loadBuffers(fieldVector, fieldVector.getField(), buffers, 
nodes, codec);
+        }
+
+        this.root.setRowCount(recordBatch.getLength());
+        if (nodes.hasNext() || buffers.hasNext()) {
+            throw new IllegalArgumentException(
+                    "not all nodes and buffers were consumed. nodes: "
+                            + Collections2.toString(nodes)
+                            + " buffers: "
+                            + Collections2.toString(buffers));
+        }
+    }
+
+    private void loadBuffers(
+            FieldVector vector,
+            Field field,
+            Iterator<ArrowBuf> buffers,
+            Iterator<ArrowFieldNode> nodes,
+            CompressionCodec codec) {
+        Preconditions.checkArgument(
+                nodes.hasNext(), "no more field nodes for field %s and vector 
%s", field, vector);
+        ArrowFieldNode fieldNode = nodes.next();
+        int bufferLayoutCount = TypeLayout.getTypeBufferCount(field.getType());
+        List<ArrowBuf> ownBuffers = new ArrayList<>(bufferLayoutCount);
+
+        try {
+            for (int j = 0; j < bufferLayoutCount; ++j) {
+                ArrowBuf nextBuf = buffers.next();
+                ArrowBuf bufferToAdd =
+                        nextBuf.writerIndex() > 0L
+                                ? codec.decompress(vector.getAllocator(), 
nextBuf)
+                                : nextBuf;
+                ownBuffers.add(bufferToAdd);
+                if (this.decompressionNeeded) {
+                    nextBuf.getReferenceManager().retain();
+                }
+            }
+            vector.loadFieldBuffers(fieldNode, ownBuffers);
+        } catch (RuntimeException e) {
+            throw new IllegalArgumentException(
+                    "Could not load buffers for field "
+                            + field
+                            + ". error message: "
+                            + e.getMessage(),
+                    e);
+        } finally {
+            if (this.decompressionNeeded) {
+                for (ArrowBuf buf : ownBuffers) {
+                    buf.close();
+                }
+            }
+        }
+
+        List<Field> children = field.getChildren();
+        if (!children.isEmpty()) {
+            List<FieldVector> childrenFromFields = 
vector.getChildrenFromFields();
+            Preconditions.checkArgument(
+                    children.size() == childrenFromFields.size(),
+                    "should have as many children as in the schema: found %s 
expected %s",
+                    childrenFromFields.size(),
+                    children.size());
+
+            for (int i = 0; i < childrenFromFields.size(); ++i) {
+                Field child = children.get(i);
+                FieldVector fieldVector = childrenFromFields.get(i);
+                this.loadBuffers(fieldVector, child, buffers, nodes, codec);
+            }
+        }
+    }
+}
diff --git a/fluss-common/src/main/java/org/apache/fluss/utils/ArrowUtils.java 
b/fluss-common/src/main/java/org/apache/fluss/utils/ArrowUtils.java
index 44e3b7fea..4219d45ff 100644
--- a/fluss-common/src/main/java/org/apache/fluss/utils/ArrowUtils.java
+++ b/fluss-common/src/main/java/org/apache/fluss/utils/ArrowUtils.java
@@ -21,6 +21,7 @@ import org.apache.fluss.annotation.Internal;
 import org.apache.fluss.compression.ArrowCompressionFactory;
 import org.apache.fluss.exception.FlussRuntimeException;
 import org.apache.fluss.memory.MemorySegment;
+import org.apache.fluss.record.FlussVectorLoader;
 import org.apache.fluss.row.arrow.ArrowReader;
 import org.apache.fluss.row.arrow.vectors.ArrowArrayColumnVector;
 import org.apache.fluss.row.arrow.vectors.ArrowBigIntColumnVector;
@@ -86,7 +87,6 @@ import 
org.apache.fluss.shaded.arrow.org.apache.arrow.vector.TypeLayout;
 import org.apache.fluss.shaded.arrow.org.apache.arrow.vector.ValueVector;
 import org.apache.fluss.shaded.arrow.org.apache.arrow.vector.VarBinaryVector;
 import org.apache.fluss.shaded.arrow.org.apache.arrow.vector.VarCharVector;
-import org.apache.fluss.shaded.arrow.org.apache.arrow.vector.VectorLoader;
 import org.apache.fluss.shaded.arrow.org.apache.arrow.vector.VectorSchemaRoot;
 import 
org.apache.fluss.shaded.arrow.org.apache.arrow.vector.complex.ListVector;
 import org.apache.fluss.shaded.arrow.org.apache.arrow.vector.complex.MapVector;
@@ -172,8 +172,8 @@ public class ArrowUtils {
         try (ReadChannel channel =
                         new ReadChannel(new 
ByteBufferReadableChannel(arrowBatchBuffer));
                 ArrowRecordBatch batch = deserializeRecordBatch(channel, 
allocator)) {
-            VectorLoader vectorLoader =
-                    new VectorLoader(schemaRoot, 
ArrowCompressionFactory.INSTANCE);
+            FlussVectorLoader vectorLoader =
+                    new FlussVectorLoader(schemaRoot, 
ArrowCompressionFactory.INSTANCE);
             vectorLoader.load(batch);
             List<ColumnVector> columnVectors = new ArrayList<>();
             List<FieldVector> fieldVectors = schemaRoot.getFieldVectors();
diff --git 
a/fluss-common/src/test/java/org/apache/fluss/record/FlussVectorLoaderTest.java 
b/fluss-common/src/test/java/org/apache/fluss/record/FlussVectorLoaderTest.java
new file mode 100644
index 000000000..d963bda31
--- /dev/null
+++ 
b/fluss-common/src/test/java/org/apache/fluss/record/FlussVectorLoaderTest.java
@@ -0,0 +1,256 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.fluss.record;
+
+import org.apache.fluss.compression.ArrowCompressionFactory;
+import org.apache.fluss.compression.ZstdArrowCompressionCodec;
+import org.apache.fluss.memory.MemorySegment;
+import org.apache.fluss.shaded.arrow.org.apache.arrow.memory.ArrowBuf;
+import org.apache.fluss.shaded.arrow.org.apache.arrow.memory.BufferAllocator;
+import org.apache.fluss.shaded.arrow.org.apache.arrow.memory.RootAllocator;
+import org.apache.fluss.shaded.arrow.org.apache.arrow.vector.IntVector;
+import org.apache.fluss.shaded.arrow.org.apache.arrow.vector.VarCharVector;
+import org.apache.fluss.shaded.arrow.org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.fluss.shaded.arrow.org.apache.arrow.vector.VectorUnloader;
+import 
org.apache.fluss.shaded.arrow.org.apache.arrow.vector.compression.CompressionCodec;
+import 
org.apache.fluss.shaded.arrow.org.apache.arrow.vector.compression.CompressionUtil;
+import org.apache.fluss.shaded.arrow.org.apache.arrow.vector.ipc.ReadChannel;
+import org.apache.fluss.shaded.arrow.org.apache.arrow.vector.ipc.WriteChannel;
+import 
org.apache.fluss.shaded.arrow.org.apache.arrow.vector.ipc.message.ArrowBlock;
+import 
org.apache.fluss.shaded.arrow.org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
+import 
org.apache.fluss.shaded.arrow.org.apache.arrow.vector.ipc.message.MessageSerializer;
+import 
org.apache.fluss.shaded.arrow.org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.fluss.shaded.arrow.org.apache.arrow.vector.types.pojo.Field;
+import org.apache.fluss.shaded.arrow.org.apache.arrow.vector.types.pojo.Schema;
+import org.apache.fluss.utils.ByteBufferReadableChannel;
+import org.apache.fluss.utils.MemorySegmentWritableChannel;
+
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import java.nio.ByteBuffer;
+import java.nio.charset.StandardCharsets;
+import java.util.Arrays;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.BiConsumer;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/**
+ * Tests for the patched {@link FlussVectorLoader} that ensures decompressed 
buffers are released
+ * when an error occurs during loading.
+ *
+ * @see <a href="https://github.com/apache/fluss/issues/2646";>FLUSS-2646</a>
+ */
+class FlussVectorLoaderTest {
+
+    private static final Schema SCHEMA =
+            new Schema(
+                    Arrays.asList(
+                            Field.nullable("ints", new ArrowType.Int(32, 
true)),
+                            Field.nullable("strings", 
ArrowType.Utf8.INSTANCE)));
+
+    private BufferAllocator allocator;
+
+    @BeforeEach
+    void setup() {
+        allocator = new RootAllocator(Long.MAX_VALUE);
+    }
+
+    @AfterEach
+    void tearDown() {
+        allocator.close();
+    }
+
+    /** Tests that normal load with compression succeeds and data is correctly 
decompressed. */
+    @Test
+    void testNormalLoadWithCompression() throws Exception {
+        testLoadWithCompression(
+                (schemaRoot, arrowRecordBatch) -> {
+                    FlussVectorLoader loader =
+                            new FlussVectorLoader(schemaRoot, 
ArrowCompressionFactory.INSTANCE);
+                    loader.load(arrowRecordBatch);
+                    assertThat(schemaRoot.getRowCount()).isEqualTo(10);
+                    IntVector readInts = (IntVector) schemaRoot.getVector(0);
+                    VarCharVector readStrings = (VarCharVector) 
schemaRoot.getVector(1);
+                    for (int i = 0; i < 10; i++) {
+                        assertThat(readInts.get(i)).isEqualTo(i);
+                        assertThat(readStrings.get(i))
+                                .isEqualTo(("value_" + 
i).getBytes(StandardCharsets.UTF_8));
+                    }
+                });
+    }
+
+    /**
+     * Tests that when decompression fails mid-way through loading, the 
already-decompressed buffers
+     * are properly released by the try-finally block in VectorLoader.
+     *
+     * <p>This verifies the fix that moved the decompression loop inside the 
try block. Without the
+     * fix, decompressed buffers created before the failure would be leaked.
+     *
+     * <p>The test uses a tracking codec that retains an extra reference on 
each decompressed buffer
+     * for inspection. After VectorLoader's finally block closes the buffers, 
only our tracking
+     * retain should remain (ref count = 1). Without the fix, the ref count 
would be 2.
+     */
+    @Test
+    void testLoadReleasesDecompressedBuffersOnDecompressionFailure() throws 
Exception {
+        // Fail on the 4th decompress call.
+        // "ints" has 2 buffers (both succeed), "strings" validity buffer 
succeeds,
+        // "ints" offset buffer (2nd) fails.
+        testLoadWithCompression(
+                (schemaRoot, arrowRecordBatch) -> {
+                    CompressionCodec.Factory failingFactory = new 
TrackingFailCodecFactory(2);
+                    FlussVectorLoader loader = new 
FlussVectorLoader(schemaRoot, failingFactory);
+                    assertThatThrownBy(() -> loader.load(arrowRecordBatch))
+                            .isInstanceOf(IllegalArgumentException.class)
+                            .hasMessageContaining("Could not load buffers")
+                            .hasMessageContaining(
+                                    "error message: Simulated OOM on 
decompress call");
+                });
+    }
+
+    /**
+     * Tests that when decompression fails for the second field, decompressed 
buffers from the first
+     * field (already loaded into the VectorSchemaRoot) are properly managed 
and can be released via
+     * VectorSchemaRoot.close().
+     */
+    @Test
+    void testLoadReleasesBuffersFromPreviousFieldOnFailure() throws Exception {
+        // Fail on the 4th decompress call.
+        // "ints" has 2 buffers (both succeed), "strings" validity buffer 
succeeds,
+        // "strings" offset buffer (4th) fails.
+        testLoadWithCompression(
+                (schemaRoot, arrowRecordBatch) -> {
+                    CompressionCodec.Factory failingFactory = new 
TrackingFailCodecFactory(4);
+                    FlussVectorLoader loader = new 
FlussVectorLoader(schemaRoot, failingFactory);
+                    assertThatThrownBy(() -> loader.load(arrowRecordBatch))
+                            .isInstanceOf(IllegalArgumentException.class)
+                            .hasMessageContaining("Could not load buffers")
+                            .hasMessageContaining(
+                                    "error message: Simulated OOM on 
decompress call");
+                });
+    }
+
+    void testLoadWithCompression(BiConsumer<VectorSchemaRoot, 
ArrowRecordBatch> loadAction)
+            throws Exception {
+        MemorySegment memorySegment = MemorySegment.allocateHeapMemory(1024 * 
10);
+        MemorySegmentWritableChannel memorySegmentWritableChannel =
+                new MemorySegmentWritableChannel(memorySegment);
+        ArrowBlock arrowBlock;
+        try (VectorSchemaRoot root = VectorSchemaRoot.create(SCHEMA, 
allocator)) {
+            populateData(root, 10);
+            try (ArrowRecordBatch arrowBatch =
+                    new VectorUnloader(root, true, new 
ZstdArrowCompressionCodec(), true)
+                            .getRecordBatch()) {
+                arrowBlock =
+                        MessageSerializer.serialize(
+                                new 
WriteChannel(memorySegmentWritableChannel), arrowBatch);
+            }
+        }
+
+        ByteBuffer arrowBatchBuffer =
+                memorySegment.wrap(
+                        (int) arrowBlock.getOffset(),
+                        (int) arrowBlock.getBodyLength() + 
arrowBlock.getMetadataLength());
+        ReadChannel channel = new ReadChannel(new 
ByteBufferReadableChannel(arrowBatchBuffer));
+        try (VectorSchemaRoot readRoot = VectorSchemaRoot.create(SCHEMA, 
allocator);
+                ArrowRecordBatch arrowMessage =
+                        MessageSerializer.deserializeRecordBatch(channel, 
allocator)) {
+            loadAction.accept(readRoot, arrowMessage);
+        }
+    }
+
+    private static void populateData(VectorSchemaRoot root, int rowCount) {
+        IntVector ints = (IntVector) root.getVector(0);
+        VarCharVector strings = (VarCharVector) root.getVector(1);
+        for (int i = 0; i < rowCount; i++) {
+            ints.setSafe(i, i);
+            strings.setSafe(i, ("value_" + 
i).getBytes(StandardCharsets.UTF_8));
+        }
+        root.setRowCount(rowCount);
+    }
+
+    // 
-----------------------------------------------------------------------------------------
+    // Test helpers
+    // 
-----------------------------------------------------------------------------------------
+
+    /**
+     * A {@link CompressionCodec.Factory} that creates a {@link 
TrackingFailCodec} which tracks
+     * decompressed buffers and fails on the Nth decompress call.
+     */
+    private static class TrackingFailCodecFactory implements 
CompressionCodec.Factory {
+        private final int failOnNthDecompress;
+
+        TrackingFailCodecFactory(int failOnNthDecompress) {
+            this.failOnNthDecompress = failOnNthDecompress;
+        }
+
+        @Override
+        public CompressionCodec createCodec(CompressionUtil.CodecType 
codecType) {
+            return new TrackingFailCodec(
+                    ArrowCompressionFactory.INSTANCE.createCodec(codecType), 
failOnNthDecompress);
+        }
+
+        @Override
+        public CompressionCodec createCodec(
+                CompressionUtil.CodecType codecType, int compressionLevel) {
+            return new TrackingFailCodec(
+                    ArrowCompressionFactory.INSTANCE.createCodec(codecType, 
compressionLevel),
+                    failOnNthDecompress);
+        }
+    }
+
+    /**
+     * A {@link CompressionCodec} that wraps a real codec but:
+     *
+     * <ol>
+     *   <li>Retains an extra reference on each decompressed buffer for 
tracking.
+     *   <li>Throws a simulated OOM on the Nth decompress call.
+     * </ol>
+     */
+    private static class TrackingFailCodec implements CompressionCodec {
+        private final CompressionCodec delegate;
+        private final int failOnN;
+        private final AtomicInteger count = new AtomicInteger(0);
+
+        TrackingFailCodec(CompressionCodec delegate, int failOnN) {
+            this.delegate = delegate;
+            this.failOnN = failOnN;
+        }
+
+        @Override
+        public ArrowBuf decompress(BufferAllocator allocator, ArrowBuf 
compressedBuffer) {
+            if (count.incrementAndGet() >= failOnN) {
+                throw new RuntimeException("Simulated OOM on decompress call 
#" + count.get());
+            }
+            return delegate.decompress(allocator, compressedBuffer);
+        }
+
+        @Override
+        public ArrowBuf compress(BufferAllocator allocator, ArrowBuf 
uncompressedBuffer) {
+            return delegate.compress(allocator, uncompressedBuffer);
+        }
+
+        @Override
+        public CompressionUtil.CodecType getCodecType() {
+            return delegate.getCodecType();
+        }
+    }
+}

Reply via email to