This is an automated email from the ASF dual-hosted git repository.

CurtHagenlocher pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-dotnet.git


The following commit(s) were added to refs/heads/main by this push:
     new 3434344  perf: speed up MemoryStream IPC stream reads (#340)
3434344 is described below

commit 343434468c749c47ed0f3d20823bb29717a3c33c
Author: InCerryGit <[email protected]>
AuthorDate: Wed Apr 29 00:44:58 2026 +0800

    perf: speed up MemoryStream IPC stream reads (#340)
    
    ## Summary
    
    This improves `ArrowStreamReader` when reading from `MemoryStream`
    instances that expose their underlying buffer. The reader now uses the
    exposed buffer for IPC message/schema metadata reads while preserving
    the existing reader-owned body-buffer boundary.
    
    The change is intentionally scoped to MemoryStream-backed IPC stream
    reads:
    
    - public/exposed `MemoryStream` can use the fast path
    - non-public `MemoryStream` and partial-read streams continue through
    the fallback stream-read path
    - record batch body data is still copied into allocator-owned memory
    before array construction
    - `ArrowMemoryReader` exact-length continuation-token handling is
    corrected for complete in-memory buffers
    
    ## Benchmark
    
    BenchmarkDotNet ShortRun, `ArrowReaderBenchmark`:
    
    | Scenario | Before | After |
    |---|---:|---:|
    | `ArrowReaderWithMemoryStream_ManagedMemory`, 100000 rows / 1 column |
    21629.3 us | 7707.6 us |
    | `ArrowReaderWithMemoryStream_ManagedMemory`, 100000 rows / 5 columns |
    91112.3 us | 40137.5 us |
    
    ## Validation
    
    - `dotnet test test/Apache.Arrow.Tests/Apache.Arrow.Tests.csproj -c
    Release --filter
    "FullyQualifiedName~Apache.Arrow.Tests.ArrowStreamReaderTests"`
    - `dotnet test
    test/Apache.Arrow.Compression.Tests/Apache.Arrow.Compression.Tests.csproj
    -c Release --filter
    "FullyQualifiedName~Apache.Arrow.Compression.Tests.ArrowStreamReaderTests"`
    - `dotnet build Apache.Arrow.sln -c Release`
---
 .../Ipc/ArrowMemoryReaderImplementation.cs         |   4 +-
 .../Ipc/ArrowMemoryStreamReaderImplementation.cs   | 214 +++++++++++
 src/Apache.Arrow/Ipc/ArrowStreamReader.cs          |  19 +-
 .../Ipc/ArrowStreamReaderImplementation.cs         |  27 +-
 .../ArrowReaderBenchmark.cs                        |  92 ++++-
 .../ArrowStreamReaderTests.cs                      |  41 ++-
 test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs  | 408 ++++++++++++++++++---
 7 files changed, 730 insertions(+), 75 deletions(-)

diff --git a/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs 
b/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs
index e9cd936..887036c 100644
--- a/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs
+++ b/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs
@@ -87,7 +87,7 @@ namespace Apache.Arrow.Ipc
                 else if (messageLength == 
MessageSerializer.IpcContinuationToken)
                 {
                     // ARROW-6313, if the first 4 bytes are continuation 
message, read the next 4 for the length
-                    if (_buffer.Length <= _bufferPosition + sizeof(int))
+                    if (_buffer.Length < _bufferPosition + sizeof(int))
                     {
                         throw new InvalidDataException("Corrupted IPC message. 
Received a continuation token at the end of the message.");
                     }
@@ -136,7 +136,7 @@ namespace Apache.Arrow.Ipc
             if (schemaMessageLength == MessageSerializer.IpcContinuationToken)
             {
                 // ARROW-6313, if the first 4 bytes are continuation message, 
read the next 4 for the length
-                if (_buffer.Length <= _bufferPosition + sizeof(int))
+                if (_buffer.Length < _bufferPosition + sizeof(int))
                 {
                     throw new InvalidDataException("Corrupted IPC message. 
Received a continuation token at the end of the message.");
                 }
diff --git a/src/Apache.Arrow/Ipc/ArrowMemoryStreamReaderImplementation.cs 
b/src/Apache.Arrow/Ipc/ArrowMemoryStreamReaderImplementation.cs
new file mode 100644
index 0000000..01d3b0b
--- /dev/null
+++ b/src/Apache.Arrow/Ipc/ArrowMemoryStreamReaderImplementation.cs
@@ -0,0 +1,214 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+using System;
+using System.Buffers;
+using System.IO;
+using System.Threading;
+using System.Threading.Tasks;
+using Apache.Arrow.Memory;
+
+namespace Apache.Arrow.Ipc
+{
+    /// <summary>
+    /// Reads Arrow IPC streams from a <see cref="MemoryStream"/> whose 
backing buffer is publicly visible.
+    /// </summary>
+    /// <remarks>
+    /// Message metadata can be read directly from the exposed stream buffer, 
but record batch bodies are
+    /// still copied into allocator-owned buffers to preserve <see 
cref="ArrowStreamReader"/> ownership semantics.
+    /// </remarks>
+    internal sealed class ArrowMemoryStreamReaderImplementation : 
ArrowStreamReaderImplementation
+    {
+        private readonly MemoryStream _stream;
+        private readonly Memory<byte> _streamMemory;
+
+        public ArrowMemoryStreamReaderImplementation(
+            MemoryStream stream,
+            MemoryAllocator allocator,
+            ICompressionCodecFactory compressionCodecFactory,
+            bool leaveOpen,
+            ExtensionTypeRegistry extensionRegistry)
+            : base(stream, allocator, compressionCodecFactory, leaveOpen, 
extensionRegistry)
+        {
+            _stream = stream;
+
+            if (!stream.TryGetBuffer(out ArraySegment<byte> streamBuffer))
+            {
+                throw new InvalidOperationException("Expected MemoryStream to 
expose its backing buffer.");
+            }
+
+            _streamMemory = streamBuffer.Array.AsMemory(streamBuffer.Offset, 
streamBuffer.Count);
+        }
+
+        public override ValueTask<RecordBatch> 
ReadNextRecordBatchAsync(CancellationToken cancellationToken)
+        {
+            cancellationToken.ThrowIfCancellationRequested();
+
+            try
+            {
+                return new ValueTask<RecordBatch>(ReadNextRecordBatch());
+            }
+            catch (Exception ex)
+            {
+                return new 
ValueTask<RecordBatch>(Task.FromException<RecordBatch>(ex));
+            }
+        }
+
+        public override RecordBatch ReadNextRecordBatch()
+        {
+            ReadSchema();
+
+            ReadResult result = default;
+            do
+            {
+                result = ReadMessageFromMemory();
+            } while (result.Batch == null && result.MessageLength > 0);
+
+            return result.Batch;
+        }
+
+        public override ValueTask<Schema> ReadSchemaAsync(CancellationToken 
cancellationToken = default)
+        {
+            cancellationToken.ThrowIfCancellationRequested();
+
+            if (HasReadSchema)
+            {
+                return new ValueTask<Schema>(_schema);
+            }
+
+            try
+            {
+                ReadSchema();
+                return new ValueTask<Schema>(_schema);
+            }
+            catch (Exception ex)
+            {
+                return new ValueTask<Schema>(Task.FromException<Schema>(ex));
+            }
+        }
+
+        public override void ReadSchema()
+        {
+            if (HasReadSchema)
+            {
+                return;
+            }
+
+            int schemaMessageLength = 
ReadMessageLengthFromMemory(throwOnFullRead: true, returnOnEmptyStream: true);
+            if (schemaMessageLength == 0)
+            {
+                return;
+            }
+
+            Memory<byte> schemaBuffer = ReadMemory(schemaMessageLength);
+            _schema = 
MessageSerializer.GetSchema(ReadMessage<Flatbuf.Schema>(CreateByteBuffer(schemaBuffer)),
 ref _dictionaryMemo, _extensionRegistry);
+        }
+
+        private ReadResult ReadMessageFromMemory()
+        {
+            int messageLength = ReadMessageLengthFromMemory(throwOnFullRead: 
false, returnOnEmptyStream: false);
+            if (messageLength == 0)
+            {
+                return default;
+            }
+
+            Memory<byte> messageBuffer = ReadMemory(messageLength);
+            Flatbuf.Message message = 
Flatbuf.Message.GetRootAsMessage(CreateByteBuffer(messageBuffer));
+
+            if (message.BodyLength > int.MaxValue)
+            {
+                throw new OverflowException(
+                    $"Arrow IPC message body length ({message.BodyLength}) is 
larger than " +
+                    $"the maximum supported message size ({int.MaxValue})");
+            }
+
+            int bodyLength = (int)message.BodyLength;
+            Memory<byte> sourceBodyBuffer = ReadMemory(bodyLength);
+            IMemoryOwner<byte> bodyBufferOwner = 
AllocateMessageBodyBuffer(bodyLength);
+            Memory<byte> bodyBuffer = bodyBufferOwner.Memory.Slice(0, 
bodyLength);
+            sourceBodyBuffer.CopyTo(bodyBuffer);
+            Google.FlatBuffers.ByteBuffer bodybb = 
CreateByteBuffer(bodyBuffer);
+
+            // Keep stream-reader ownership semantics: batches outlive the 
source MemoryStream buffer.
+            return new ReadResult(messageLength, 
CreateArrowObjectFromMessage(message, bodybb, bodyBufferOwner));
+        }
+
+        private int ReadMessageLengthFromMemory(bool throwOnFullRead, bool 
returnOnEmptyStream)
+        {
+            if (_stream.Position == _stream.Length && returnOnEmptyStream)
+            {
+                return 0;
+            }
+
+            if (!TryReadInt32(throwOnFullRead, out int messageLength))
+            {
+                return 0;
+            }
+
+            if (messageLength == MessageSerializer.IpcContinuationToken &&
+                !TryReadInt32(throwOnFullRead, out messageLength))
+            {
+                return 0;
+            }
+
+            return messageLength;
+        }
+
+        private bool TryReadInt32(bool throwOnFullRead, out int value)
+        {
+            value = 0;
+
+            if (!TryReadMemory(sizeof(int), throwOnFullRead, out Memory<byte> 
buffer))
+            {
+                return false;
+            }
+
+            value = BitUtility.ReadInt32(buffer);
+            return true;
+        }
+
+        private bool TryReadMemory(int length, bool throwOnFullRead, out 
Memory<byte> buffer)
+        {
+            buffer = default;
+
+            long remainingLength = _stream.Length - _stream.Position;
+            if (remainingLength < length)
+            {
+                if (throwOnFullRead)
+                {
+                    throw new InvalidOperationException("Unexpectedly reached 
the end of the stream before a full buffer was read.");
+                }
+
+                _stream.Position = _stream.Length;
+                return false;
+            }
+
+            buffer = ReadMemory(length);
+            return true;
+        }
+
+        private Memory<byte> ReadMemory(int length)
+        {
+            if (length == 0)
+            {
+                return Memory<byte>.Empty;
+            }
+
+            Memory<byte> buffer = 
_streamMemory.Slice(checked((int)_stream.Position), length);
+            _stream.Position += length;
+            return buffer;
+        }
+    }
+}
diff --git a/src/Apache.Arrow/Ipc/ArrowStreamReader.cs 
b/src/Apache.Arrow/Ipc/ArrowStreamReader.cs
index e5dade2..6100eea 100644
--- a/src/Apache.Arrow/Ipc/ArrowStreamReader.cs
+++ b/src/Apache.Arrow/Ipc/ArrowStreamReader.cs
@@ -68,7 +68,7 @@ namespace Apache.Arrow.Ipc
             if (stream == null)
                 throw new ArgumentNullException(nameof(stream));
 
-            _implementation = new ArrowStreamReaderImplementation(stream, 
allocator, compressionCodecFactory, leaveOpen);
+            _implementation = CreateImplementation(stream, allocator, 
compressionCodecFactory, leaveOpen, extensionRegistry: null);
         }
 
         public ArrowStreamReader(ArrowContext context, Stream stream, bool 
leaveOpen = false)
@@ -78,7 +78,7 @@ namespace Apache.Arrow.Ipc
             if (context == null)
                 throw new ArgumentNullException(nameof(context));
 
-            _implementation = new ArrowStreamReaderImplementation(stream, 
context.Allocator, context.CompressionCodecFactory, leaveOpen, 
context.ExtensionRegistry);
+            _implementation = CreateImplementation(stream, context.Allocator, 
context.CompressionCodecFactory, leaveOpen, context.ExtensionRegistry);
         }
 
         public ArrowStreamReader(ReadOnlyMemory<byte> buffer)
@@ -104,6 +104,21 @@ namespace Apache.Arrow.Ipc
             _implementation = implementation;
         }
 
+        private static ArrowReaderImplementation CreateImplementation(
+            Stream stream,
+            MemoryAllocator allocator,
+            ICompressionCodecFactory compressionCodecFactory,
+            bool leaveOpen,
+            ExtensionTypeRegistry extensionRegistry)
+        {
+            if (stream is MemoryStream memoryStream && 
memoryStream.TryGetBuffer(out _))
+            {
+                return new ArrowMemoryStreamReaderImplementation(memoryStream, 
allocator, compressionCodecFactory, leaveOpen, extensionRegistry);
+            }
+
+            return new ArrowStreamReaderImplementation(stream, allocator, 
compressionCodecFactory, leaveOpen, extensionRegistry);
+        }
+
         public void Dispose()
         {
             Dispose(true);
diff --git a/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs 
b/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs
index 9d0cbe3..179b998 100644
--- a/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs
+++ b/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs
@@ -47,11 +47,10 @@ namespace Apache.Arrow.Ipc
             }
         }
 
-        public override async ValueTask<RecordBatch> 
ReadNextRecordBatchAsync(CancellationToken cancellationToken)
+        public override ValueTask<RecordBatch> 
ReadNextRecordBatchAsync(CancellationToken cancellationToken)
         {
-            // TODO: Loop until a record batch is read.
             cancellationToken.ThrowIfCancellationRequested();
-            return await 
ReadRecordBatchAsync(cancellationToken).ConfigureAwait(false);
+            return ReadRecordBatchAsync(cancellationToken);
         }
 
         public override RecordBatch ReadNextRecordBatch()
@@ -61,7 +60,7 @@ namespace Apache.Arrow.Ipc
 
         protected async ValueTask<RecordBatch> 
ReadRecordBatchAsync(CancellationToken cancellationToken = default)
         {
-            await ReadSchemaAsync().ConfigureAwait(false);
+            await ReadSchemaAsync(cancellationToken).ConfigureAwait(false);
 
             ReadResult result = default;
             do
@@ -94,7 +93,7 @@ namespace Apache.Arrow.Ipc
 
                 int bodyLength = checked((int)message.BodyLength);
 
-                IMemoryOwner<byte> bodyBuffOwner = 
_allocator.Allocate(bodyLength);
+                IMemoryOwner<byte> bodyBuffOwner = 
AllocateMessageBodyBuffer(bodyLength);
                 Memory<byte> bodyBuff = bodyBuffOwner.Memory.Slice(0, 
bodyLength);
                 bytesRead = await BaseStream.ReadFullBufferAsync(bodyBuff, 
cancellationToken)
                     .ConfigureAwait(false);
@@ -145,7 +144,7 @@ namespace Apache.Arrow.Ipc
                 }
                 int bodyLength = (int)message.BodyLength;
 
-                IMemoryOwner<byte> bodyBuffOwner = 
_allocator.Allocate(bodyLength);
+                IMemoryOwner<byte> bodyBuffOwner = 
AllocateMessageBodyBuffer(bodyLength);
                 Memory<byte> bodyBuff = bodyBuffOwner.Memory.Slice(0, 
bodyLength);
                 bytesRead = BaseStream.ReadFullBuffer(bodyBuff);
                 EnsureFullRead(bodyBuff, bytesRead);
@@ -157,13 +156,25 @@ namespace Apache.Arrow.Ipc
             return new ReadResult(messageLength, result);
         }
 
-        public override async ValueTask<Schema> 
ReadSchemaAsync(CancellationToken cancellationToken = default)
+        protected IMemoryOwner<byte> AllocateMessageBodyBuffer(int bodyLength)
         {
+            return _allocator.Allocate(bodyLength);
+        }
+
+        public override ValueTask<Schema> ReadSchemaAsync(CancellationToken 
cancellationToken = default)
+        {
+            cancellationToken.ThrowIfCancellationRequested();
+
             if (HasReadSchema)
             {
-                return _schema;
+                return new ValueTask<Schema>(_schema);
             }
 
+            return ReadSchemaAsyncCore(cancellationToken);
+        }
+
+        private async ValueTask<Schema> ReadSchemaAsyncCore(CancellationToken 
cancellationToken)
+        {
             // Figure out length of schema
             int schemaMessageLength = await 
ReadMessageLengthAsync(throwOnFullRead: true, returnOnEmptyStream: true, 
cancellationToken)
                 .ConfigureAwait(false);
diff --git a/test/Apache.Arrow.Benchmarks/ArrowReaderBenchmark.cs 
b/test/Apache.Arrow.Benchmarks/ArrowReaderBenchmark.cs
index 3760d94..9305adb 100644
--- a/test/Apache.Arrow.Benchmarks/ArrowReaderBenchmark.cs
+++ b/test/Apache.Arrow.Benchmarks/ArrowReaderBenchmark.cs
@@ -14,8 +14,8 @@
 // limitations under the License.
 
 using System;
+using System.Collections.Generic;
 using System.IO;
-using System.Linq;
 using System.Threading.Tasks;
 using Apache.Arrow.Ipc;
 using Apache.Arrow.Memory;
@@ -32,13 +32,19 @@ namespace Apache.Arrow.Benchmarks
         [Params(10_000, 1_000_000)]
         public int Count { get; set; }
 
+        [Params(1, 5)]
+        public int ColumnSetCount { get; set; }
+
         private MemoryStream _memoryStream;
         private static readonly MemoryAllocator s_allocator = new 
TestMemoryAllocator();
 
         [GlobalSetup]
         public async Task GlobalSetup()
         {
-            RecordBatch batch = TestData.CreateSampleRecordBatch(length: 
Count, createDictionaryArray: false);
+            RecordBatch batch = TestData.CreateSampleRecordBatch(
+                length: Count,
+                columnSetCount: ColumnSetCount,
+                excludedTypes: new HashSet<ArrowTypeId> { 
ArrowTypeId.Dictionary, ArrowTypeId.RunEndEncoded });
             _memoryStream = new MemoryStream();
 
             ArrowStreamWriter writer = new ArrowStreamWriter(_memoryStream, 
batch.Schema);
@@ -83,6 +89,73 @@ namespace Apache.Arrow.Benchmarks
             return sum;
         }
 
+        [Benchmark]
+        public async Task<double> 
ArrowReaderWithMemoryStream_ExplicitDefaultAllocator()
+        {
+            double sum = 0;
+            var reader = new ArrowStreamReader(_memoryStream, 
MemoryAllocator.Default.Value);
+            RecordBatch recordBatch;
+            while ((recordBatch = await reader.ReadNextRecordBatchAsync()) != 
null)
+            {
+                using (recordBatch)
+                {
+                    sum += SumAllNumbers(recordBatch);
+                }
+            }
+            return sum;
+        }
+
+        [Benchmark]
+        public async Task<double> 
ArrowReaderWithNonPubliclyVisibleMemoryStream()
+        {
+            double sum = 0;
+            using var stream = CreateNonPubliclyVisibleReadStream();
+            using var reader = new ArrowStreamReader(stream);
+            RecordBatch recordBatch;
+            while ((recordBatch = await reader.ReadNextRecordBatchAsync()) != 
null)
+            {
+                using (recordBatch)
+                {
+                    sum += SumAllNumbers(recordBatch);
+                }
+            }
+            return sum;
+        }
+
+        [Benchmark]
+        public async Task<double> 
ArrowReaderWithNonPubliclyVisibleMemoryStream_ManagedMemory()
+        {
+            double sum = 0;
+            using var stream = CreateNonPubliclyVisibleReadStream();
+            using var reader = new ArrowStreamReader(stream, s_allocator);
+            RecordBatch recordBatch;
+            while ((recordBatch = await reader.ReadNextRecordBatchAsync()) != 
null)
+            {
+                using (recordBatch)
+                {
+                    sum += SumAllNumbers(recordBatch);
+                }
+            }
+            return sum;
+        }
+
+        [Benchmark]
+        public async Task<double> 
ArrowReaderWithNonPubliclyVisibleMemoryStream_ExplicitDefaultAllocator()
+        {
+            double sum = 0;
+            using var stream = CreateNonPubliclyVisibleReadStream();
+            using var reader = new ArrowStreamReader(stream, 
MemoryAllocator.Default.Value);
+            RecordBatch recordBatch;
+            while ((recordBatch = await reader.ReadNextRecordBatchAsync()) != 
null)
+            {
+                using (recordBatch)
+                {
+                    sum += SumAllNumbers(recordBatch);
+                }
+            }
+            return sum;
+        }
+
         [Benchmark]
         public async Task<double> ArrowReaderWithMemory()
         {
@@ -99,14 +172,25 @@ namespace Apache.Arrow.Benchmarks
             return sum;
         }
 
+        private MemoryStream CreateNonPubliclyVisibleReadStream()
+        {
+            return new MemoryStream(
+                _memoryStream.GetBuffer(),
+                index: 0,
+                count: checked((int)_memoryStream.Length),
+                writable: false,
+                publiclyVisible: false);
+        }
+
         private static double SumAllNumbers(RecordBatch recordBatch)
         {
             double sum = 0;
 
             for (int k = 0; k < recordBatch.ColumnCount; k++)
             {
-                var array = recordBatch.Arrays.ElementAt(k);
-                switch (recordBatch.Schema.GetFieldByIndex(k).DataType.TypeId)
+                var array = recordBatch.Column(k);
+                ArrowTypeId typeId = 
recordBatch.Schema.GetFieldByIndex(k).DataType.TypeId;
+                switch (typeId)
                 {
                     case ArrowTypeId.Int64:
                         Int64Array int64Array = (Int64Array)array;
diff --git a/test/Apache.Arrow.Compression.Tests/ArrowStreamReaderTests.cs 
b/test/Apache.Arrow.Compression.Tests/ArrowStreamReaderTests.cs
index 9c2bf75..99e7442 100644
--- a/test/Apache.Arrow.Compression.Tests/ArrowStreamReaderTests.cs
+++ b/test/Apache.Arrow.Compression.Tests/ArrowStreamReaderTests.cs
@@ -14,8 +14,10 @@
 // limitations under the License.
 
 using System;
+using System.IO;
 using System.Reflection;
 using Apache.Arrow.Ipc;
+using Apache.Arrow.Memory;
 using Apache.Arrow.Tests;
 using Xunit;
 
@@ -47,13 +49,34 @@ namespace Apache.Arrow.Compression.Tests
             using var stream = 
assembly.GetManifestResourceStream($"Apache.Arrow.Compression.Tests.Resources.{fileName}");
             Assert.NotNull(stream);
             var buffer = new byte[stream.Length];
-            stream.ReadFullBuffer(buffer);
+            ReadExactly(stream, buffer);
             var codecFactory = new Compression.CompressionCodecFactory();
             using var reader = new ArrowStreamReader(buffer, codecFactory);
 
             VerifyCompressedIpcFileBatch(reader.ReadNextRecordBatch());
         }
 
+        [Theory]
+        [InlineData("ipc_lz4_compression.arrow_stream")]
+        [InlineData("ipc_zstd_compression.arrow_stream")]
+        public void 
CanReadCompressedIpcStreamFromMemoryBuffer_UsesDefaultAllocator(string fileName)
+        {
+            var assembly = Assembly.GetExecutingAssembly();
+            using var stream = 
assembly.GetManifestResourceStream($"Apache.Arrow.Compression.Tests.Resources.{fileName}");
+            Assert.NotNull(stream);
+            var buffer = new byte[stream.Length];
+            ReadExactly(stream, buffer);
+            var codecFactory = new Compression.CompressionCodecFactory();
+
+            long allocationsBeforeRead = 
MemoryAllocator.Default.Value.Statistics.Allocations;
+
+            using var reader = new ArrowStreamReader(buffer, codecFactory);
+            using RecordBatch batch = reader.ReadNextRecordBatch();
+            VerifyCompressedIpcFileBatch(batch);
+
+            Assert.True(MemoryAllocator.Default.Value.Statistics.Allocations > 
allocationsBeforeRead);
+        }
+
         [Fact]
         public void ErrorReadingCompressedStreamWithoutCodecFactory()
         {
@@ -86,6 +109,21 @@ namespace Apache.Arrow.Compression.Tests
 
         }
 
+        private static void ReadExactly(Stream stream, byte[] buffer)
+        {
+            int offset = 0;
+            while (offset < buffer.Length)
+            {
+                int bytesRead = stream.Read(buffer, offset, buffer.Length - 
offset);
+                if (bytesRead == 0)
+                {
+                    throw new EndOfStreamException();
+                }
+
+                offset += bytesRead;
+            }
+        }
+
         private static void VerifyCompressedIpcFileBatch(RecordBatch batch)
         {
             var intArray = (Int32Array)batch.Column("integers");
@@ -103,4 +141,3 @@ namespace Apache.Arrow.Compression.Tests
         }
     }
 }
-
diff --git a/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs 
b/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs
index 5e7c57e..d04e0cd 100644
--- a/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs
+++ b/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs
@@ -16,9 +16,11 @@
 using System;
 using System.Buffers.Binary;
 using System.IO;
+using System.Reflection;
 using System.Threading;
 using System.Threading.Tasks;
 using Apache.Arrow.Ipc;
+using Apache.Arrow.Memory;
 using Apache.Arrow.Types;
 using Xunit;
 
@@ -69,13 +71,20 @@ namespace Apache.Arrow.Tests
 
                 var memoryPool = new TestMemoryAllocator();
                 ArrowStreamReader reader = new ArrowStreamReader(stream, 
memoryPool, shouldLeaveOpen);
-                reader.ReadNextRecordBatch();
-
-                Assert.Equal(expectedAllocations, 
memoryPool.Statistics.Allocations);
-                Assert.True(memoryPool.Statistics.BytesAllocated > 0);
+                using (RecordBatch readBatch = reader.ReadNextRecordBatch())
+                {
+                    Assert.Equal(expectedAllocations, 
memoryPool.Statistics.Allocations);
+                    Assert.True(memoryPool.Statistics.BytesAllocated > 0);
+                    Assert.Equal(expectedAllocations, memoryPool.Rented);
+                }
 
                 reader.Dispose();
 
+                if (!createDictionaryArray)
+                {
+                    Assert.Equal(0, memoryPool.Rented);
+                }
+
                 if (shouldLeaveOpen)
                 {
                     Assert.True(stream.Position > 0);
@@ -109,6 +118,22 @@ namespace Apache.Arrow.Tests
             await TestReaderFromMemory(ArrowReaderVerifier.VerifyReaderAsync, 
writeEnd);
         }
 
+        [Fact]
+        public async Task ReadRecordBatch_Memory_ExactLengthSlice()
+        {
+            await TestReaderFromMemoryExactLength((reader, originalBatch) =>
+            {
+                ArrowReaderVerifier.VerifyReader(reader, originalBatch);
+                return Task.CompletedTask;
+            });
+        }
+
+        [Fact]
+        public async Task ReadRecordBatchAsync_Memory_ExactLengthSlice()
+        {
+            await 
TestReaderFromMemoryExactLength(ArrowReaderVerifier.VerifyReaderAsync);
+        }
+
         private static async Task TestReaderFromMemory(
             Func<ArrowStreamReader, RecordBatch, Task> verificationFunc,
             bool writeEnd)
@@ -131,6 +156,24 @@ namespace Apache.Arrow.Tests
             await verificationFunc(reader, originalBatch);
         }
 
+        private static async Task TestReaderFromMemoryExactLength(
+            Func<ArrowStreamReader, RecordBatch, Task> verificationFunc)
+        {
+            RecordBatch originalBatch = 
TestData.CreateSampleRecordBatch(length: 100);
+
+            ReadOnlyMemory<byte> buffer;
+            using (MemoryStream stream = new MemoryStream())
+            {
+                ArrowStreamWriter writer = new ArrowStreamWriter(stream, 
originalBatch.Schema);
+                await writer.WriteRecordBatchAsync(originalBatch);
+                await writer.WriteEndAsync();
+                buffer = stream.GetBuffer().AsMemory(0, 
checked((int)stream.Length));
+            }
+
+            ArrowStreamReader reader = new ArrowStreamReader(buffer);
+            await verificationFunc(reader, originalBatch);
+        }
+
         [Fact]
         public void ReadRecordBatch_EmptyStream()
         {
@@ -167,6 +210,40 @@ namespace Apache.Arrow.Tests
             }
         }
 
+        [Fact]
+        public async Task 
ReadRecordBatchAsync_PassesCancellationTokenToSchemaRead()
+        {
+            using var stream = new RequiresCancelableReadStream();
+            using var reader = new ArrowStreamReader(stream);
+            using var cancellation = new CancellationTokenSource();
+
+            await Assert.ThrowsAnyAsync<OperationCanceledException>(async () =>
+                await reader.ReadNextRecordBatchAsync(cancellation.Token));
+
+            Assert.True(stream.SawCancelableToken);
+        }
+
+        [Fact]
+        public async Task 
ReadRecordBatchAsync_Stream_DictionaryFixtureWithoutRee()
+        {
+            using RecordBatch originalBatch = TestData.CreateSampleRecordBatch(
+                length: 100,
+                columnSetCount: 5,
+                excludedTypes: new 
System.Collections.Generic.HashSet<ArrowTypeId> { ArrowTypeId.RunEndEncoded });
+
+            using var stream = new MemoryStream();
+            using (ArrowStreamWriter writer = new ArrowStreamWriter(stream, 
originalBatch.Schema, leaveOpen: true))
+            {
+                await writer.WriteRecordBatchAsync(originalBatch);
+                await writer.WriteEndAsync();
+            }
+
+            stream.Position = 0;
+
+            using var reader = new ArrowStreamReader(stream);
+            await ArrowReaderVerifier.VerifyReaderAsync(reader, originalBatch);
+        }
+
         [Theory]
         [InlineData(true, true)]
         [InlineData(true, false)]
@@ -177,6 +254,84 @@ namespace Apache.Arrow.Tests
             await TestReaderFromStream(ArrowReaderVerifier.VerifyReaderAsync, 
writeEnd, createDictionaryArray);
         }
 
+        [Theory]
+        [InlineData(true)]
+        [InlineData(false)]
+        public async Task 
ReadRecordBatchAsync_NonPubliclyVisibleMemoryStream(bool createDictionaryArray)
+        {
+            RecordBatch originalBatch = 
TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: 
createDictionaryArray);
+
+            byte[] buffer;
+            using (MemoryStream stream = new MemoryStream())
+            {
+                ArrowStreamWriter writer = new ArrowStreamWriter(stream, 
originalBatch.Schema, leaveOpen: true);
+                await writer.WriteRecordBatchAsync(originalBatch);
+                await writer.WriteEndAsync();
+                buffer = stream.ToArray();
+            }
+
+            using (MemoryStream stream = new MemoryStream(buffer))
+            {
+                ArrowStreamReader reader = new ArrowStreamReader(stream);
+                await ArrowReaderVerifier.VerifyReaderAsync(reader, 
originalBatch);
+            }
+        }
+
+        [Fact]
+        public async Task 
ReadRecordBatchAsync_NonPubliclyVisibleMemoryStream_UsesExplicitAllocator()
+        {
+            RecordBatch originalBatch = 
TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: false);
+
+            byte[] buffer;
+            using (MemoryStream stream = new MemoryStream())
+            {
+                ArrowStreamWriter writer = new ArrowStreamWriter(stream, 
originalBatch.Schema, leaveOpen: true);
+                await writer.WriteRecordBatchAsync(originalBatch);
+                await writer.WriteEndAsync();
+                buffer = stream.ToArray();
+            }
+
+            var allocator = new TestMemoryAllocator();
+            using (MemoryStream stream = new MemoryStream(buffer))
+            using (var reader = new ArrowStreamReader(stream, allocator))
+            {
+                using (RecordBatch readBatch = await 
reader.ReadNextRecordBatchAsync())
+                {
+                    ArrowReaderVerifier.CompareBatches(originalBatch, 
readBatch);
+                }
+
+                Assert.True(allocator.Statistics.Allocations > 0);
+                Assert.Equal(0, allocator.Rented);
+                Assert.Null(await reader.ReadNextRecordBatchAsync());
+            }
+        }
+
+        [Fact]
+        public async Task 
ReadRecordBatchAsync_NonPubliclyVisibleMemoryStream_UsesDefaultAllocator()
+        {
+            RecordBatch originalBatch = 
TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: false);
+
+            byte[] buffer;
+            using (MemoryStream stream = new MemoryStream())
+            {
+                ArrowStreamWriter writer = new ArrowStreamWriter(stream, 
originalBatch.Schema, leaveOpen: true);
+                await writer.WriteRecordBatchAsync(originalBatch);
+                await writer.WriteEndAsync();
+                buffer = stream.ToArray();
+            }
+
+            long allocationsBeforeRead = 
MemoryAllocator.Default.Value.Statistics.Allocations;
+
+            using (MemoryStream stream = new MemoryStream(buffer))
+            using (var reader = new ArrowStreamReader(stream, 
MemoryAllocator.Default.Value))
+            using (RecordBatch readBatch = await 
reader.ReadNextRecordBatchAsync())
+            {
+                ArrowReaderVerifier.CompareBatches(originalBatch, readBatch);
+            }
+
+            Assert.True(MemoryAllocator.Default.Value.Statistics.Allocations > 
allocationsBeforeRead);
+        }
+
         private static async Task TestReaderFromStream(
             Func<ArrowStreamReader, RecordBatch, Task> verificationFunc,
             bool writeEnd, bool createDictionaryArray)
@@ -199,6 +354,92 @@ namespace Apache.Arrow.Tests
             }
         }
 
+        [Fact]
+        public async Task 
ReadRecordBatch_ExposedMemoryStream_BatchRemainsUsableAfterDispose()
+        {
+            RecordBatch originalBatch = 
TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: false);
+            RecordBatch readBatch;
+
+            using (MemoryStream stream = new MemoryStream())
+            {
+                ArrowStreamWriter writer = new ArrowStreamWriter(stream, 
originalBatch.Schema, leaveOpen: true);
+                await writer.WriteRecordBatchAsync(originalBatch);
+                await writer.WriteEndAsync();
+
+                stream.Position = 0;
+
+                using (ArrowStreamReader reader = new 
ArrowStreamReader(stream, leaveOpen: true))
+                {
+                    readBatch = reader.ReadNextRecordBatch();
+                }
+            }
+
+            using (readBatch)
+            {
+                ArrowReaderVerifier.CompareBatches(originalBatch, readBatch);
+            }
+        }
+
+        [Fact]
+        public async Task 
ReadRecordBatch_ExposedMemoryStream_BatchDoesNotAliasMutableStreamBuffer()
+        {
+            RecordBatch originalBatch = 
TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: false);
+            RecordBatch readBatch;
+            byte[] streamBuffer;
+
+            using (MemoryStream stream = new MemoryStream())
+            {
+                ArrowStreamWriter writer = new ArrowStreamWriter(stream, 
originalBatch.Schema, leaveOpen: true);
+                await writer.WriteRecordBatchAsync(originalBatch);
+                await writer.WriteEndAsync();
+
+                streamBuffer = stream.GetBuffer();
+                stream.Position = 0;
+
+                using (ArrowStreamReader reader = new 
ArrowStreamReader(stream, leaveOpen: true))
+                {
+                    readBatch = reader.ReadNextRecordBatch();
+                }
+            }
+
+            System.Array.Clear(streamBuffer, 0, streamBuffer.Length);
+
+            using (readBatch)
+            {
+                ArrowReaderVerifier.CompareBatches(originalBatch, readBatch);
+            }
+        }
+
+        [Fact]
+        public async Task 
ReadRecordBatchAsync_ExposedMemoryStream_BatchDoesNotAliasMutableStreamBuffer()
+        {
+            RecordBatch originalBatch = 
TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: false);
+            RecordBatch readBatch;
+            byte[] streamBuffer;
+
+            using (MemoryStream stream = new MemoryStream())
+            {
+                ArrowStreamWriter writer = new ArrowStreamWriter(stream, 
originalBatch.Schema, leaveOpen: true);
+                await writer.WriteRecordBatchAsync(originalBatch);
+                await writer.WriteEndAsync();
+
+                streamBuffer = stream.GetBuffer();
+                stream.Position = 0;
+
+                using (ArrowStreamReader reader = new 
ArrowStreamReader(stream, leaveOpen: true))
+                {
+                    readBatch = await reader.ReadNextRecordBatchAsync();
+                }
+            }
+
+            System.Array.Clear(streamBuffer, 0, streamBuffer.Length);
+
+            using (readBatch)
+            {
+                ArrowReaderVerifier.CompareBatches(originalBatch, readBatch);
+            }
+        }
+
         [Theory]
         [InlineData(true)]
         [InlineData(false)]
@@ -243,11 +484,34 @@ namespace Apache.Arrow.Tests
         /// <summary>
         /// A stream class that only returns a part of the data at a time.
         /// </summary>
-        private class PartialReadStream : MemoryStream
+        private class PartialReadStream : Stream
         {
+            private readonly MemoryStream _innerStream = new MemoryStream();
+
             // by default return 20 bytes at a time
             public int PartialReadLength { get; set; } = 20;
 
+            public override bool CanRead => _innerStream.CanRead;
+            public override bool CanSeek => _innerStream.CanSeek;
+            public override bool CanWrite => _innerStream.CanWrite;
+            public override long Length => _innerStream.Length;
+            public override long Position { get => _innerStream.Position; set 
=> _innerStream.Position = value; }
+
+            public override void Flush() => _innerStream.Flush();
+            public override long Seek(long offset, SeekOrigin origin) => 
_innerStream.Seek(offset, origin);
+            public override void SetLength(long value) => 
_innerStream.SetLength(value);
+            public override void Write(byte[] buffer, int offset, int count) 
=> _innerStream.Write(buffer, offset, count);
+
+            public override int Read(byte[] buffer, int offset, int count)
+            {
+                return _innerStream.Read(buffer, offset, Math.Min(count, 
PartialReadLength));
+            }
+
+            public override Task<int> ReadAsync(byte[] buffer, int offset, int 
count, CancellationToken cancellationToken = default)
+            {
+                return _innerStream.ReadAsync(buffer, offset, Math.Min(count, 
PartialReadLength), cancellationToken);
+            }
+
 #if NET5_0_OR_GREATER
             public override int Read(Span<byte> destination)
             {
@@ -256,7 +520,7 @@ namespace Apache.Arrow.Tests
                     destination = destination.Slice(0, PartialReadLength);
                 }
 
-                return base.Read(destination);
+                return _innerStream.Read(destination);
             }
 
             public override ValueTask<int> ReadAsync(Memory<byte> destination, 
CancellationToken cancellationToken = default)
@@ -266,53 +530,50 @@ namespace Apache.Arrow.Tests
                     destination = destination.Slice(0, PartialReadLength);
                 }
 
-                return base.ReadAsync(destination, cancellationToken);
-            }
-#else
-            public override int Read(byte[] buffer, int offset, int length)
-            {
-                return base.Read(buffer, offset, Math.Min(length, 
PartialReadLength));
-            }
-
-            public override Task<int> ReadAsync(byte[] buffer, int offset, int 
length, CancellationToken cancellationToken = default)
-            {
-                return base.ReadAsync(buffer, offset, Math.Min(length, 
PartialReadLength), cancellationToken);
+                return _innerStream.ReadAsync(destination, cancellationToken);
             }
 #endif
         }
 
-        [Fact]
-        public unsafe void MalformedColumnNameLength()
+        private class RequiresCancelableReadStream : Stream
         {
-            const int FieldNameLengthOffset = 108;
-            const int FakeFieldNameLength = 165535;
+            public bool SawCancelableToken { get; private set; }
 
-            byte[] buffer;
-            using (var stream = new MemoryStream())
+            public override bool CanRead => true;
+            public override bool CanSeek => false;
+            public override bool CanWrite => false;
+            public override long Length => throw new NotSupportedException();
+            public override long Position { get => throw new 
NotSupportedException(); set => throw new NotSupportedException(); }
+
+            public override void Flush() { }
+            public override int Read(byte[] buffer, int offset, int count) => 
throw new NotSupportedException();
+            public override long Seek(long offset, SeekOrigin origin) => throw 
new NotSupportedException();
+            public override void SetLength(long value) => throw new 
NotSupportedException();
+            public override void Write(byte[] buffer, int offset, int count) 
=> throw new NotSupportedException();
+
+#if NET5_0_OR_GREATER
+            public override ValueTask<int> ReadAsync(Memory<byte> buffer, 
CancellationToken cancellationToken = default)
             {
-                Schema schema = new(
-                    [new Field("index", Int32Type.Default, nullable: false)],
-                    metadata: []);
-                using (var writer = new ArrowStreamWriter(stream, schema, 
leaveOpen: true))
+                SawCancelableToken = cancellationToken.CanBeCanceled;
+                if (!SawCancelableToken)
                 {
-                    writer.WriteStart();
-                    writer.WriteEnd();
+                    throw new InvalidOperationException("Expected the caller's 
cancellation token during schema read.");
                 }
-                buffer = stream.ToArray();
-            }
 
-            Span<int> length = buffer.AsSpan().Slice(FieldNameLengthOffset, 
sizeof(int)).CastTo<int>();
-            Assert.Equal(5, length[0]);
-            length[0] = FakeFieldNameLength;
-
-            Assert.Throws<ArgumentOutOfRangeException>(() =>
+                throw new OperationCanceledException(cancellationToken);
+            }
+#else
+            public override Task<int> ReadAsync(byte[] buffer, int offset, int 
count, CancellationToken cancellationToken)
             {
-                using (var stream = new MemoryStream(buffer))
-                using (var reader = new ArrowStreamReader(stream))
+                SawCancelableToken = cancellationToken.CanBeCanceled;
+                if (!SawCancelableToken)
                 {
-                    reader.ReadNextRecordBatch();
+                    throw new InvalidOperationException("Expected the caller's 
cancellation token during schema read.");
                 }
-            });
+
+                throw new OperationCanceledException(cancellationToken);
+            }
+#endif
         }
 
         [Fact]
@@ -468,34 +729,52 @@ namespace Apache.Arrow.Tests
 
         private static void WriteInt64LittleEndian(byte[] buffer, int offset, 
long value)
         {
-            System.Buffers.Binary.BinaryPrimitives.WriteInt64LittleEndian(
-                buffer.AsSpan(offset), value);
+            BinaryPrimitives.WriteInt64LittleEndian(buffer.AsSpan(offset), 
value);
         }
 
         [Fact]
-        public async Task EmptyStreamNoSyncRead()
+        public unsafe void MalformedColumnNameLength()
         {
-            using (var stream = new EmptyAsyncOnlyStream())
+            const int FieldNameLengthOffset = 108;
+            const int FakeFieldNameLength = 165535;
+
+            byte[] buffer;
+            using (var stream = new MemoryStream())
             {
-                var reader = new ArrowStreamReader(stream);
-                var schema = await reader.GetSchema();
-                Assert.Null(schema);
+                Schema schema = new(
+                    [new Field("index", Int32Type.Default, nullable: false)],
+                    metadata: []);
+                using (var writer = new ArrowStreamWriter(stream, schema, 
leaveOpen: true))
+                {
+                    writer.WriteStart();
+                    writer.WriteEnd();
+                }
+                buffer = stream.ToArray();
             }
-        }
 
-        private static short ToInt16LittleEndian(byte[] buffer, int offset)
-        {
-            return 
BinaryPrimitives.ReadInt16LittleEndian(buffer.AsSpan().Slice(offset));
-        }
+            Span<int> length = buffer.AsSpan().Slice(FieldNameLengthOffset, 
sizeof(int)).CastTo<int>();
+            Assert.Equal(5, length[0]);
+            length[0] = FakeFieldNameLength;
 
-        private static int ToInt32LittleEndian(byte[] buffer, int offset)
-        {
-            return 
BinaryPrimitives.ReadInt32LittleEndian(buffer.AsSpan().Slice(offset));
+            Assert.Throws<ArgumentOutOfRangeException>(() =>
+            {
+                using (var stream = new MemoryStream(buffer))
+                using (var reader = new ArrowStreamReader(stream))
+                {
+                    reader.ReadNextRecordBatch();
+                }
+            });
         }
 
-        private static long ToInt64LittleEndian(byte[] buffer, int offset)
+        [Fact]
+        public async Task EmptyStreamNoSyncRead()
         {
-            return 
BinaryPrimitives.ReadInt64LittleEndian(buffer.AsSpan().Slice(offset));
+            using (var stream = new EmptyAsyncOnlyStream())
+            {
+                var reader = new ArrowStreamReader(stream);
+                var schema = await reader.GetSchema();
+                Assert.Null(schema);
+            }
         }
 
         private class EmptyAsyncOnlyStream : Stream
@@ -512,5 +791,20 @@ namespace Apache.Arrow.Tests
             public override void Write(byte[] buffer, int offset, int count) 
=> throw new NotSupportedException();
             public override Task<int> ReadAsync(byte[] buffer, int offset, int 
count, CancellationToken cancellationToken) => Task.FromResult(0);
         }
+
+        private static short ToInt16LittleEndian(byte[] buffer, int offset)
+        {
+            return 
BinaryPrimitives.ReadInt16LittleEndian(buffer.AsSpan().Slice(offset));
+        }
+
+        private static int ToInt32LittleEndian(byte[] buffer, int offset)
+        {
+            return 
BinaryPrimitives.ReadInt32LittleEndian(buffer.AsSpan().Slice(offset));
+        }
+
+        private static long ToInt64LittleEndian(byte[] buffer, int offset)
+        {
+            return 
BinaryPrimitives.ReadInt64LittleEndian(buffer.AsSpan().Slice(offset));
+        }
     }
 }


Reply via email to