This is an automated email from the ASF dual-hosted git repository.

Gabriel39 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 6bf4327a7a5 [fix](maxcompute) Estimate write block size from Arrow 
buffers, not per-row serialization (#64612)
6bf4327a7a5 is described below

commit 6bf4327a7a58bcff9d2b869869f7ae3067a0f80f
Author: daidai <[email protected]>
AuthorDate: Wed Jun 24 16:20:02 2026 +0800

    [fix](maxcompute) Estimate write block size from Arrow buffers, not per-row 
serialization (#64612)
    
    The old per-row estimateSingleRowPayloadBytes ZSTD-serialized a one-row
    batch for every row (CPU-heavy and ~25x oversized); sum
    FieldVector.getBufferSize() over the whole batch instead, and rotate the
    block lazily.
---
 .../doris/maxcompute/MaxComputeJniWriter.java      | 184 ++++++++++++++-------
 .../doris/maxcompute/MaxComputeJniWriterTest.java  | 133 +++++++++++++++
 2 files changed, 258 insertions(+), 59 deletions(-)

diff --git 
a/fe/be-java-extensions/max-compute-connector/src/main/java/org/apache/doris/maxcompute/MaxComputeJniWriter.java
 
b/fe/be-java-extensions/max-compute-connector/src/main/java/org/apache/doris/maxcompute/MaxComputeJniWriter.java
index 9788184057e..ecb01d9092f 100644
--- 
a/fe/be-java-extensions/max-compute-connector/src/main/java/org/apache/doris/maxcompute/MaxComputeJniWriter.java
+++ 
b/fe/be-java-extensions/max-compute-connector/src/main/java/org/apache/doris/maxcompute/MaxComputeJniWriter.java
@@ -25,8 +25,6 @@ import org.apache.doris.common.maxcompute.MCUtils;
 
 import com.aliyun.odps.Odps;
 import com.aliyun.odps.OdpsType;
-import com.aliyun.odps.table.arrow.ArrowWriter;
-import com.aliyun.odps.table.arrow.ArrowWriterFactory;
 import com.aliyun.odps.table.configuration.ArrowOptions;
 import com.aliyun.odps.table.configuration.ArrowOptions.TimestampUnit;
 import com.aliyun.odps.table.configuration.CompressionCodec;
@@ -67,7 +65,6 @@ import org.apache.log4j.Logger;
 import java.io.ByteArrayOutputStream;
 import java.io.IOException;
 import java.io.ObjectOutputStream;
-import java.io.OutputStream;
 import java.math.BigDecimal;
 import java.nio.charset.StandardCharsets;
 import java.time.LocalDate;
@@ -125,6 +122,10 @@ public class MaxComputeJniWriter extends JniWriter {
     private List<String> columnNames;
     private long currentBlockId = -1L;
     private long currentBlockWrittenBytes = 0L;
+    // Per-row Arrow payload size observed from previously written ranges. 
Used to bound
+    // how many rows are materialized into a single Arrow root, so a large 
incoming JNI
+    // block is never copied whole before its size is known. Refined as ranges 
are written.
+    private long observedBytesPerRow = 0L;
     private final List<WriterCommitMessage> commitMessages = new ArrayList<>();
 
     // Statistics
@@ -234,7 +235,7 @@ public class MaxComputeJniWriter extends JniWriter {
         }
 
         try {
-            writeRowsWithRowChecks(inputTable, numRows, numCols);
+            writeBatch(inputTable, numRows, numCols);
         } catch (Exception e) {
             String errorMsg = "Failed to write data to MaxCompute table " + 
project + "." + tableName;
             LOG.error(errorMsg, e);
@@ -272,79 +273,144 @@ public class MaxComputeJniWriter extends JniWriter {
         openBatchWriter(requestBlockId());
     }
 
-    private void writeRowsWithRowChecks(VectorTable inputTable, int numRows, 
int numCols) throws IOException {
+    private void writeBatch(VectorTable inputTable, int numRows, int numCols) 
throws IOException {
         int rowStart = 0;
         while (rowStart < numRows) {
-            int rowEnd = rowStart;
-            long batchEstimatedBytes = 0L;
-            boolean rotateAfterWrite = false;
-            while (rowEnd < numRows) {
-                long rowEstimatedBytes = 
estimateSingleRowPayloadBytes(inputTable, numCols, rowEnd);
-                boolean exceedsHardLimit = currentBlockWrittenBytes + 
batchEstimatedBytes
-                        + rowEstimatedBytes > maxBlockBytes;
-                if (exceedsHardLimit) {
-                    if (rowEnd == rowStart) {
-                        if (currentBlockWrittenBytes > 0) {
-                            rotateCurrentBatchWriter();
-                            continue;
-                        }
-                        batchEstimatedBytes += rowEstimatedBytes;
-                        rowEnd++;
-                        rotateAfterWrite = true;
-                    }
-                    break;
+            // Bound the rows copied into one Arrow root using the per-row 
size observed so
+            // far, so an oversized incoming block is never materialized whole 
before we know
+            // whether it fits the current block.
+            int probeEnd = rowStart + 
boundedProbeRowCount(observedBytesPerRow, maxBlockBytes, numRows - rowStart);
+            try (VectorSchemaRoot root = buildRowRangeRoot(inputTable, 
numCols, rowStart, probeEnd)) {
+                int probeRows = probeEnd - rowStart;
+                long probeBytes = estimateBatchPayloadBytes(root);
+                observedBytesPerRow = probeBytes / probeRows;
+                if (currentBlockWrittenBytes + probeBytes <= maxBlockBytes) {
+                    writeRoot(root, probeRows, probeBytes);
+                    rowStart = probeEnd;
+                    continue;
                 }
-                batchEstimatedBytes += rowEstimatedBytes;
-                rowEnd++;
-                if (currentBlockWrittenBytes + batchEstimatedBytes >= 
maxBlockBytes) {
-                    rotateAfterWrite = true;
-                    break;
+
+                // The probe overflows the current block. Split it WITHOUT 
rebuilding: the binary
+                // search measures leading-row sizes from this already-built 
root via
+                // getBufferSizeFor, then we slice off the prefix that fits. 
The remaining rows are
+                // rebuilt on the next iteration (after rotating), so no Arrow 
buffer outlives the
+                // current block writer.
+                RowRange rowRange = findPartialRowRange(rowStart, probeEnd, 
currentBlockWrittenBytes,
+                        maxBlockBytes, (rangeStart, rangeEnd) -> 
prefixBufferBytes(root, rangeEnd - rangeStart));
+                if (rowRange.rotateBeforeWrite) {
+                    rotateCurrentBatchWriter();
+                    continue;
                 }
-            }
 
-            if (rowEnd == rowStart) {
-                long rowEstimatedBytes = 
estimateSingleRowPayloadBytes(inputTable, numCols, rowStart);
-                batchEstimatedBytes = rowEstimatedBytes;
-                rowEnd = rowStart + 1;
-                rotateAfterWrite = true;
+                int headRows = rowRange.rowEnd - rowStart;
+                try (VectorSchemaRoot head = root.slice(0, headRows)) {
+                    writeRoot(head, headRows, rowRange.bytes);
+                }
+                rowStart = rowRange.rowEnd;
             }
-
-            try (VectorSchemaRoot root = buildRowRangeRoot(inputTable, 
numCols, rowStart, rowEnd)) {
-                batchWriter.write(root);
+            if (rowStart < numRows && currentBlockWrittenBytes >= 
maxBlockBytes) {
+                rotateCurrentBatchWriter();
             }
-            batchWriter.flush();
-            int rowsWrittenNow = rowEnd - rowStart;
-            writtenRows += rowsWrittenNow;
-            currentBlockWrittenBytes += batchEstimatedBytes;
-            writtenBytes += batchEstimatedBytes;
-            rowStart = rowEnd;
+        }
+    }
 
-            if (rotateAfterWrite && rowStart < numRows) {
-                rotateCurrentBatchWriter();
+    // Off-heap payload bytes of the leading rowCount rows of an already-built 
Arrow root,
+    // read from the existing column buffers (getBufferSizeFor) without 
rebuilding any vector.
+    static long prefixBufferBytes(VectorSchemaRoot root, int rowCount) {
+        long total = 0L;
+        for (FieldVector vector : root.getFieldVectors()) {
+            total += vector.getBufferSizeFor(rowCount);
+        }
+        return total;
+    }
+
+    /**
+     * Choose how many rows to materialize into the next Arrow root, bounded 
so a large
+     * incoming JNI block is never copied whole before its size is known. The 
bound targets
+     * roughly one MaxCompute block worth of payload using {@code 
observedBytesPerRow}; before
+     * any range has been measured it probes a single row, then sizes from 
that row's measured
+     * Arrow payload. The result is at least one row and never exceeds {@code 
remainingRows}.
+     */
+    static int boundedProbeRowCount(long observedBytesPerRow, long 
maxBlockBytes, int remainingRows) {
+        long cap;
+        if (observedBytesPerRow <= 0L) {
+            cap = 1L;
+        } else {
+            cap = Math.max(1L, maxBlockBytes / observedBytesPerRow);
+        }
+        if (cap >= remainingRows) {
+            return remainingRows;
+        }
+        return (int) cap;
+    }
+
+    private void writeRoot(VectorSchemaRoot root, int numRows, long 
batchBytes) throws IOException {
+        batchWriter.write(root);
+        batchWriter.flush();
+
+        writtenRows += numRows;
+        currentBlockWrittenBytes += batchBytes;
+        writtenBytes += batchBytes;
+    }
+
+    static RowRange findPartialRowRange(int rowStart, int numRows, long 
currentBlockWrittenBytes,
+            long maxBlockBytes, RowRangeByteEstimator estimator) throws 
IOException {
+        int low = rowStart + 1;
+        int high = numRows - 1;
+        int bestEnd = rowStart;
+        long bestBytes = 0L;
+        while (low <= high) {
+            int mid = low + (high - low) / 2;
+            long rangeBytes = estimator.estimate(rowStart, mid);
+            if (currentBlockWrittenBytes + rangeBytes <= maxBlockBytes) {
+                bestEnd = mid;
+                bestBytes = rangeBytes;
+                low = mid + 1;
+            } else {
+                high = mid - 1;
             }
         }
+
+        if (bestEnd > rowStart) {
+            return RowRange.write(bestEnd, bestBytes);
+        }
+        if (currentBlockWrittenBytes > 0) {
+            return RowRange.rotateBeforeWrite();
+        }
+        return RowRange.write(rowStart + 1, estimator.estimate(rowStart, 
rowStart + 1));
+    }
+
+    interface RowRangeByteEstimator {
+        long estimate(int rowStart, int rowEnd) throws IOException;
     }
 
-    private static class CountingDiscardOutputStream extends OutputStream {
-        @Override
-        public void write(int b) {
-            // Discard bytes while allowing WriteChannel to track payload size.
+    static class RowRange {
+        final int rowEnd;
+        final long bytes;
+        final boolean rotateBeforeWrite;
+
+        private RowRange(int rowEnd, long bytes, boolean rotateBeforeWrite) {
+            this.rowEnd = rowEnd;
+            this.bytes = bytes;
+            this.rotateBeforeWrite = rotateBeforeWrite;
+        }
+
+        static RowRange write(int rowEnd, long bytes) {
+            return new RowRange(rowEnd, bytes, false);
         }
 
-        @Override
-        public void write(byte[] b, int off, int len) {
-            // Discard bytes while allowing WriteChannel to track payload size.
+        static RowRange rotateBeforeWrite() {
+            return new RowRange(-1, 0L, true);
         }
     }
 
-    private long estimateSingleRowPayloadBytes(VectorTable inputTable, int 
numCols, int rowIndex)
-            throws IOException {
-        try (VectorSchemaRoot root = buildRowRangeRoot(inputTable, numCols, 
rowIndex, rowIndex + 1);
-                ArrowWriter estimator = 
ArrowWriterFactory.getRecordBatchWriter(
-                        new CountingDiscardOutputStream(), writerOptions)) {
-            estimator.writeBatch(root);
-            return estimator.bytesWritten();
+    // Estimate an Arrow batch's payload size from its column buffer sizes 
(O(columns)).
+    static long estimateBatchPayloadBytes(VectorSchemaRoot root) {
+        long total = 0L;
+        for (FieldVector vector : root.getFieldVectors()) {
+            total += vector.getBufferSize();
         }
+        return total;
     }
 
     private VectorSchemaRoot buildRowRangeRoot(VectorTable inputTable, int 
numCols, int rowStart, int rowEnd) {
diff --git 
a/fe/be-java-extensions/max-compute-connector/src/test/java/org/apache/doris/maxcompute/MaxComputeJniWriterTest.java
 
b/fe/be-java-extensions/max-compute-connector/src/test/java/org/apache/doris/maxcompute/MaxComputeJniWriterTest.java
new file mode 100644
index 00000000000..84cc4a79d6e
--- /dev/null
+++ 
b/fe/be-java-extensions/max-compute-connector/src/test/java/org/apache/doris/maxcompute/MaxComputeJniWriterTest.java
@@ -0,0 +1,133 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.maxcompute;
+
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.Collections;
+
+public class MaxComputeJniWriterTest {
+    @Test
+    public void testPrefixBufferBytesMeasuresLeadingRowsWithoutRebuild() {
+        try (BufferAllocator allocator = new RootAllocator();
+                IntVector vec = new IntVector("c", allocator)) {
+            vec.allocateNew(8);
+            for (int i = 0; i < 8; i++) {
+                vec.set(i, i);
+            }
+            vec.setValueCount(8);
+            try (VectorSchemaRoot root = new 
VectorSchemaRoot(Collections.singletonList(vec))) {
+                // The whole-root measurement must match 
estimateBatchPayloadBytes...
+                
Assert.assertEquals(MaxComputeJniWriter.estimateBatchPayloadBytes(root),
+                        MaxComputeJniWriter.prefixBufferBytes(root, 
root.getRowCount()));
+                // ...and a leading prefix must be strictly smaller, computed 
from the
+                // already-built buffers (no rebuild).
+                Assert.assertTrue(MaxComputeJniWriter.prefixBufferBytes(root, 
4)
+                        < MaxComputeJniWriter.prefixBufferBytes(root, 8));
+            }
+        }
+    }
+
+    @Test
+    public void testFindPartialRowRangeFillsRemainingBlock() throws Exception {
+        MaxComputeJniWriter.RowRange range = 
MaxComputeJniWriter.findPartialRowRange(
+                0, 4, 60L, 100L, prefixEstimator(10L, 20L, 30L, 40L));
+
+        Assert.assertFalse(range.rotateBeforeWrite);
+        Assert.assertEquals(2, range.rowEnd);
+        Assert.assertEquals(30L, range.bytes);
+    }
+
+    @Test
+    public void testFindPartialRowRangeRotatesWhenNoRowFitsNonEmptyBlock() 
throws Exception {
+        MaxComputeJniWriter.RowRange range = 
MaxComputeJniWriter.findPartialRowRange(
+                0, 3, 95L, 100L, prefixEstimator(10L, 20L, 30L));
+
+        Assert.assertTrue(range.rotateBeforeWrite);
+    }
+
+    @Test
+    public void 
testFindPartialRowRangeKeepsSingleOversizeFallbackOnEmptyBlock() throws 
Exception {
+        MaxComputeJniWriter.RowRange range = 
MaxComputeJniWriter.findPartialRowRange(
+                0, 3, 0L, 5L, prefixEstimator(10L, 20L, 30L));
+
+        Assert.assertFalse(range.rotateBeforeWrite);
+        Assert.assertEquals(1, range.rowEnd);
+        Assert.assertEquals(10L, range.bytes);
+    }
+
+    @Test
+    public void testFindPartialRowRangeUsesRowStartOffset() throws Exception {
+        MaxComputeJniWriter.RowRange range = 
MaxComputeJniWriter.findPartialRowRange(
+                1, 4, 50L, 100L, prefixEstimator(999L, 30L, 30L, 50L));
+
+        Assert.assertFalse(range.rotateBeforeWrite);
+        Assert.assertEquals(2, range.rowEnd);
+        Assert.assertEquals(30L, range.bytes);
+    }
+
+    @Test
+    public void testBoundedProbeRowCountBootstrapsWithSingleRow() {
+        // No per-row estimate yet: bootstrap by measuring a single row's real 
Arrow payload,
+        // so an oversized input is never copied whole and we never guess a 
row count.
+        int probeRows = MaxComputeJniWriter.boundedProbeRowCount(0L, 64L * 
1024 * 1024, 1_000_000);
+
+        Assert.assertEquals(1, probeRows);
+    }
+
+    @Test
+    public void testBoundedProbeRowCountTargetsOneBlockAfterMeasurement() {
+        // 1 KiB/row against a 64 MiB block => ~65536 rows fill one block.
+        int probeRows = MaxComputeJniWriter.boundedProbeRowCount(1024L, 64L * 
1024 * 1024, 1_000_000);
+
+        Assert.assertEquals(65536, probeRows);
+        Assert.assertTrue(probeRows < 1_000_000);
+    }
+
+    @Test
+    public void testBoundedProbeRowCountReturnsRemainingWhenItFitsCap() {
+        // A small input that comfortably fits one block is probed in one shot.
+        int probeRows = MaxComputeJniWriter.boundedProbeRowCount(1024L, 64L * 
1024 * 1024, 4096);
+
+        Assert.assertEquals(4096, probeRows);
+    }
+
+    @Test
+    public void testBoundedProbeRowCountProbesSingleRowWhenRowExceedsBlock() {
+        // A single row larger than a whole block must still make progress 
(never 0 rows).
+        int probeRows = MaxComputeJniWriter.boundedProbeRowCount(
+                128L * 1024 * 1024, 64L * 1024 * 1024, 1_000_000);
+
+        Assert.assertEquals(1, probeRows);
+    }
+
+    private static MaxComputeJniWriter.RowRangeByteEstimator 
prefixEstimator(long... rowBytes) {
+        return (rowStart, rowEnd) -> {
+            long bytes = 0L;
+            for (int i = rowStart; i < rowEnd; i++) {
+                bytes += rowBytes[i];
+            }
+            return bytes;
+        };
+    }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to