This is an automated email from the ASF dual-hosted git repository.
CurtHagenlocher pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-dotnet.git
The following commit(s) were added to refs/heads/main by this push:
new 3434344 perf: speed up MemoryStream IPC stream reads (#340)
3434344 is described below
commit 343434468c749c47ed0f3d20823bb29717a3c33c
Author: InCerryGit <[email protected]>
AuthorDate: Wed Apr 29 00:44:58 2026 +0800
perf: speed up MemoryStream IPC stream reads (#340)
## Summary
This improves `ArrowStreamReader` when reading from `MemoryStream`
instances that expose their underlying buffer. The reader now uses the
exposed buffer for IPC message/schema metadata reads while preserving
the existing reader-owned body-buffer boundary.
The change is intentionally scoped to MemoryStream-backed IPC stream
reads:
- public/exposed `MemoryStream` can use the fast path
- non-public `MemoryStream` and partial-read streams continue through
the fallback stream-read path
- record batch body data is still copied into allocator-owned memory
before array construction
- `ArrowMemoryReader` exact-length continuation-token handling is
corrected for complete in-memory buffers
## Benchmark
BenchmarkDotNet ShortRun, `ArrowReaderBenchmark`:
| Scenario | Before | After |
|---|---:|---:|
| `ArrowReaderWithMemoryStream_ManagedMemory`, 100000 rows / 1 column |
21629.3 us | 7707.6 us |
| `ArrowReaderWithMemoryStream_ManagedMemory`, 100000 rows / 5 columns |
91112.3 us | 40137.5 us |
## Validation
- `dotnet test test/Apache.Arrow.Tests/Apache.Arrow.Tests.csproj -c
Release --filter
"FullyQualifiedName~Apache.Arrow.Tests.ArrowStreamReaderTests"`
- `dotnet test
test/Apache.Arrow.Compression.Tests/Apache.Arrow.Compression.Tests.csproj
-c Release --filter
"FullyQualifiedName~Apache.Arrow.Compression.Tests.ArrowStreamReaderTests"`
- `dotnet build Apache.Arrow.sln -c Release`
---
.../Ipc/ArrowMemoryReaderImplementation.cs | 4 +-
.../Ipc/ArrowMemoryStreamReaderImplementation.cs | 214 +++++++++++
src/Apache.Arrow/Ipc/ArrowStreamReader.cs | 19 +-
.../Ipc/ArrowStreamReaderImplementation.cs | 27 +-
.../ArrowReaderBenchmark.cs | 92 ++++-
.../ArrowStreamReaderTests.cs | 41 ++-
test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs | 408 ++++++++++++++++++---
7 files changed, 730 insertions(+), 75 deletions(-)
diff --git a/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs
b/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs
index e9cd936..887036c 100644
--- a/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs
+++ b/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs
@@ -87,7 +87,7 @@ namespace Apache.Arrow.Ipc
else if (messageLength ==
MessageSerializer.IpcContinuationToken)
{
// ARROW-6313, if the first 4 bytes are continuation
message, read the next 4 for the length
- if (_buffer.Length <= _bufferPosition + sizeof(int))
+ if (_buffer.Length < _bufferPosition + sizeof(int))
{
throw new InvalidDataException("Corrupted IPC message.
Received a continuation token at the end of the message.");
}
@@ -136,7 +136,7 @@ namespace Apache.Arrow.Ipc
if (schemaMessageLength == MessageSerializer.IpcContinuationToken)
{
// ARROW-6313, if the first 4 bytes are continuation message,
read the next 4 for the length
- if (_buffer.Length <= _bufferPosition + sizeof(int))
+ if (_buffer.Length < _bufferPosition + sizeof(int))
{
throw new InvalidDataException("Corrupted IPC message.
Received a continuation token at the end of the message.");
}
diff --git a/src/Apache.Arrow/Ipc/ArrowMemoryStreamReaderImplementation.cs
b/src/Apache.Arrow/Ipc/ArrowMemoryStreamReaderImplementation.cs
new file mode 100644
index 0000000..01d3b0b
--- /dev/null
+++ b/src/Apache.Arrow/Ipc/ArrowMemoryStreamReaderImplementation.cs
@@ -0,0 +1,214 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+using System;
+using System.Buffers;
+using System.IO;
+using System.Threading;
+using System.Threading.Tasks;
+using Apache.Arrow.Memory;
+
+namespace Apache.Arrow.Ipc
+{
+ /// <summary>
+ /// Reads Arrow IPC streams from a <see cref="MemoryStream"/> whose
backing buffer is publicly visible.
+ /// </summary>
+ /// <remarks>
+ /// Message metadata can be read directly from the exposed stream buffer,
but record batch bodies are
+ /// still copied into allocator-owned buffers to preserve <see
cref="ArrowStreamReader"/> ownership semantics.
+ /// </remarks>
+ internal sealed class ArrowMemoryStreamReaderImplementation :
ArrowStreamReaderImplementation
+ {
+ private readonly MemoryStream _stream;
+ private readonly Memory<byte> _streamMemory;
+
+ public ArrowMemoryStreamReaderImplementation(
+ MemoryStream stream,
+ MemoryAllocator allocator,
+ ICompressionCodecFactory compressionCodecFactory,
+ bool leaveOpen,
+ ExtensionTypeRegistry extensionRegistry)
+ : base(stream, allocator, compressionCodecFactory, leaveOpen,
extensionRegistry)
+ {
+ _stream = stream;
+
+ if (!stream.TryGetBuffer(out ArraySegment<byte> streamBuffer))
+ {
+ throw new InvalidOperationException("Expected MemoryStream to
expose its backing buffer.");
+ }
+
+ _streamMemory = streamBuffer.Array.AsMemory(streamBuffer.Offset,
streamBuffer.Count);
+ }
+
+ public override ValueTask<RecordBatch>
ReadNextRecordBatchAsync(CancellationToken cancellationToken)
+ {
+ cancellationToken.ThrowIfCancellationRequested();
+
+ try
+ {
+ return new ValueTask<RecordBatch>(ReadNextRecordBatch());
+ }
+ catch (Exception ex)
+ {
+ return new
ValueTask<RecordBatch>(Task.FromException<RecordBatch>(ex));
+ }
+ }
+
+ public override RecordBatch ReadNextRecordBatch()
+ {
+ ReadSchema();
+
+ ReadResult result = default;
+ do
+ {
+ result = ReadMessageFromMemory();
+ } while (result.Batch == null && result.MessageLength > 0);
+
+ return result.Batch;
+ }
+
+ public override ValueTask<Schema> ReadSchemaAsync(CancellationToken
cancellationToken = default)
+ {
+ cancellationToken.ThrowIfCancellationRequested();
+
+ if (HasReadSchema)
+ {
+ return new ValueTask<Schema>(_schema);
+ }
+
+ try
+ {
+ ReadSchema();
+ return new ValueTask<Schema>(_schema);
+ }
+ catch (Exception ex)
+ {
+ return new ValueTask<Schema>(Task.FromException<Schema>(ex));
+ }
+ }
+
+ public override void ReadSchema()
+ {
+ if (HasReadSchema)
+ {
+ return;
+ }
+
+ int schemaMessageLength =
ReadMessageLengthFromMemory(throwOnFullRead: true, returnOnEmptyStream: true);
+ if (schemaMessageLength == 0)
+ {
+ return;
+ }
+
+ Memory<byte> schemaBuffer = ReadMemory(schemaMessageLength);
+ _schema =
MessageSerializer.GetSchema(ReadMessage<Flatbuf.Schema>(CreateByteBuffer(schemaBuffer)),
ref _dictionaryMemo, _extensionRegistry);
+ }
+
+ private ReadResult ReadMessageFromMemory()
+ {
+ int messageLength = ReadMessageLengthFromMemory(throwOnFullRead:
false, returnOnEmptyStream: false);
+ if (messageLength == 0)
+ {
+ return default;
+ }
+
+ Memory<byte> messageBuffer = ReadMemory(messageLength);
+ Flatbuf.Message message =
Flatbuf.Message.GetRootAsMessage(CreateByteBuffer(messageBuffer));
+
+ if (message.BodyLength > int.MaxValue)
+ {
+ throw new OverflowException(
+ $"Arrow IPC message body length ({message.BodyLength}) is
larger than " +
+ $"the maximum supported message size ({int.MaxValue})");
+ }
+
+ int bodyLength = (int)message.BodyLength;
+ Memory<byte> sourceBodyBuffer = ReadMemory(bodyLength);
+ IMemoryOwner<byte> bodyBufferOwner =
AllocateMessageBodyBuffer(bodyLength);
+ Memory<byte> bodyBuffer = bodyBufferOwner.Memory.Slice(0,
bodyLength);
+ sourceBodyBuffer.CopyTo(bodyBuffer);
+ Google.FlatBuffers.ByteBuffer bodybb =
CreateByteBuffer(bodyBuffer);
+
+ // Keep stream-reader ownership semantics: batches outlive the
source MemoryStream buffer.
+ return new ReadResult(messageLength,
CreateArrowObjectFromMessage(message, bodybb, bodyBufferOwner));
+ }
+
+ private int ReadMessageLengthFromMemory(bool throwOnFullRead, bool
returnOnEmptyStream)
+ {
+ if (_stream.Position == _stream.Length && returnOnEmptyStream)
+ {
+ return 0;
+ }
+
+ if (!TryReadInt32(throwOnFullRead, out int messageLength))
+ {
+ return 0;
+ }
+
+ if (messageLength == MessageSerializer.IpcContinuationToken &&
+ !TryReadInt32(throwOnFullRead, out messageLength))
+ {
+ return 0;
+ }
+
+ return messageLength;
+ }
+
+ private bool TryReadInt32(bool throwOnFullRead, out int value)
+ {
+ value = 0;
+
+ if (!TryReadMemory(sizeof(int), throwOnFullRead, out Memory<byte>
buffer))
+ {
+ return false;
+ }
+
+ value = BitUtility.ReadInt32(buffer);
+ return true;
+ }
+
+ private bool TryReadMemory(int length, bool throwOnFullRead, out
Memory<byte> buffer)
+ {
+ buffer = default;
+
+ long remainingLength = _stream.Length - _stream.Position;
+ if (remainingLength < length)
+ {
+ if (throwOnFullRead)
+ {
+ throw new InvalidOperationException("Unexpectedly reached
the end of the stream before a full buffer was read.");
+ }
+
+ _stream.Position = _stream.Length;
+ return false;
+ }
+
+ buffer = ReadMemory(length);
+ return true;
+ }
+
+ private Memory<byte> ReadMemory(int length)
+ {
+ if (length == 0)
+ {
+ return Memory<byte>.Empty;
+ }
+
+ Memory<byte> buffer =
_streamMemory.Slice(checked((int)_stream.Position), length);
+ _stream.Position += length;
+ return buffer;
+ }
+ }
+}
diff --git a/src/Apache.Arrow/Ipc/ArrowStreamReader.cs
b/src/Apache.Arrow/Ipc/ArrowStreamReader.cs
index e5dade2..6100eea 100644
--- a/src/Apache.Arrow/Ipc/ArrowStreamReader.cs
+++ b/src/Apache.Arrow/Ipc/ArrowStreamReader.cs
@@ -68,7 +68,7 @@ namespace Apache.Arrow.Ipc
if (stream == null)
throw new ArgumentNullException(nameof(stream));
- _implementation = new ArrowStreamReaderImplementation(stream,
allocator, compressionCodecFactory, leaveOpen);
+ _implementation = CreateImplementation(stream, allocator,
compressionCodecFactory, leaveOpen, extensionRegistry: null);
}
public ArrowStreamReader(ArrowContext context, Stream stream, bool
leaveOpen = false)
@@ -78,7 +78,7 @@ namespace Apache.Arrow.Ipc
if (context == null)
throw new ArgumentNullException(nameof(context));
- _implementation = new ArrowStreamReaderImplementation(stream,
context.Allocator, context.CompressionCodecFactory, leaveOpen,
context.ExtensionRegistry);
+ _implementation = CreateImplementation(stream, context.Allocator,
context.CompressionCodecFactory, leaveOpen, context.ExtensionRegistry);
}
public ArrowStreamReader(ReadOnlyMemory<byte> buffer)
@@ -104,6 +104,21 @@ namespace Apache.Arrow.Ipc
_implementation = implementation;
}
+ private static ArrowReaderImplementation CreateImplementation(
+ Stream stream,
+ MemoryAllocator allocator,
+ ICompressionCodecFactory compressionCodecFactory,
+ bool leaveOpen,
+ ExtensionTypeRegistry extensionRegistry)
+ {
+ if (stream is MemoryStream memoryStream &&
memoryStream.TryGetBuffer(out _))
+ {
+ return new ArrowMemoryStreamReaderImplementation(memoryStream,
allocator, compressionCodecFactory, leaveOpen, extensionRegistry);
+ }
+
+ return new ArrowStreamReaderImplementation(stream, allocator,
compressionCodecFactory, leaveOpen, extensionRegistry);
+ }
+
public void Dispose()
{
Dispose(true);
diff --git a/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs
b/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs
index 9d0cbe3..179b998 100644
--- a/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs
+++ b/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs
@@ -47,11 +47,10 @@ namespace Apache.Arrow.Ipc
}
}
- public override async ValueTask<RecordBatch>
ReadNextRecordBatchAsync(CancellationToken cancellationToken)
+ public override ValueTask<RecordBatch>
ReadNextRecordBatchAsync(CancellationToken cancellationToken)
{
- // TODO: Loop until a record batch is read.
cancellationToken.ThrowIfCancellationRequested();
- return await
ReadRecordBatchAsync(cancellationToken).ConfigureAwait(false);
+ return ReadRecordBatchAsync(cancellationToken);
}
public override RecordBatch ReadNextRecordBatch()
@@ -61,7 +60,7 @@ namespace Apache.Arrow.Ipc
protected async ValueTask<RecordBatch>
ReadRecordBatchAsync(CancellationToken cancellationToken = default)
{
- await ReadSchemaAsync().ConfigureAwait(false);
+ await ReadSchemaAsync(cancellationToken).ConfigureAwait(false);
ReadResult result = default;
do
@@ -94,7 +93,7 @@ namespace Apache.Arrow.Ipc
int bodyLength = checked((int)message.BodyLength);
- IMemoryOwner<byte> bodyBuffOwner =
_allocator.Allocate(bodyLength);
+ IMemoryOwner<byte> bodyBuffOwner =
AllocateMessageBodyBuffer(bodyLength);
Memory<byte> bodyBuff = bodyBuffOwner.Memory.Slice(0,
bodyLength);
bytesRead = await BaseStream.ReadFullBufferAsync(bodyBuff,
cancellationToken)
.ConfigureAwait(false);
@@ -145,7 +144,7 @@ namespace Apache.Arrow.Ipc
}
int bodyLength = (int)message.BodyLength;
- IMemoryOwner<byte> bodyBuffOwner =
_allocator.Allocate(bodyLength);
+ IMemoryOwner<byte> bodyBuffOwner =
AllocateMessageBodyBuffer(bodyLength);
Memory<byte> bodyBuff = bodyBuffOwner.Memory.Slice(0,
bodyLength);
bytesRead = BaseStream.ReadFullBuffer(bodyBuff);
EnsureFullRead(bodyBuff, bytesRead);
@@ -157,13 +156,25 @@ namespace Apache.Arrow.Ipc
return new ReadResult(messageLength, result);
}
- public override async ValueTask<Schema>
ReadSchemaAsync(CancellationToken cancellationToken = default)
+ protected IMemoryOwner<byte> AllocateMessageBodyBuffer(int bodyLength)
{
+ return _allocator.Allocate(bodyLength);
+ }
+
+ public override ValueTask<Schema> ReadSchemaAsync(CancellationToken
cancellationToken = default)
+ {
+ cancellationToken.ThrowIfCancellationRequested();
+
if (HasReadSchema)
{
- return _schema;
+ return new ValueTask<Schema>(_schema);
}
+ return ReadSchemaAsyncCore(cancellationToken);
+ }
+
+ private async ValueTask<Schema> ReadSchemaAsyncCore(CancellationToken
cancellationToken)
+ {
// Figure out length of schema
int schemaMessageLength = await
ReadMessageLengthAsync(throwOnFullRead: true, returnOnEmptyStream: true,
cancellationToken)
.ConfigureAwait(false);
diff --git a/test/Apache.Arrow.Benchmarks/ArrowReaderBenchmark.cs
b/test/Apache.Arrow.Benchmarks/ArrowReaderBenchmark.cs
index 3760d94..9305adb 100644
--- a/test/Apache.Arrow.Benchmarks/ArrowReaderBenchmark.cs
+++ b/test/Apache.Arrow.Benchmarks/ArrowReaderBenchmark.cs
@@ -14,8 +14,8 @@
// limitations under the License.
using System;
+using System.Collections.Generic;
using System.IO;
-using System.Linq;
using System.Threading.Tasks;
using Apache.Arrow.Ipc;
using Apache.Arrow.Memory;
@@ -32,13 +32,19 @@ namespace Apache.Arrow.Benchmarks
[Params(10_000, 1_000_000)]
public int Count { get; set; }
+ [Params(1, 5)]
+ public int ColumnSetCount { get; set; }
+
private MemoryStream _memoryStream;
private static readonly MemoryAllocator s_allocator = new
TestMemoryAllocator();
[GlobalSetup]
public async Task GlobalSetup()
{
- RecordBatch batch = TestData.CreateSampleRecordBatch(length:
Count, createDictionaryArray: false);
+ RecordBatch batch = TestData.CreateSampleRecordBatch(
+ length: Count,
+ columnSetCount: ColumnSetCount,
+ excludedTypes: new HashSet<ArrowTypeId> {
ArrowTypeId.Dictionary, ArrowTypeId.RunEndEncoded });
_memoryStream = new MemoryStream();
ArrowStreamWriter writer = new ArrowStreamWriter(_memoryStream,
batch.Schema);
@@ -83,6 +89,73 @@ namespace Apache.Arrow.Benchmarks
return sum;
}
+ [Benchmark]
+ public async Task<double>
ArrowReaderWithMemoryStream_ExplicitDefaultAllocator()
+ {
+ double sum = 0;
+ var reader = new ArrowStreamReader(_memoryStream,
MemoryAllocator.Default.Value);
+ RecordBatch recordBatch;
+ while ((recordBatch = await reader.ReadNextRecordBatchAsync()) !=
null)
+ {
+ using (recordBatch)
+ {
+ sum += SumAllNumbers(recordBatch);
+ }
+ }
+ return sum;
+ }
+
+ [Benchmark]
+ public async Task<double>
ArrowReaderWithNonPubliclyVisibleMemoryStream()
+ {
+ double sum = 0;
+ using var stream = CreateNonPubliclyVisibleReadStream();
+ using var reader = new ArrowStreamReader(stream);
+ RecordBatch recordBatch;
+ while ((recordBatch = await reader.ReadNextRecordBatchAsync()) !=
null)
+ {
+ using (recordBatch)
+ {
+ sum += SumAllNumbers(recordBatch);
+ }
+ }
+ return sum;
+ }
+
+ [Benchmark]
+ public async Task<double>
ArrowReaderWithNonPubliclyVisibleMemoryStream_ManagedMemory()
+ {
+ double sum = 0;
+ using var stream = CreateNonPubliclyVisibleReadStream();
+ using var reader = new ArrowStreamReader(stream, s_allocator);
+ RecordBatch recordBatch;
+ while ((recordBatch = await reader.ReadNextRecordBatchAsync()) !=
null)
+ {
+ using (recordBatch)
+ {
+ sum += SumAllNumbers(recordBatch);
+ }
+ }
+ return sum;
+ }
+
+ [Benchmark]
+ public async Task<double>
ArrowReaderWithNonPubliclyVisibleMemoryStream_ExplicitDefaultAllocator()
+ {
+ double sum = 0;
+ using var stream = CreateNonPubliclyVisibleReadStream();
+ using var reader = new ArrowStreamReader(stream,
MemoryAllocator.Default.Value);
+ RecordBatch recordBatch;
+ while ((recordBatch = await reader.ReadNextRecordBatchAsync()) !=
null)
+ {
+ using (recordBatch)
+ {
+ sum += SumAllNumbers(recordBatch);
+ }
+ }
+ return sum;
+ }
+
[Benchmark]
public async Task<double> ArrowReaderWithMemory()
{
@@ -99,14 +172,25 @@ namespace Apache.Arrow.Benchmarks
return sum;
}
+ private MemoryStream CreateNonPubliclyVisibleReadStream()
+ {
+ return new MemoryStream(
+ _memoryStream.GetBuffer(),
+ index: 0,
+ count: checked((int)_memoryStream.Length),
+ writable: false,
+ publiclyVisible: false);
+ }
+
private static double SumAllNumbers(RecordBatch recordBatch)
{
double sum = 0;
for (int k = 0; k < recordBatch.ColumnCount; k++)
{
- var array = recordBatch.Arrays.ElementAt(k);
- switch (recordBatch.Schema.GetFieldByIndex(k).DataType.TypeId)
+ var array = recordBatch.Column(k);
+ ArrowTypeId typeId =
recordBatch.Schema.GetFieldByIndex(k).DataType.TypeId;
+ switch (typeId)
{
case ArrowTypeId.Int64:
Int64Array int64Array = (Int64Array)array;
diff --git a/test/Apache.Arrow.Compression.Tests/ArrowStreamReaderTests.cs
b/test/Apache.Arrow.Compression.Tests/ArrowStreamReaderTests.cs
index 9c2bf75..99e7442 100644
--- a/test/Apache.Arrow.Compression.Tests/ArrowStreamReaderTests.cs
+++ b/test/Apache.Arrow.Compression.Tests/ArrowStreamReaderTests.cs
@@ -14,8 +14,10 @@
// limitations under the License.
using System;
+using System.IO;
using System.Reflection;
using Apache.Arrow.Ipc;
+using Apache.Arrow.Memory;
using Apache.Arrow.Tests;
using Xunit;
@@ -47,13 +49,34 @@ namespace Apache.Arrow.Compression.Tests
using var stream =
assembly.GetManifestResourceStream($"Apache.Arrow.Compression.Tests.Resources.{fileName}");
Assert.NotNull(stream);
var buffer = new byte[stream.Length];
- stream.ReadFullBuffer(buffer);
+ ReadExactly(stream, buffer);
var codecFactory = new Compression.CompressionCodecFactory();
using var reader = new ArrowStreamReader(buffer, codecFactory);
VerifyCompressedIpcFileBatch(reader.ReadNextRecordBatch());
}
+ [Theory]
+ [InlineData("ipc_lz4_compression.arrow_stream")]
+ [InlineData("ipc_zstd_compression.arrow_stream")]
+ public void
CanReadCompressedIpcStreamFromMemoryBuffer_UsesDefaultAllocator(string fileName)
+ {
+ var assembly = Assembly.GetExecutingAssembly();
+ using var stream =
assembly.GetManifestResourceStream($"Apache.Arrow.Compression.Tests.Resources.{fileName}");
+ Assert.NotNull(stream);
+ var buffer = new byte[stream.Length];
+ ReadExactly(stream, buffer);
+ var codecFactory = new Compression.CompressionCodecFactory();
+
+ long allocationsBeforeRead =
MemoryAllocator.Default.Value.Statistics.Allocations;
+
+ using var reader = new ArrowStreamReader(buffer, codecFactory);
+ using RecordBatch batch = reader.ReadNextRecordBatch();
+ VerifyCompressedIpcFileBatch(batch);
+
+ Assert.True(MemoryAllocator.Default.Value.Statistics.Allocations >
allocationsBeforeRead);
+ }
+
[Fact]
public void ErrorReadingCompressedStreamWithoutCodecFactory()
{
@@ -86,6 +109,21 @@ namespace Apache.Arrow.Compression.Tests
}
+ private static void ReadExactly(Stream stream, byte[] buffer)
+ {
+ int offset = 0;
+ while (offset < buffer.Length)
+ {
+ int bytesRead = stream.Read(buffer, offset, buffer.Length -
offset);
+ if (bytesRead == 0)
+ {
+ throw new EndOfStreamException();
+ }
+
+ offset += bytesRead;
+ }
+ }
+
private static void VerifyCompressedIpcFileBatch(RecordBatch batch)
{
var intArray = (Int32Array)batch.Column("integers");
@@ -103,4 +141,3 @@ namespace Apache.Arrow.Compression.Tests
}
}
}
-
diff --git a/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs
b/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs
index 5e7c57e..d04e0cd 100644
--- a/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs
+++ b/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs
@@ -16,9 +16,11 @@
using System;
using System.Buffers.Binary;
using System.IO;
+using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Ipc;
+using Apache.Arrow.Memory;
using Apache.Arrow.Types;
using Xunit;
@@ -69,13 +71,20 @@ namespace Apache.Arrow.Tests
var memoryPool = new TestMemoryAllocator();
ArrowStreamReader reader = new ArrowStreamReader(stream,
memoryPool, shouldLeaveOpen);
- reader.ReadNextRecordBatch();
-
- Assert.Equal(expectedAllocations,
memoryPool.Statistics.Allocations);
- Assert.True(memoryPool.Statistics.BytesAllocated > 0);
+ using (RecordBatch readBatch = reader.ReadNextRecordBatch())
+ {
+ Assert.Equal(expectedAllocations,
memoryPool.Statistics.Allocations);
+ Assert.True(memoryPool.Statistics.BytesAllocated > 0);
+ Assert.Equal(expectedAllocations, memoryPool.Rented);
+ }
reader.Dispose();
+ if (!createDictionaryArray)
+ {
+ Assert.Equal(0, memoryPool.Rented);
+ }
+
if (shouldLeaveOpen)
{
Assert.True(stream.Position > 0);
@@ -109,6 +118,22 @@ namespace Apache.Arrow.Tests
await TestReaderFromMemory(ArrowReaderVerifier.VerifyReaderAsync,
writeEnd);
}
+ [Fact]
+ public async Task ReadRecordBatch_Memory_ExactLengthSlice()
+ {
+ await TestReaderFromMemoryExactLength((reader, originalBatch) =>
+ {
+ ArrowReaderVerifier.VerifyReader(reader, originalBatch);
+ return Task.CompletedTask;
+ });
+ }
+
+ [Fact]
+ public async Task ReadRecordBatchAsync_Memory_ExactLengthSlice()
+ {
+ await
TestReaderFromMemoryExactLength(ArrowReaderVerifier.VerifyReaderAsync);
+ }
+
private static async Task TestReaderFromMemory(
Func<ArrowStreamReader, RecordBatch, Task> verificationFunc,
bool writeEnd)
@@ -131,6 +156,24 @@ namespace Apache.Arrow.Tests
await verificationFunc(reader, originalBatch);
}
+ private static async Task TestReaderFromMemoryExactLength(
+ Func<ArrowStreamReader, RecordBatch, Task> verificationFunc)
+ {
+ RecordBatch originalBatch =
TestData.CreateSampleRecordBatch(length: 100);
+
+ ReadOnlyMemory<byte> buffer;
+ using (MemoryStream stream = new MemoryStream())
+ {
+ ArrowStreamWriter writer = new ArrowStreamWriter(stream,
originalBatch.Schema);
+ await writer.WriteRecordBatchAsync(originalBatch);
+ await writer.WriteEndAsync();
+ buffer = stream.GetBuffer().AsMemory(0,
checked((int)stream.Length));
+ }
+
+ ArrowStreamReader reader = new ArrowStreamReader(buffer);
+ await verificationFunc(reader, originalBatch);
+ }
+
[Fact]
public void ReadRecordBatch_EmptyStream()
{
@@ -167,6 +210,40 @@ namespace Apache.Arrow.Tests
}
}
+ [Fact]
+ public async Task
ReadRecordBatchAsync_PassesCancellationTokenToSchemaRead()
+ {
+ using var stream = new RequiresCancelableReadStream();
+ using var reader = new ArrowStreamReader(stream);
+ using var cancellation = new CancellationTokenSource();
+
+ await Assert.ThrowsAnyAsync<OperationCanceledException>(async () =>
+ await reader.ReadNextRecordBatchAsync(cancellation.Token));
+
+ Assert.True(stream.SawCancelableToken);
+ }
+
+ [Fact]
+ public async Task
ReadRecordBatchAsync_Stream_DictionaryFixtureWithoutRee()
+ {
+ using RecordBatch originalBatch = TestData.CreateSampleRecordBatch(
+ length: 100,
+ columnSetCount: 5,
+ excludedTypes: new
System.Collections.Generic.HashSet<ArrowTypeId> { ArrowTypeId.RunEndEncoded });
+
+ using var stream = new MemoryStream();
+ using (ArrowStreamWriter writer = new ArrowStreamWriter(stream,
originalBatch.Schema, leaveOpen: true))
+ {
+ await writer.WriteRecordBatchAsync(originalBatch);
+ await writer.WriteEndAsync();
+ }
+
+ stream.Position = 0;
+
+ using var reader = new ArrowStreamReader(stream);
+ await ArrowReaderVerifier.VerifyReaderAsync(reader, originalBatch);
+ }
+
[Theory]
[InlineData(true, true)]
[InlineData(true, false)]
@@ -177,6 +254,84 @@ namespace Apache.Arrow.Tests
await TestReaderFromStream(ArrowReaderVerifier.VerifyReaderAsync,
writeEnd, createDictionaryArray);
}
+ [Theory]
+ [InlineData(true)]
+ [InlineData(false)]
+ public async Task
ReadRecordBatchAsync_NonPubliclyVisibleMemoryStream(bool createDictionaryArray)
+ {
+ RecordBatch originalBatch =
TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray:
createDictionaryArray);
+
+ byte[] buffer;
+ using (MemoryStream stream = new MemoryStream())
+ {
+ ArrowStreamWriter writer = new ArrowStreamWriter(stream,
originalBatch.Schema, leaveOpen: true);
+ await writer.WriteRecordBatchAsync(originalBatch);
+ await writer.WriteEndAsync();
+ buffer = stream.ToArray();
+ }
+
+ using (MemoryStream stream = new MemoryStream(buffer))
+ {
+ ArrowStreamReader reader = new ArrowStreamReader(stream);
+ await ArrowReaderVerifier.VerifyReaderAsync(reader,
originalBatch);
+ }
+ }
+
+ [Fact]
+ public async Task
ReadRecordBatchAsync_NonPubliclyVisibleMemoryStream_UsesExplicitAllocator()
+ {
+ RecordBatch originalBatch =
TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: false);
+
+ byte[] buffer;
+ using (MemoryStream stream = new MemoryStream())
+ {
+ ArrowStreamWriter writer = new ArrowStreamWriter(stream,
originalBatch.Schema, leaveOpen: true);
+ await writer.WriteRecordBatchAsync(originalBatch);
+ await writer.WriteEndAsync();
+ buffer = stream.ToArray();
+ }
+
+ var allocator = new TestMemoryAllocator();
+ using (MemoryStream stream = new MemoryStream(buffer))
+ using (var reader = new ArrowStreamReader(stream, allocator))
+ {
+ using (RecordBatch readBatch = await
reader.ReadNextRecordBatchAsync())
+ {
+ ArrowReaderVerifier.CompareBatches(originalBatch,
readBatch);
+ }
+
+ Assert.True(allocator.Statistics.Allocations > 0);
+ Assert.Equal(0, allocator.Rented);
+ Assert.Null(await reader.ReadNextRecordBatchAsync());
+ }
+ }
+
+ [Fact]
+ public async Task
ReadRecordBatchAsync_NonPubliclyVisibleMemoryStream_UsesDefaultAllocator()
+ {
+ RecordBatch originalBatch =
TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: false);
+
+ byte[] buffer;
+ using (MemoryStream stream = new MemoryStream())
+ {
+ ArrowStreamWriter writer = new ArrowStreamWriter(stream,
originalBatch.Schema, leaveOpen: true);
+ await writer.WriteRecordBatchAsync(originalBatch);
+ await writer.WriteEndAsync();
+ buffer = stream.ToArray();
+ }
+
+ long allocationsBeforeRead =
MemoryAllocator.Default.Value.Statistics.Allocations;
+
+ using (MemoryStream stream = new MemoryStream(buffer))
+ using (var reader = new ArrowStreamReader(stream,
MemoryAllocator.Default.Value))
+ using (RecordBatch readBatch = await
reader.ReadNextRecordBatchAsync())
+ {
+ ArrowReaderVerifier.CompareBatches(originalBatch, readBatch);
+ }
+
+ Assert.True(MemoryAllocator.Default.Value.Statistics.Allocations >
allocationsBeforeRead);
+ }
+
private static async Task TestReaderFromStream(
Func<ArrowStreamReader, RecordBatch, Task> verificationFunc,
bool writeEnd, bool createDictionaryArray)
@@ -199,6 +354,92 @@ namespace Apache.Arrow.Tests
}
}
+ [Fact]
+ public async Task
ReadRecordBatch_ExposedMemoryStream_BatchRemainsUsableAfterDispose()
+ {
+ RecordBatch originalBatch =
TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: false);
+ RecordBatch readBatch;
+
+ using (MemoryStream stream = new MemoryStream())
+ {
+ ArrowStreamWriter writer = new ArrowStreamWriter(stream,
originalBatch.Schema, leaveOpen: true);
+ await writer.WriteRecordBatchAsync(originalBatch);
+ await writer.WriteEndAsync();
+
+ stream.Position = 0;
+
+ using (ArrowStreamReader reader = new
ArrowStreamReader(stream, leaveOpen: true))
+ {
+ readBatch = reader.ReadNextRecordBatch();
+ }
+ }
+
+ using (readBatch)
+ {
+ ArrowReaderVerifier.CompareBatches(originalBatch, readBatch);
+ }
+ }
+
+ [Fact]
+ public async Task
ReadRecordBatch_ExposedMemoryStream_BatchDoesNotAliasMutableStreamBuffer()
+ {
+ RecordBatch originalBatch =
TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: false);
+ RecordBatch readBatch;
+ byte[] streamBuffer;
+
+ using (MemoryStream stream = new MemoryStream())
+ {
+ ArrowStreamWriter writer = new ArrowStreamWriter(stream,
originalBatch.Schema, leaveOpen: true);
+ await writer.WriteRecordBatchAsync(originalBatch);
+ await writer.WriteEndAsync();
+
+ streamBuffer = stream.GetBuffer();
+ stream.Position = 0;
+
+ using (ArrowStreamReader reader = new
ArrowStreamReader(stream, leaveOpen: true))
+ {
+ readBatch = reader.ReadNextRecordBatch();
+ }
+ }
+
+ System.Array.Clear(streamBuffer, 0, streamBuffer.Length);
+
+ using (readBatch)
+ {
+ ArrowReaderVerifier.CompareBatches(originalBatch, readBatch);
+ }
+ }
+
+ [Fact]
+ public async Task
ReadRecordBatchAsync_ExposedMemoryStream_BatchDoesNotAliasMutableStreamBuffer()
+ {
+ RecordBatch originalBatch =
TestData.CreateSampleRecordBatch(length: 100, createDictionaryArray: false);
+ RecordBatch readBatch;
+ byte[] streamBuffer;
+
+ using (MemoryStream stream = new MemoryStream())
+ {
+ ArrowStreamWriter writer = new ArrowStreamWriter(stream,
originalBatch.Schema, leaveOpen: true);
+ await writer.WriteRecordBatchAsync(originalBatch);
+ await writer.WriteEndAsync();
+
+ streamBuffer = stream.GetBuffer();
+ stream.Position = 0;
+
+ using (ArrowStreamReader reader = new
ArrowStreamReader(stream, leaveOpen: true))
+ {
+ readBatch = await reader.ReadNextRecordBatchAsync();
+ }
+ }
+
+ System.Array.Clear(streamBuffer, 0, streamBuffer.Length);
+
+ using (readBatch)
+ {
+ ArrowReaderVerifier.CompareBatches(originalBatch, readBatch);
+ }
+ }
+
[Theory]
[InlineData(true)]
[InlineData(false)]
@@ -243,11 +484,34 @@ namespace Apache.Arrow.Tests
/// <summary>
/// A stream class that only returns a part of the data at a time.
/// </summary>
- private class PartialReadStream : MemoryStream
+ private class PartialReadStream : Stream
{
+ private readonly MemoryStream _innerStream = new MemoryStream();
+
// by default return 20 bytes at a time
public int PartialReadLength { get; set; } = 20;
+ public override bool CanRead => _innerStream.CanRead;
+ public override bool CanSeek => _innerStream.CanSeek;
+ public override bool CanWrite => _innerStream.CanWrite;
+ public override long Length => _innerStream.Length;
+ public override long Position { get => _innerStream.Position; set
=> _innerStream.Position = value; }
+
+ public override void Flush() => _innerStream.Flush();
+ public override long Seek(long offset, SeekOrigin origin) =>
_innerStream.Seek(offset, origin);
+ public override void SetLength(long value) =>
_innerStream.SetLength(value);
+ public override void Write(byte[] buffer, int offset, int count)
=> _innerStream.Write(buffer, offset, count);
+
+ public override int Read(byte[] buffer, int offset, int count)
+ {
+ return _innerStream.Read(buffer, offset, Math.Min(count,
PartialReadLength));
+ }
+
+ public override Task<int> ReadAsync(byte[] buffer, int offset, int
count, CancellationToken cancellationToken = default)
+ {
+ return _innerStream.ReadAsync(buffer, offset, Math.Min(count,
PartialReadLength), cancellationToken);
+ }
+
#if NET5_0_OR_GREATER
public override int Read(Span<byte> destination)
{
@@ -256,7 +520,7 @@ namespace Apache.Arrow.Tests
destination = destination.Slice(0, PartialReadLength);
}
- return base.Read(destination);
+ return _innerStream.Read(destination);
}
public override ValueTask<int> ReadAsync(Memory<byte> destination,
CancellationToken cancellationToken = default)
@@ -266,53 +530,50 @@ namespace Apache.Arrow.Tests
destination = destination.Slice(0, PartialReadLength);
}
- return base.ReadAsync(destination, cancellationToken);
- }
-#else
- public override int Read(byte[] buffer, int offset, int length)
- {
- return base.Read(buffer, offset, Math.Min(length,
PartialReadLength));
- }
-
- public override Task<int> ReadAsync(byte[] buffer, int offset, int
length, CancellationToken cancellationToken = default)
- {
- return base.ReadAsync(buffer, offset, Math.Min(length,
PartialReadLength), cancellationToken);
+ return _innerStream.ReadAsync(destination, cancellationToken);
}
#endif
}
- [Fact]
- public unsafe void MalformedColumnNameLength()
+ private class RequiresCancelableReadStream : Stream
{
- const int FieldNameLengthOffset = 108;
- const int FakeFieldNameLength = 165535;
+ public bool SawCancelableToken { get; private set; }
- byte[] buffer;
- using (var stream = new MemoryStream())
+ public override bool CanRead => true;
+ public override bool CanSeek => false;
+ public override bool CanWrite => false;
+ public override long Length => throw new NotSupportedException();
+ public override long Position { get => throw new
NotSupportedException(); set => throw new NotSupportedException(); }
+
+ public override void Flush() { }
+ public override int Read(byte[] buffer, int offset, int count) =>
throw new NotSupportedException();
+ public override long Seek(long offset, SeekOrigin origin) => throw
new NotSupportedException();
+ public override void SetLength(long value) => throw new
NotSupportedException();
+ public override void Write(byte[] buffer, int offset, int count)
=> throw new NotSupportedException();
+
+#if NET5_0_OR_GREATER
+ public override ValueTask<int> ReadAsync(Memory<byte> buffer,
CancellationToken cancellationToken = default)
{
- Schema schema = new(
- [new Field("index", Int32Type.Default, nullable: false)],
- metadata: []);
- using (var writer = new ArrowStreamWriter(stream, schema,
leaveOpen: true))
+ SawCancelableToken = cancellationToken.CanBeCanceled;
+ if (!SawCancelableToken)
{
- writer.WriteStart();
- writer.WriteEnd();
+ throw new InvalidOperationException("Expected the caller's
cancellation token during schema read.");
}
- buffer = stream.ToArray();
- }
- Span<int> length = buffer.AsSpan().Slice(FieldNameLengthOffset,
sizeof(int)).CastTo<int>();
- Assert.Equal(5, length[0]);
- length[0] = FakeFieldNameLength;
-
- Assert.Throws<ArgumentOutOfRangeException>(() =>
+ throw new OperationCanceledException(cancellationToken);
+ }
+#else
+ public override Task<int> ReadAsync(byte[] buffer, int offset, int
count, CancellationToken cancellationToken)
{
- using (var stream = new MemoryStream(buffer))
- using (var reader = new ArrowStreamReader(stream))
+ SawCancelableToken = cancellationToken.CanBeCanceled;
+ if (!SawCancelableToken)
{
- reader.ReadNextRecordBatch();
+ throw new InvalidOperationException("Expected the caller's
cancellation token during schema read.");
}
- });
+
+ throw new OperationCanceledException(cancellationToken);
+ }
+#endif
}
[Fact]
@@ -468,34 +729,52 @@ namespace Apache.Arrow.Tests
private static void WriteInt64LittleEndian(byte[] buffer, int offset,
long value)
{
- System.Buffers.Binary.BinaryPrimitives.WriteInt64LittleEndian(
- buffer.AsSpan(offset), value);
+ BinaryPrimitives.WriteInt64LittleEndian(buffer.AsSpan(offset),
value);
}
[Fact]
- public async Task EmptyStreamNoSyncRead()
+ public unsafe void MalformedColumnNameLength()
{
- using (var stream = new EmptyAsyncOnlyStream())
+ const int FieldNameLengthOffset = 108;
+ const int FakeFieldNameLength = 165535;
+
+ byte[] buffer;
+ using (var stream = new MemoryStream())
{
- var reader = new ArrowStreamReader(stream);
- var schema = await reader.GetSchema();
- Assert.Null(schema);
+ Schema schema = new(
+ [new Field("index", Int32Type.Default, nullable: false)],
+ metadata: []);
+ using (var writer = new ArrowStreamWriter(stream, schema,
leaveOpen: true))
+ {
+ writer.WriteStart();
+ writer.WriteEnd();
+ }
+ buffer = stream.ToArray();
}
- }
- private static short ToInt16LittleEndian(byte[] buffer, int offset)
- {
- return
BinaryPrimitives.ReadInt16LittleEndian(buffer.AsSpan().Slice(offset));
- }
+ Span<int> length = buffer.AsSpan().Slice(FieldNameLengthOffset,
sizeof(int)).CastTo<int>();
+ Assert.Equal(5, length[0]);
+ length[0] = FakeFieldNameLength;
- private static int ToInt32LittleEndian(byte[] buffer, int offset)
- {
- return
BinaryPrimitives.ReadInt32LittleEndian(buffer.AsSpan().Slice(offset));
+ Assert.Throws<ArgumentOutOfRangeException>(() =>
+ {
+ using (var stream = new MemoryStream(buffer))
+ using (var reader = new ArrowStreamReader(stream))
+ {
+ reader.ReadNextRecordBatch();
+ }
+ });
}
- private static long ToInt64LittleEndian(byte[] buffer, int offset)
+ [Fact]
+ public async Task EmptyStreamNoSyncRead()
{
- return
BinaryPrimitives.ReadInt64LittleEndian(buffer.AsSpan().Slice(offset));
+ using (var stream = new EmptyAsyncOnlyStream())
+ {
+ var reader = new ArrowStreamReader(stream);
+ var schema = await reader.GetSchema();
+ Assert.Null(schema);
+ }
}
private class EmptyAsyncOnlyStream : Stream
@@ -512,5 +791,20 @@ namespace Apache.Arrow.Tests
public override void Write(byte[] buffer, int offset, int count)
=> throw new NotSupportedException();
public override Task<int> ReadAsync(byte[] buffer, int offset, int
count, CancellationToken cancellationToken) => Task.FromResult(0);
}
+
+ private static short ToInt16LittleEndian(byte[] buffer, int offset)
+ {
+ return
BinaryPrimitives.ReadInt16LittleEndian(buffer.AsSpan().Slice(offset));
+ }
+
+ private static int ToInt32LittleEndian(byte[] buffer, int offset)
+ {
+ return
BinaryPrimitives.ReadInt32LittleEndian(buffer.AsSpan().Slice(offset));
+ }
+
+ private static long ToInt64LittleEndian(byte[] buffer, int offset)
+ {
+ return
BinaryPrimitives.ReadInt64LittleEndian(buffer.AsSpan().Slice(offset));
+ }
}
}