This is an automated email from the ASF dual-hosted git repository.
kou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 80cb11a ARROW-4997: [C#] ArrowStreamReader doesn't consume whole
stream and doesn't implement sync read.
80cb11a is described below
commit 80cb11a57b9c7a9c275235eaa29e307b5d45fda5
Author: Eric Erhardt <[email protected]>
AuthorDate: Sat Mar 23 06:25:31 2019 +0900
ARROW-4997: [C#] ArrowStreamReader doesn't consume whole stream and doesn't
implement sync read.
When reading from a network stream, ArrowStreamReader doesn't check how
many bytes were read, and instead assumes the whole buffer was filled with a
single read call. Fixing this to read multiple times to fill the whole buffer.
Also, when consuming a stream only the async read method works. The
synchronous method throws NotImplementedException. Implementing the sync method.
Note: this implements a lot of the underlying functionality from #3925. The
difference here is that this change doesn't attempt to solve the perf issues
with allocating and copying memory multiples times. #3925 is specifically
solving that perf issue.
@stephentoub @pgovind @chutchinson
Author: Eric Erhardt <[email protected]>
Closes #4017 from eerhardt/ArrowReaderBlockingIssues and squashes the
following commits:
806a0d34 <Eric Erhardt> PR feedback
356a861c <Eric Erhardt> ArrowStreamReader doesn't consume whole stream and
doesn't implement sync read.
---
csharp/src/Apache.Arrow/Apache.Arrow.csproj | 4 +
csharp/src/Apache.Arrow/BitUtility.cs | 8 +
.../Apache.Arrow/Extensions/ArrayPoolExtensions.cs | 8 +-
.../Apache.Arrow/Extensions/StreamExtensions.cs | 70 +++++++
.../Extensions/StreamExtensions.netcoreapp2.1.cs | 29 +++
.../Extensions/StreamExtensions.netstandard.cs | 57 +++++-
.../Ipc/ArrowFileReaderImplementation.cs | 203 ++++++++++++++++-----
csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs | 9 +-
.../Ipc/ArrowMemoryReaderImplementation.cs | 9 +-
.../Apache.Arrow/Ipc/ArrowReaderImplementation.cs | 7 +-
csharp/src/Apache.Arrow/Ipc/ArrowStreamReader.cs | 4 +-
.../Ipc/ArrowStreamReaderImplementation.cs | 182 +++++++++++-------
csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs | 12 +-
.../Apache.Arrow.Tests/ArrowFileReaderTests.cs | 77 ++++++++
...StreamReaderTests.cs => ArrowReaderVerifier.cs} | 60 ++----
.../Apache.Arrow.Tests/ArrowStreamReaderTests.cs | 160 +++++++++-------
csharp/test/Apache.Arrow.Tests/app.config | 15 --
17 files changed, 651 insertions(+), 263 deletions(-)
diff --git a/csharp/src/Apache.Arrow/Apache.Arrow.csproj
b/csharp/src/Apache.Arrow/Apache.Arrow.csproj
index 9086c83..48e89d7 100644
--- a/csharp/src/Apache.Arrow/Apache.Arrow.csproj
+++ b/csharp/src/Apache.Arrow/Apache.Arrow.csproj
@@ -12,6 +12,7 @@
<PackageReference Include="System.Buffers" Version="4.5.0" />
<PackageReference Include="System.Memory" Version="4.5.2" />
<PackageReference Include="System.Runtime.CompilerServices.Unsafe"
Version="4.5.2" />
+ <PackageReference Include="System.Threading.Tasks.Extensions"
Version="4.5.2" />
<PackageReference Include="Microsoft.SourceLink.GitHub"
Version="1.0.0-beta2-18618-05" PrivateAssets="All" />
</ItemGroup>
@@ -31,6 +32,9 @@
</EmbeddedResource>
</ItemGroup>
+ <ItemGroup Condition="'$(TargetFramework)' == 'netstandard1.3'">
+ <Compile Remove="Extensions\StreamExtensions.netcoreapp2.1.cs" />
+ </ItemGroup>
<ItemGroup Condition="'$(TargetFramework)' == 'netcoreapp2.1'">
<Compile Remove="Extensions\StreamExtensions.netstandard.cs" />
</ItemGroup>
diff --git a/csharp/src/Apache.Arrow/BitUtility.cs
b/csharp/src/Apache.Arrow/BitUtility.cs
index 3b4ee7a..fccdfe0 100644
--- a/csharp/src/Apache.Arrow/BitUtility.cs
+++ b/csharp/src/Apache.Arrow/BitUtility.cs
@@ -15,6 +15,8 @@
using System;
using System.Diagnostics;
+using System.Runtime.CompilerServices;
+using System.Runtime.InteropServices;
namespace Apache.Arrow
{
@@ -116,5 +118,11 @@ namespace Apache.Arrow
return (n + (factor - 1)) & ~(factor - 1);
}
+ internal static int ReadInt32(ReadOnlyMemory<byte> value)
+ {
+ Debug.Assert(value.Length >= sizeof(int));
+
+ return Unsafe.ReadUnaligned<int>(ref
MemoryMarshal.GetReference(value.Span));
+ }
}
}
diff --git a/csharp/src/Apache.Arrow/Extensions/ArrayPoolExtensions.cs
b/csharp/src/Apache.Arrow/Extensions/ArrayPoolExtensions.cs
index e65a3ef..9dd9589 100644
--- a/csharp/src/Apache.Arrow/Extensions/ArrayPoolExtensions.cs
+++ b/csharp/src/Apache.Arrow/Extensions/ArrayPoolExtensions.cs
@@ -23,14 +23,14 @@ namespace Apache.Arrow
internal static class ArrayPoolExtensions
{
[MethodImpl(MethodImplOptions.AggressiveInlining)]
- public static void RentReturn(this ArrayPool<byte> pool, int length,
Action<byte[]> action)
+ public static void RentReturn(this ArrayPool<byte> pool, int length,
Action<Memory<byte>> action)
{
byte[] array = null;
try
{
array = pool.Rent(length);
- action(array);
+ action(array.AsMemory(0, length));
}
finally
{
@@ -42,14 +42,14 @@ namespace Apache.Arrow
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
- public static Task RentReturnAsync(this ArrayPool<byte> pool, int
length, Func<byte[], Task> action)
+ public static ValueTask RentReturnAsync(this ArrayPool<byte> pool, int
length, Func<Memory<byte>, ValueTask> action)
{
byte[] array = null;
try
{
array = pool.Rent(length);
- return action(array);
+ return action(array.AsMemory(0, length));
}
finally
{
diff --git a/csharp/src/Apache.Arrow/Extensions/StreamExtensions.cs
b/csharp/src/Apache.Arrow/Extensions/StreamExtensions.cs
new file mode 100644
index 0000000..1767d23
--- /dev/null
+++ b/csharp/src/Apache.Arrow/Extensions/StreamExtensions.cs
@@ -0,0 +1,70 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+using System;
+using System.IO;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace Apache.Arrow
+{
+ internal static partial class StreamExtensions
+ {
+ public static async ValueTask<int> ReadFullBufferAsync(this Stream
stream, Memory<byte> buffer, CancellationToken cancellationToken = default)
+ {
+ int totalBytesRead = 0;
+ do
+ {
+ int bytesRead =
+ await stream.ReadAsync(
+ buffer.Slice(totalBytesRead, buffer.Length -
totalBytesRead),
+ cancellationToken)
+ .ConfigureAwait(false);
+
+ if (bytesRead == 0)
+ {
+ // reached the end of the stream
+ return totalBytesRead;
+ }
+
+ totalBytesRead += bytesRead;
+ }
+ while (totalBytesRead < buffer.Length);
+
+ return totalBytesRead;
+ }
+
+ public static int ReadFullBuffer(this Stream stream, Memory<byte>
buffer)
+ {
+ int totalBytesRead = 0;
+ do
+ {
+ int bytesRead = stream.Read(
+ buffer.Slice(totalBytesRead, buffer.Length -
totalBytesRead));
+
+ if (bytesRead == 0)
+ {
+ // reached the end of the stream
+ return totalBytesRead;
+ }
+
+ totalBytesRead += bytesRead;
+ }
+ while (totalBytesRead < buffer.Length);
+
+ return totalBytesRead;
+ }
+ }
+}
diff --git
a/csharp/src/Apache.Arrow/Extensions/StreamExtensions.netcoreapp2.1.cs
b/csharp/src/Apache.Arrow/Extensions/StreamExtensions.netcoreapp2.1.cs
new file mode 100644
index 0000000..f51dc53
--- /dev/null
+++ b/csharp/src/Apache.Arrow/Extensions/StreamExtensions.netcoreapp2.1.cs
@@ -0,0 +1,29 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+using System;
+using System.IO;
+
+namespace Apache.Arrow
+{
+ // Helpers to read from Stream to Memory<byte> on netcoreapp
+ internal static partial class StreamExtensions
+ {
+ public static int Read(this Stream stream, Memory<byte> buffer)
+ {
+ return stream.Read(buffer.Span);
+ }
+ }
+}
diff --git a/csharp/src/Apache.Arrow/Extensions/StreamExtensions.netstandard.cs
b/csharp/src/Apache.Arrow/Extensions/StreamExtensions.netstandard.cs
index 94affa3..ce23bd1 100644
--- a/csharp/src/Apache.Arrow/Extensions/StreamExtensions.netstandard.cs
+++ b/csharp/src/Apache.Arrow/Extensions/StreamExtensions.netstandard.cs
@@ -23,13 +23,62 @@ using System.Threading.Tasks;
namespace Apache.Arrow
{
// Helpers to write Memory<byte> to Stream on netstandard
- internal static class StreamExtensions
+ internal static partial class StreamExtensions
{
- public static Task WriteAsync(this Stream stream, ReadOnlyMemory<byte>
buffer, CancellationToken cancellationToken = default)
+ public static int Read(this Stream stream, Memory<byte> buffer)
{
if (MemoryMarshal.TryGetArray(buffer, out ArraySegment<byte>
array))
{
- return stream.WriteAsync(array.Array, array.Offset,
array.Count, cancellationToken);
+ return stream.Read(array.Array, array.Offset, array.Count);
+ }
+ else
+ {
+ byte[] sharedBuffer =
ArrayPool<byte>.Shared.Rent(buffer.Length);
+ try
+ {
+ int result = stream.Read(sharedBuffer, 0, buffer.Length);
+ new Span<byte>(sharedBuffer, 0,
result).CopyTo(buffer.Span);
+ return result;
+ }
+ finally
+ {
+ ArrayPool<byte>.Shared.Return(sharedBuffer);
+ }
+ }
+ }
+
+ public static ValueTask<int> ReadAsync(this Stream stream,
Memory<byte> buffer, CancellationToken cancellationToken = default)
+ {
+ if (MemoryMarshal.TryGetArray(buffer, out ArraySegment<byte>
array))
+ {
+ return new ValueTask<int>(stream.ReadAsync(array.Array,
array.Offset, array.Count, cancellationToken));
+ }
+ else
+ {
+ byte[] sharedBuffer =
ArrayPool<byte>.Shared.Rent(buffer.Length);
+ return FinishReadAsync(stream.ReadAsync(sharedBuffer, 0,
buffer.Length, cancellationToken), sharedBuffer, buffer);
+
+ async ValueTask<int> FinishReadAsync(Task<int> readTask,
byte[] localBuffer, Memory<byte> localDestination)
+ {
+ try
+ {
+ int result = await readTask.ConfigureAwait(false);
+ new Span<byte>(localBuffer, 0,
result).CopyTo(localDestination.Span);
+ return result;
+ }
+ finally
+ {
+ ArrayPool<byte>.Shared.Return(localBuffer);
+ }
+ }
+ }
+ }
+
+ public static ValueTask WriteAsync(this Stream stream,
ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
+ {
+ if (MemoryMarshal.TryGetArray(buffer, out ArraySegment<byte>
array))
+ {
+ return new ValueTask(stream.WriteAsync(array.Array,
array.Offset, array.Count, cancellationToken));
}
else
{
@@ -39,7 +88,7 @@ namespace Apache.Arrow
}
}
- private static async Task FinishWriteAsync(Task writeTask, byte[]
localBuffer)
+ private static async ValueTask FinishWriteAsync(Task writeTask, byte[]
localBuffer)
{
try
{
diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs
b/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs
index 1a99a40..7a62085 100644
--- a/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs
+++ b/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs
@@ -13,9 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-using FlatBuffers;
using System;
-using System.Buffers.Binary;
using System.IO;
using System.Linq;
using System.Threading;
@@ -54,7 +52,7 @@ namespace Apache.Arrow.Ipc
return _footer.RecordBatchCount;
}
- protected override async Task ReadSchemaAsync()
+ protected override async ValueTask ReadSchemaAsync()
{
if (HasReadSchema)
{
@@ -63,43 +61,85 @@ namespace Apache.Arrow.Ipc
await ValidateFileAsync().ConfigureAwait(false);
- var bytesRead = 0;
- var footerLength = 0;
-
+ int footerLength = 0;
await Buffers.RentReturnAsync(4, async (buffer) =>
{
- BaseStream.Position = BaseStream.Length -
ArrowFileConstants.Magic.Length - 4;
-
- bytesRead = await BaseStream.ReadAsync(buffer, 0,
4).ConfigureAwait(false);
- footerLength = BinaryPrimitives.ReadInt32LittleEndian(buffer);
+ BaseStream.Position = GetFooterLengthPosition();
- if (bytesRead != 4) throw new InvalidDataException(
- $"Failed to read footer length. Read <{bytesRead}>,
expected 4.");
+ int bytesRead = await
BaseStream.ReadFullBufferAsync(buffer).ConfigureAwait(false);
+ EnsureFullRead(buffer, bytesRead);
- if (footerLength <= 0) throw new InvalidDataException(
- $"Footer length has invalid size <{footerLength}>");
+ footerLength = ReadFooterLength(buffer);
}).ConfigureAwait(false);
await Buffers.RentReturnAsync(footerLength, async (buffer) =>
{
- _footerStartPostion = (int)BaseStream.Length - footerLength -
ArrowFileConstants.Magic.Length - 4;
+ _footerStartPostion = (int)GetFooterLengthPosition() -
footerLength;
BaseStream.Position = _footerStartPostion;
- bytesRead = await BaseStream.ReadAsync(buffer, 0,
footerLength).ConfigureAwait(false);
+ int bytesRead = await
BaseStream.ReadFullBufferAsync(buffer).ConfigureAwait(false);
+ EnsureFullRead(buffer, bytesRead);
- if (bytesRead != footerLength)
- {
- throw new InvalidDataException(
- $"Failed to read footer. Read <{bytesRead}> bytes,
expected <{footerLength}>.");
- }
+ ReadSchema(buffer);
+ }).ConfigureAwait(false);
+ }
+
+ protected override void ReadSchema()
+ {
+ if (HasReadSchema)
+ {
+ return;
+ }
- // Deserialize the footer from the footer flatbuffer
+ ValidateFile();
- _footer = new ArrowFooter(Flatbuf.Footer.GetRootAsFooter(new
ByteBuffer(buffer)));
+ int footerLength = 0;
+ Buffers.RentReturn(4, (buffer) =>
+ {
+ BaseStream.Position = GetFooterLengthPosition();
- Schema = _footer.Schema;
- }).ConfigureAwait(false);
+ int bytesRead = BaseStream.ReadFullBuffer(buffer);
+ EnsureFullRead(buffer, bytesRead);
+
+ footerLength = ReadFooterLength(buffer);
+ });
+
+ Buffers.RentReturn(footerLength, (buffer) =>
+ {
+ _footerStartPostion = (int)GetFooterLengthPosition() -
footerLength;
+
+ BaseStream.Position = _footerStartPostion;
+
+ int bytesRead = BaseStream.ReadFullBuffer(buffer);
+ EnsureFullRead(buffer, bytesRead);
+
+ ReadSchema(buffer);
+ });
+ }
+
+ private long GetFooterLengthPosition()
+ {
+ return BaseStream.Length - ArrowFileConstants.Magic.Length - 4;
+ }
+
+ private static int ReadFooterLength(Memory<byte> buffer)
+ {
+ int footerLength = BitUtility.ReadInt32(buffer);
+
+ if (footerLength <= 0)
+ throw new InvalidDataException(
+ $"Footer length has invalid size <{footerLength}>");
+
+ return footerLength;
+ }
+
+ private void ReadSchema(Memory<byte> buffer)
+ {
+ // Deserialize the footer from the footer flatbuffer
+ _footer = new
ArrowFooter(Flatbuf.Footer.GetRootAsFooter(CreateByteBuffer(buffer)));
+
+ Schema = _footer.Schema;
}
public async Task<RecordBatch> ReadRecordBatchAsync(int index,
CancellationToken cancellationToken)
@@ -118,7 +158,23 @@ namespace Apache.Arrow.Ipc
return await
ReadRecordBatchAsync(cancellationToken).ConfigureAwait(false);
}
- public override async Task<RecordBatch>
ReadNextRecordBatchAsync(CancellationToken cancellationToken)
+ public RecordBatch ReadRecordBatch(int index)
+ {
+ ReadSchema();
+
+ if (index >= _footer.RecordBatchCount)
+ {
+ throw new ArgumentOutOfRangeException(nameof(index));
+ }
+
+ var block = _footer.GetRecordBatchBlock(index);
+
+ BaseStream.Position = block.Offset;
+
+ return ReadRecordBatch();
+ }
+
+ public override async ValueTask<RecordBatch>
ReadNextRecordBatchAsync(CancellationToken cancellationToken)
{
await ReadSchemaAsync().ConfigureAwait(false);
@@ -133,10 +189,25 @@ namespace Apache.Arrow.Ipc
return result;
}
+ public override RecordBatch ReadNextRecordBatch()
+ {
+ ReadSchema();
+
+ if (_recordBatchIndex >= _footer.RecordBatchCount)
+ {
+ return null;
+ }
+
+ RecordBatch result = ReadRecordBatch(_recordBatchIndex);
+ _recordBatchIndex++;
+
+ return result;
+ }
+
/// <summary>
/// Check if file format is valid. If it's valid don't run the
validation again.
/// </summary>
- private async Task ValidateFileAsync()
+ private async ValueTask ValidateFileAsync()
{
if (IsFileValid)
{
@@ -148,7 +219,22 @@ namespace Apache.Arrow.Ipc
IsFileValid = true;
}
- private async Task ValidateMagicAsync()
+ /// <summary>
+ /// Check if file format is valid. If it's valid don't run the
validation again.
+ /// </summary>
+ private void ValidateFile()
+ {
+ if (IsFileValid)
+ {
+ return;
+ }
+
+ ValidateMagic();
+
+ IsFileValid = true;
+ }
+
+ private async ValueTask ValidateMagicAsync()
{
var startingPosition = BaseStream.Position;
var magicLength = ArrowFileConstants.Magic.Length;
@@ -158,32 +244,20 @@ namespace Apache.Arrow.Ipc
await Buffers.RentReturnAsync(magicLength, async (buffer) =>
{
// Seek to the beginning of the stream
-
BaseStream.Position = 0;
// Read beginning of stream
+ await BaseStream.ReadAsync(buffer).ConfigureAwait(false);
- await BaseStream.ReadAsync(buffer, 0,
magicLength).ConfigureAwait(false);
-
- if
(!ArrowFileConstants.Magic.SequenceEqual(buffer.Take(magicLength)))
- {
- throw new InvalidDataException(
- $"Invalid magic at offset
<{BaseStream.Position}>");
- }
+ VerifyMagic(buffer);
// Move stream position to magic-length bytes away from
the end of the stream
-
BaseStream.Position = BaseStream.Length - magicLength;
// Read the end of the stream
+ await BaseStream.ReadAsync(buffer).ConfigureAwait(false);
- await BaseStream.ReadAsync(buffer, 0,
magicLength).ConfigureAwait(false);
-
- if
(!ArrowFileConstants.Magic.SequenceEqual(buffer.Take(magicLength)))
- {
- throw new InvalidDataException(
- $"Invalid magic at offset
<{BaseStream.Position}>");
- }
+ VerifyMagic(buffer);
}).ConfigureAwait(false);
}
finally
@@ -191,5 +265,46 @@ namespace Apache.Arrow.Ipc
BaseStream.Position = startingPosition;
}
}
+
+ private void ValidateMagic()
+ {
+ var startingPosition = BaseStream.Position;
+ var magicLength = ArrowFileConstants.Magic.Length;
+
+ try
+ {
+ Buffers.RentReturn(magicLength, buffer =>
+ {
+ // Seek to the beginning of the stream
+ BaseStream.Position = 0;
+
+ // Read beginning of stream
+ BaseStream.Read(buffer);
+
+ VerifyMagic(buffer);
+
+ // Move stream position to magic-length bytes away from
the end of the stream
+ BaseStream.Position = BaseStream.Length - magicLength;
+
+ // Read the end of the stream
+ BaseStream.Read(buffer);
+
+ VerifyMagic(buffer);
+ });
+ }
+ finally
+ {
+ BaseStream.Position = startingPosition;
+ }
+ }
+
+ private void VerifyMagic(Memory<byte> buffer)
+ {
+ if (!ArrowFileConstants.Magic.AsSpan().SequenceEqual(buffer.Span))
+ {
+ throw new InvalidDataException(
+ $"Invalid magic at offset <{BaseStream.Position}>");
+ }
+ }
}
}
diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs
b/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs
index 459cb51..b74bcc4 100644
--- a/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs
+++ b/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs
@@ -84,7 +84,7 @@ namespace Apache.Arrow.Ipc
await
BaseStream.FlushAsync(cancellationToken).ConfigureAwait(false);
}
- private async Task WriteHeaderAsync(CancellationToken
cancellationToken)
+ private async ValueTask WriteHeaderAsync(CancellationToken
cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
@@ -95,7 +95,7 @@ namespace Apache.Arrow.Ipc
.ConfigureAwait(false);
}
- private async Task WriteFooterAsync(Schema schema, CancellationToken
cancellationToken)
+ private async ValueTask WriteFooterAsync(Schema schema,
CancellationToken cancellationToken)
{
Builder.Clear();
@@ -141,10 +141,10 @@ namespace Apache.Arrow.Ipc
await Buffers.RentReturnAsync(4, async (buffer) =>
{
- BinaryPrimitives.WriteInt32LittleEndian(buffer,
+ BinaryPrimitives.WriteInt32LittleEndian(buffer.Span,
Convert.ToInt32(BaseStream.Position - offset));
- await BaseStream.WriteAsync(buffer, 0, 4,
cancellationToken).ConfigureAwait(false);
+ await BaseStream.WriteAsync(buffer,
cancellationToken).ConfigureAwait(false);
}).ConfigureAwait(false);
// Write magic
@@ -159,6 +159,5 @@ namespace Apache.Arrow.Ipc
return BaseStream.WriteAsync(
ArrowFileConstants.Magic, 0, ArrowFileConstants.Magic.Length);
}
-
}
}
diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs
b/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs
index 3656c84..df8b809 100644
--- a/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs
+++ b/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs
@@ -32,10 +32,10 @@ namespace Apache.Arrow.Ipc
_buffer = buffer;
}
- public override Task<RecordBatch>
ReadNextRecordBatchAsync(CancellationToken cancellationToken)
+ public override ValueTask<RecordBatch>
ReadNextRecordBatchAsync(CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
- return Task.FromResult(ReadNextRecordBatch());
+ return new ValueTask<RecordBatch>(ReadNextRecordBatch());
}
public override RecordBatch ReadNextRecordBatch()
@@ -74,11 +74,6 @@ namespace Apache.Arrow.Ipc
return new ArrowBuffer(data);
}
- private static ByteBuffer CreateByteBuffer(ReadOnlyMemory<byte> buffer)
- {
- return new ByteBuffer(new ReadOnlyMemoryBufferAllocator(buffer),
0);
- }
-
private void ReadSchema()
{
if (HasReadSchema)
diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs
b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs
index aa4f748..e5e9802 100644
--- a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs
+++ b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs
@@ -39,7 +39,7 @@ namespace Apache.Arrow.Ipc
{
}
- public abstract Task<RecordBatch>
ReadNextRecordBatchAsync(CancellationToken cancellationToken);
+ public abstract ValueTask<RecordBatch>
ReadNextRecordBatchAsync(CancellationToken cancellationToken);
public abstract RecordBatch ReadNextRecordBatch();
protected abstract ArrowBuffer CreateArrowBuffer(ReadOnlyMemory<byte>
data);
@@ -104,6 +104,11 @@ namespace Apache.Arrow.Ipc
return null;
}
+ internal static ByteBuffer CreateByteBuffer(ReadOnlyMemory<byte>
buffer)
+ {
+ return new ByteBuffer(new ReadOnlyMemoryBufferAllocator(buffer),
0);
+ }
+
private List<IArrowArray> BuildArrays(
Schema schema,
ByteBuffer messageBuffer,
diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReader.cs
b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReader.cs
index a399056..0923968 100644
--- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReader.cs
+++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReader.cs
@@ -66,9 +66,9 @@ namespace Apache.Arrow.Ipc
}
}
- public Task<RecordBatch> ReadNextRecordBatchAsync(CancellationToken
cancellationToken = default)
+ public async Task<RecordBatch>
ReadNextRecordBatchAsync(CancellationToken cancellationToken = default)
{
- return _implementation.ReadNextRecordBatchAsync(cancellationToken);
+ return await
_implementation.ReadNextRecordBatchAsync(cancellationToken).ConfigureAwait(false);
}
public RecordBatch ReadNextRecordBatch()
diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs
b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs
index 36e2e57..6ca0518 100644
--- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs
+++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs
@@ -42,7 +42,7 @@ namespace Apache.Arrow.Ipc
}
}
- public override async Task<RecordBatch>
ReadNextRecordBatchAsync(CancellationToken cancellationToken)
+ public override async ValueTask<RecordBatch>
ReadNextRecordBatchAsync(CancellationToken cancellationToken)
{
// TODO: Loop until a record batch is read.
cancellationToken.ThrowIfCancellationRequested();
@@ -51,107 +51,150 @@ namespace Apache.Arrow.Ipc
public override RecordBatch ReadNextRecordBatch()
{
- throw new NotImplementedException();
+ return ReadRecordBatch();
}
- protected async Task<RecordBatch>
ReadRecordBatchAsync(CancellationToken cancellationToken = default)
+ protected async ValueTask<RecordBatch>
ReadRecordBatchAsync(CancellationToken cancellationToken = default)
{
await ReadSchemaAsync().ConfigureAwait(false);
- var bytesRead = 0;
-
- byte[] lengthBuffer = null;
- byte[] messageBuff = null;
- byte[] bodyBuff = null;
-
- try
+ int messageLength = 0;
+ await Buffers.RentReturnAsync(4, async (lengthBuffer) =>
{
// Get Length of record batch for message header.
-
- lengthBuffer = Buffers.Rent(4);
- bytesRead += await BaseStream.ReadAsync(lengthBuffer, 0, 4,
cancellationToken)
+ int bytesRead = await
BaseStream.ReadFullBufferAsync(lengthBuffer, cancellationToken)
.ConfigureAwait(false);
- if (bytesRead != 4)
+ if (bytesRead == 4)
{
- //reached the end
- return null;
+ messageLength = BitUtility.ReadInt32(lengthBuffer);
}
+ }).ConfigureAwait(false);
- var messageLength = BitConverter.ToInt32(lengthBuffer, 0);
+ if (messageLength == 0)
+ {
+ // reached end
+ return null;
+ }
- if (messageLength == 0)
+ RecordBatch result = null;
+ await Buffers.RentReturnAsync(messageLength, async (messageBuff) =>
+ {
+ int bytesRead = await
BaseStream.ReadFullBufferAsync(messageBuff, cancellationToken)
+ .ConfigureAwait(false);
+ EnsureFullRead(messageBuff, bytesRead);
+
+ Flatbuf.Message message =
Flatbuf.Message.GetRootAsMessage(CreateByteBuffer(messageBuff));
+
+ await Buffers.RentReturnAsync((int)message.BodyLength, async
(bodyBuff) =>
{
- //reached the end
- return null;
- }
+ int bodyBytesRead = await
BaseStream.ReadFullBufferAsync(bodyBuff, cancellationToken)
+ .ConfigureAwait(false);
+ EnsureFullRead(bodyBuff, bodyBytesRead);
- messageBuff = Buffers.Rent(messageLength);
- bytesRead += await BaseStream.ReadAsync(messageBuff, 0,
messageLength, cancellationToken)
- .ConfigureAwait(false);
- var message = Flatbuf.Message.GetRootAsMessage(new
FlatBuffers.ByteBuffer(messageBuff));
+ FlatBuffers.ByteBuffer bodybb = CreateByteBuffer(bodyBuff);
+ result = CreateArrowObjectFromMessage(message, bodybb);
+ }).ConfigureAwait(false);
+ }).ConfigureAwait(false);
- bodyBuff = Buffers.Rent((int)message.BodyLength);
- var bodybb = new FlatBuffers.ByteBuffer(bodyBuff);
- bytesRead += await BaseStream.ReadAsync(bodyBuff, 0,
(int)message.BodyLength, cancellationToken)
- .ConfigureAwait(false);
+ return result;
+ }
- return CreateArrowObjectFromMessage(message, bodybb);
- }
- finally
+ protected RecordBatch ReadRecordBatch()
+ {
+ ReadSchema();
+
+ int messageLength = 0;
+ Buffers.RentReturn(4, lengthBuffer =>
{
- if (lengthBuffer != null)
- {
- Buffers.Return(lengthBuffer);
- }
+ int bytesRead = BaseStream.ReadFullBuffer(lengthBuffer);
- if (messageBuff != null)
+ if (bytesRead == 4)
{
- Buffers.Return(messageBuff);
+ messageLength = BitUtility.ReadInt32(lengthBuffer);
}
+ });
- if (bodyBuff != null)
- {
- Buffers.Return(bodyBuff);
- }
+ if (messageLength == 0)
+ {
+ // reached end
+ return null;
}
+
+ RecordBatch result = null;
+ Buffers.RentReturn(messageLength, messageBuff =>
+ {
+ int bytesRead = BaseStream.ReadFullBuffer(messageBuff);
+ EnsureFullRead(messageBuff, bytesRead);
+
+ Flatbuf.Message message =
Flatbuf.Message.GetRootAsMessage(CreateByteBuffer(messageBuff));
+
+ Buffers.RentReturn((int)message.BodyLength, bodyBuff =>
+ {
+ int bodyBytesRead = BaseStream.ReadFullBuffer(bodyBuff);
+ EnsureFullRead(bodyBuff, bodyBytesRead);
+
+ FlatBuffers.ByteBuffer bodybb = CreateByteBuffer(bodyBuff);
+ result = CreateArrowObjectFromMessage(message, bodybb);
+ });
+ });
+
+ return result;
}
- protected virtual async Task ReadSchemaAsync()
+ protected virtual async ValueTask ReadSchemaAsync()
{
if (HasReadSchema)
{
return;
}
- byte[] buff = null;
-
- try
+ // Figure out length of schema
+ int schemaMessageLength = 0;
+ await Buffers.RentReturnAsync(4, async (lengthBuffer) =>
{
- // Figure out length of schema
-
- buff = Buffers.Rent(4);
- await BaseStream.ReadAsync(buff, 0, 4).ConfigureAwait(false);
- var schemaMessageLength = BitConverter.ToInt32(buff, 0);
- Buffers.Return(buff);
+ int bytesRead = await
BaseStream.ReadFullBufferAsync(lengthBuffer).ConfigureAwait(false);
+ EnsureFullRead(lengthBuffer, bytesRead);
- // Allocate byte array for schema flat buffer
-
- buff = Buffers.Rent(schemaMessageLength);
- var schemabb = new FlatBuffers.ByteBuffer(buff);
+ schemaMessageLength = BitUtility.ReadInt32(lengthBuffer);
+ }).ConfigureAwait(false);
+ await Buffers.RentReturnAsync(schemaMessageLength, async (buff) =>
+ {
// Read in schema
+ int bytesRead = await
BaseStream.ReadFullBufferAsync(buff).ConfigureAwait(false);
+ EnsureFullRead(buff, bytesRead);
- await BaseStream.ReadAsync(buff, 0,
schemaMessageLength).ConfigureAwait(false);
+ var schemabb = CreateByteBuffer(buff);
Schema =
MessageSerializer.GetSchema(ReadMessage<Flatbuf.Schema>(schemabb));
- }
- finally
+ }).ConfigureAwait(false);
+ }
+
+ protected virtual void ReadSchema()
+ {
+ if (HasReadSchema)
{
- if (buff != null)
- {
- Buffers.Return(buff);
- }
+ return;
}
+
+ // Figure out length of schema
+ int schemaMessageLength = 0;
+ Buffers.RentReturn(4, lengthBuffer =>
+ {
+ int bytesRead = BaseStream.ReadFullBuffer(lengthBuffer);
+ EnsureFullRead(lengthBuffer, bytesRead);
+
+ schemaMessageLength = BitUtility.ReadInt32(lengthBuffer);
+ });
+
+ Buffers.RentReturn(schemaMessageLength, buff =>
+ {
+ int bytesRead = BaseStream.ReadFullBuffer(buff);
+ EnsureFullRead(buff, bytesRead);
+
+ var schemabb = CreateByteBuffer(buff);
+ Schema =
MessageSerializer.GetSchema(ReadMessage<Flatbuf.Schema>(schemabb));
+ });
}
protected override ArrowBuffer CreateArrowBuffer(ReadOnlyMemory<byte>
data)
@@ -162,5 +205,18 @@ namespace Apache.Arrow.Ipc
.Append(data.Span)
.Build();
}
+
+ /// <summary>
+ /// Ensures the number of bytes read matches the buffer length
+ /// and throws an exception it if doesn't. This ensures we have read
+ /// a full buffer from the stream.
+ /// </summary>
+ internal static void EnsureFullRead(Memory<byte> buffer, int bytesRead)
+ {
+ if (bytesRead != buffer.Length)
+ {
+ throw new InvalidOperationException("Unexpectedly reached the
end of the stream before a full buffer was read.");
+ }
+ }
}
}
diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs
b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs
index b87b71b..c1a6646 100644
--- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs
+++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs
@@ -320,7 +320,7 @@ namespace Apache.Arrow.Ipc
}
- private async Task<Offset<Flatbuf.Schema>> WriteSchemaAsync(Schema
schema, CancellationToken cancellationToken)
+ private async ValueTask<Offset<Flatbuf.Schema>>
WriteSchemaAsync(Schema schema, CancellationToken cancellationToken)
{
Builder.Clear();
@@ -336,7 +336,7 @@ namespace Apache.Arrow.Ipc
return schemaOffset;
}
- private async Task WriteMessageAsync<T>(
+ private async ValueTask WriteMessageAsync<T>(
Flatbuf.MessageHeader headerType, Offset<T> headerOffset, int
bodyLength,
CancellationToken cancellationToken)
where T: struct
@@ -350,18 +350,18 @@ namespace Apache.Arrow.Ipc
var messageData =
Builder.DataBuffer.ToReadOnlyMemory(Builder.DataBuffer.Position,
Builder.Offset);
var messagePaddingLength = CalculatePadding(messageData.Length);
- await Buffers.RentReturnAsync(4, (buffer) =>
+ await Buffers.RentReturnAsync(4, async (buffer) =>
{
var metadataSize = messageData.Length + messagePaddingLength;
- BinaryPrimitives.WriteInt32LittleEndian(buffer, metadataSize);
- return BaseStream.WriteAsync(buffer, 0, 4, cancellationToken);
+ BinaryPrimitives.WriteInt32LittleEndian(buffer.Span,
metadataSize);
+ await BaseStream.WriteAsync(buffer,
cancellationToken).ConfigureAwait(false);
}).ConfigureAwait(false);
await BaseStream.WriteAsync(messageData,
cancellationToken).ConfigureAwait(false);
await
WritePaddingAsync(messagePaddingLength).ConfigureAwait(false);
}
- private protected async Task WriteFlatBufferAsync(CancellationToken
cancellationToken = default)
+ private protected async ValueTask
WriteFlatBufferAsync(CancellationToken cancellationToken = default)
{
var segment =
Builder.DataBuffer.ToReadOnlyMemory(Builder.DataBuffer.Position,
Builder.Offset);
diff --git a/csharp/test/Apache.Arrow.Tests/ArrowFileReaderTests.cs
b/csharp/test/Apache.Arrow.Tests/ArrowFileReaderTests.cs
index b2b769d..b756faa 100644
--- a/csharp/test/Apache.Arrow.Tests/ArrowFileReaderTests.cs
+++ b/csharp/test/Apache.Arrow.Tests/ArrowFileReaderTests.cs
@@ -16,6 +16,7 @@
using Apache.Arrow.Ipc;
using System;
using System.IO;
+using System.Threading.Tasks;
using Xunit;
namespace Apache.Arrow.Tests
@@ -45,5 +46,81 @@ namespace Apache.Arrow.Tests
new ArrowFileReader(stream, leaveOpen: true).Dispose();
Assert.Equal(0, stream.Position);
}
+
+ [Fact]
+ public async Task TestReadNextRecordBatch()
+ {
+ await TestReadRecordBatchHelper((reader, originalBatch) =>
+ {
+ ArrowReaderVerifier.VerifyReader(reader, originalBatch);
+ return Task.CompletedTask;
+ });
+ }
+
+ [Fact]
+ public async Task TestReadNextRecordBatchAsync()
+ {
+ await
TestReadRecordBatchHelper(ArrowReaderVerifier.VerifyReaderAsync);
+ }
+
+ [Fact]
+ public async Task TestReadRecordBatchAsync()
+ {
+ await TestReadRecordBatchHelper(async (reader, originalBatch) =>
+ {
+ RecordBatch readBatch = await reader.ReadRecordBatchAsync(0);
+ ArrowReaderVerifier.CompareBatches(originalBatch, readBatch);
+
+ // You should be able to read the same record batch again
+ RecordBatch readBatch2 = await reader.ReadRecordBatchAsync(0);
+ ArrowReaderVerifier.CompareBatches(originalBatch, readBatch2);
+ });
+ }
+
+ private static async Task TestReadRecordBatchHelper(
+ Func<ArrowFileReader, RecordBatch, Task> verificationFunc)
+ {
+ RecordBatch originalBatch =
TestData.CreateSampleRecordBatch(length: 100);
+
+ using (MemoryStream stream = new MemoryStream())
+ {
+ ArrowFileWriter writer = new ArrowFileWriter(stream,
originalBatch.Schema);
+ await writer.WriteRecordBatchAsync(originalBatch);
+ await writer.WriteFooterAsync();
+ stream.Position = 0;
+
+ ArrowFileReader reader = new ArrowFileReader(stream);
+ await verificationFunc(reader, originalBatch);
+ }
+ }
+
+ [Fact]
+ public async Task TestReadMultipleRecordBatchAsync()
+ {
+ RecordBatch originalBatch1 =
TestData.CreateSampleRecordBatch(length: 100);
+ RecordBatch originalBatch2 =
TestData.CreateSampleRecordBatch(length: 50);
+
+ using (MemoryStream stream = new MemoryStream())
+ {
+ ArrowFileWriter writer = new ArrowFileWriter(stream,
originalBatch1.Schema);
+ await writer.WriteRecordBatchAsync(originalBatch1);
+ await writer.WriteRecordBatchAsync(originalBatch2);
+ await writer.WriteFooterAsync();
+ stream.Position = 0;
+
+ // the recordbatches by index are in reverse order - back to
front.
+ // TODO: is this a bug??
+ ArrowFileReader reader = new ArrowFileReader(stream);
+ RecordBatch readBatch1 = await reader.ReadRecordBatchAsync(0);
+ ArrowReaderVerifier.CompareBatches(originalBatch2, readBatch1);
+
+ RecordBatch readBatch2 = await reader.ReadRecordBatchAsync(1);
+ ArrowReaderVerifier.CompareBatches(originalBatch1, readBatch2);
+
+ // now read the first again, for random access
+ RecordBatch readBatch3 = await reader.ReadRecordBatchAsync(0);
+ ArrowReaderVerifier.CompareBatches(originalBatch2, readBatch3);
+ }
+ }
}
}
diff --git a/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs
b/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs
similarity index 74%
copy from csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs
copy to csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs
index 3a1dae6..d7f0c26 100644
--- a/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs
+++ b/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs
@@ -15,53 +15,16 @@
using Apache.Arrow.Ipc;
using System;
-using System.IO;
using System.Linq;
using System.Threading.Tasks;
using Xunit;
namespace Apache.Arrow.Tests
{
- public class ArrowStreamReaderTests
+ public static class ArrowReaderVerifier
{
- [Fact]
- public void Ctor_LeaveOpenDefault_StreamClosedOnDispose()
+ public static void VerifyReader(ArrowStreamReader reader, RecordBatch
originalBatch)
{
- var stream = new MemoryStream();
- new ArrowStreamReader(stream).Dispose();
- Assert.Throws<ObjectDisposedException>(() => stream.Position);
- }
-
- [Fact]
- public void Ctor_LeaveOpenFalse_StreamClosedOnDispose()
- {
- var stream = new MemoryStream();
- new ArrowStreamReader(stream, leaveOpen: false).Dispose();
- Assert.Throws<ObjectDisposedException>(() => stream.Position);
- }
-
- [Fact]
- public void Ctor_LeaveOpenTrue_StreamValidOnDispose()
- {
- var stream = new MemoryStream();
- new ArrowStreamReader(stream, leaveOpen: true).Dispose();
- Assert.Equal(0, stream.Position);
- }
-
- [Fact]
- public async Task ReadRecordBatch()
- {
- RecordBatch originalBatch =
TestData.CreateSampleRecordBatch(length: 100);
-
- byte[] buffer;
- using (MemoryStream stream = new MemoryStream())
- {
- ArrowStreamWriter writer = new ArrowStreamWriter(stream,
originalBatch.Schema);
- await writer.WriteRecordBatchAsync(originalBatch);
- buffer = stream.GetBuffer();
- }
-
- ArrowStreamReader reader = new ArrowStreamReader(buffer);
RecordBatch readBatch = reader.ReadNextRecordBatch();
CompareBatches(originalBatch, readBatch);
@@ -70,9 +33,19 @@ namespace Apache.Arrow.Tests
Assert.Null(reader.ReadNextRecordBatch());
}
- private void CompareBatches(RecordBatch expectedBatch, RecordBatch
actualBatch)
+ public static async Task VerifyReaderAsync(ArrowStreamReader reader,
RecordBatch originalBatch)
+ {
+ RecordBatch readBatch = await reader.ReadNextRecordBatchAsync();
+ CompareBatches(originalBatch, readBatch);
+
+ // There should only be one batch - calling
ReadNextRecordBatchAsync again should return null.
+ Assert.Null(await reader.ReadNextRecordBatchAsync());
+ Assert.Null(await reader.ReadNextRecordBatchAsync());
+ }
+
+ public static void CompareBatches(RecordBatch expectedBatch,
RecordBatch actualBatch)
{
- CompareSchemas(expectedBatch.Schema, actualBatch.Schema);
+ Assert.True(SchemaComparer.Equals(expectedBatch.Schema,
actualBatch.Schema));
Assert.Equal(expectedBatch.Length, actualBatch.Length);
Assert.Equal(expectedBatch.ColumnCount, actualBatch.ColumnCount);
@@ -85,11 +58,6 @@ namespace Apache.Arrow.Tests
}
}
- private void CompareSchemas(Schema expectedSchema, Schema actualSchema)
- {
- Assert.True(SchemaComparer.Equals(expectedSchema, actualSchema));
- }
-
private class ArrayComparer :
IArrowArrayVisitor<Int8Array>,
IArrowArrayVisitor<Int16Array>,
diff --git a/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs
b/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs
index 3a1dae6..0a2670c 100644
--- a/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs
+++ b/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs
@@ -16,7 +16,7 @@
using Apache.Arrow.Ipc;
using System;
using System.IO;
-using System.Linq;
+using System.Threading;
using System.Threading.Tasks;
using Xunit;
@@ -49,7 +49,22 @@ namespace Apache.Arrow.Tests
}
[Fact]
- public async Task ReadRecordBatch()
+ public async Task ReadRecordBatch_Memory()
+ {
+ await TestReaderFromMemory((reader, originalBatch) =>
+ {
+ ArrowReaderVerifier.VerifyReader(reader, originalBatch);
+ return Task.CompletedTask;
+ });
+ }
+
+ [Fact]
+ public async Task ReadRecordBatchAsync_Memory()
+ {
+ await TestReaderFromMemory(ArrowReaderVerifier.VerifyReaderAsync);
+ }
+
+ private static async Task TestReaderFromMemory(Func<ArrowStreamReader,
RecordBatch, Task> verificationFunc)
{
RecordBatch originalBatch =
TestData.CreateSampleRecordBatch(length: 100);
@@ -62,92 +77,105 @@ namespace Apache.Arrow.Tests
}
ArrowStreamReader reader = new ArrowStreamReader(buffer);
- RecordBatch readBatch = reader.ReadNextRecordBatch();
- CompareBatches(originalBatch, readBatch);
+ await verificationFunc(reader, originalBatch);
+ }
+
+ [Fact]
+ public async Task ReadRecordBatch_Stream()
+ {
+ await TestReaderFromStream((reader, originalBatch) =>
+ {
+ ArrowReaderVerifier.VerifyReader(reader, originalBatch);
+ return Task.CompletedTask;
+ });
+ }
- // There should only be one batch - calling ReadNextRecordBatch
again should return null.
- Assert.Null(reader.ReadNextRecordBatch());
- Assert.Null(reader.ReadNextRecordBatch());
+ [Fact]
+ public async Task ReadRecordBatchAsync_Stream()
+ {
+ await TestReaderFromStream(ArrowReaderVerifier.VerifyReaderAsync);
}
- private void CompareBatches(RecordBatch expectedBatch, RecordBatch
actualBatch)
+ private static async Task TestReaderFromStream(Func<ArrowStreamReader,
RecordBatch, Task> verificationFunc)
{
- CompareSchemas(expectedBatch.Schema, actualBatch.Schema);
- Assert.Equal(expectedBatch.Length, actualBatch.Length);
- Assert.Equal(expectedBatch.ColumnCount, actualBatch.ColumnCount);
+ RecordBatch originalBatch =
TestData.CreateSampleRecordBatch(length: 100);
- for (int i = 0; i < expectedBatch.ColumnCount; i++)
+ using (MemoryStream stream = new MemoryStream())
{
- IArrowArray expectedArray = expectedBatch.Arrays.ElementAt(i);
- IArrowArray actualArray = actualBatch.Arrays.ElementAt(i);
+ ArrowStreamWriter writer = new ArrowStreamWriter(stream,
originalBatch.Schema);
+ await writer.WriteRecordBatchAsync(originalBatch);
- actualArray.Accept(new ArrayComparer(expectedArray));
+ stream.Position = 0;
+
+ ArrowStreamReader reader = new ArrowStreamReader(stream);
+ await verificationFunc(reader, originalBatch);
}
}
- private void CompareSchemas(Schema expectedSchema, Schema actualSchema)
+ [Fact]
+ public async Task ReadRecordBatch_PartialReadStream()
+ {
+ await TestReaderFromPartialReadStream((reader, originalBatch) =>
+ {
+ ArrowReaderVerifier.VerifyReader(reader, originalBatch);
+ return Task.CompletedTask;
+ });
+ }
+
+ [Fact]
+ public async Task ReadRecordBatchAsync_PartialReadStream()
{
- Assert.True(SchemaComparer.Equals(expectedSchema, actualSchema));
+ await
TestReaderFromPartialReadStream(ArrowReaderVerifier.VerifyReaderAsync);
}
- private class ArrayComparer :
- IArrowArrayVisitor<Int8Array>,
- IArrowArrayVisitor<Int16Array>,
- IArrowArrayVisitor<Int32Array>,
- IArrowArrayVisitor<Int64Array>,
- IArrowArrayVisitor<UInt8Array>,
- IArrowArrayVisitor<UInt16Array>,
- IArrowArrayVisitor<UInt32Array>,
- IArrowArrayVisitor<UInt64Array>,
- IArrowArrayVisitor<FloatArray>,
- IArrowArrayVisitor<DoubleArray>,
- IArrowArrayVisitor<BooleanArray>,
- IArrowArrayVisitor<TimestampArray>,
- IArrowArrayVisitor<Date32Array>,
- IArrowArrayVisitor<Date64Array>,
- IArrowArrayVisitor<ListArray>,
- IArrowArrayVisitor<StringArray>,
- IArrowArrayVisitor<BinaryArray>
+ /// <summary>
+ /// Verifies that the stream reader reads multiple times when a stream
+ /// only returns a subset of the data from each Read.
+ /// </summary>
+ private static async Task
TestReaderFromPartialReadStream(Func<ArrowStreamReader, RecordBatch, Task>
verificationFunc)
{
- private readonly IArrowArray _expectedArray;
+ RecordBatch originalBatch =
TestData.CreateSampleRecordBatch(length: 100);
- public ArrayComparer(IArrowArray expectedArray)
+ using (PartialReadStream stream = new PartialReadStream())
{
- _expectedArray = expectedArray;
+ ArrowStreamWriter writer = new ArrowStreamWriter(stream,
originalBatch.Schema);
+ await writer.WriteRecordBatchAsync(originalBatch);
+
+ stream.Position = 0;
+
+ ArrowStreamReader reader = new ArrowStreamReader(stream);
+ await verificationFunc(reader, originalBatch);
}
+ }
+
+ /// <summary>
+ /// A stream class that only returns a part of the data at a time.
+ /// </summary>
+ private class PartialReadStream : MemoryStream
+ {
+ // by default return 20 bytes at a time
+ public int PartialReadLength { get; set; } = 20;
- public void Visit(Int8Array array) => CompareArrays(array);
- public void Visit(Int16Array array) => CompareArrays(array);
- public void Visit(Int32Array array) => CompareArrays(array);
- public void Visit(Int64Array array) => CompareArrays(array);
- public void Visit(UInt8Array array) => CompareArrays(array);
- public void Visit(UInt16Array array) => CompareArrays(array);
- public void Visit(UInt32Array array) => CompareArrays(array);
- public void Visit(UInt64Array array) => CompareArrays(array);
- public void Visit(FloatArray array) => CompareArrays(array);
- public void Visit(DoubleArray array) => CompareArrays(array);
- public void Visit(BooleanArray array) => CompareArrays(array);
- public void Visit(TimestampArray array) => CompareArrays(array);
- public void Visit(Date32Array array) => CompareArrays(array);
- public void Visit(Date64Array array) => CompareArrays(array);
- public void Visit(ListArray array) => throw new
NotImplementedException();
- public void Visit(StringArray array) => throw new
NotImplementedException();
- public void Visit(BinaryArray array) => throw new
NotImplementedException();
- public void Visit(IArrowArray array) => throw new
NotImplementedException();
-
- private void CompareArrays<T>(PrimitiveArray<T> actualArray)
- where T : struct, IEquatable<T>
+ public override int Read(Span<byte> destination)
{
- Assert.IsAssignableFrom<PrimitiveArray<T>>(_expectedArray);
- PrimitiveArray<T> expectedArray =
(PrimitiveArray<T>)_expectedArray;
+ if (destination.Length > PartialReadLength)
+ {
+ destination = destination.Slice(0, PartialReadLength);
+ }
- Assert.Equal(expectedArray.Length, actualArray.Length);
- Assert.Equal(expectedArray.NullCount, actualArray.NullCount);
- Assert.Equal(expectedArray.Offset, actualArray.Offset);
+ return base.Read(destination);
+ }
+
+ public override ValueTask<int> ReadAsync(Memory<byte> destination,
CancellationToken cancellationToken = default)
+ {
+ if (destination.Length > PartialReadLength)
+ {
+ destination = destination.Slice(0, PartialReadLength);
+ }
-
Assert.True(expectedArray.NullBitmapBuffer.Span.SequenceEqual(actualArray.NullBitmapBuffer.Span));
- Assert.True(expectedArray.Values.Slice(0,
expectedArray.Length).SequenceEqual(actualArray.Values.Slice(0,
actualArray.Length)));
+ return base.ReadAsync(destination, cancellationToken);
}
}
}
}
+
diff --git a/csharp/test/Apache.Arrow.Tests/app.config
b/csharp/test/Apache.Arrow.Tests/app.config
deleted file mode 100644
index b90af84..0000000
--- a/csharp/test/Apache.Arrow.Tests/app.config
+++ /dev/null
@@ -1,15 +0,0 @@
-<?xml version="1.0" encoding="utf-8"?>
-<configuration>
- <runtime>
- <assemblyBinding xmlns="urn:schemas-microsoft-com:asm.v1">
- <dependentAssembly>
- <assemblyIdentity name="System.Numerics.Vectors"
publicKeyToken="b03f5f7f11d50a3a" culture="neutral" />
- <bindingRedirect oldVersion="0.0.0.0-4.1.4.0" newVersion="4.1.4.0" />
- </dependentAssembly>
- <dependentAssembly>
- <assemblyIdentity name="System.Buffers"
publicKeyToken="cc7b13ffcd2ddd51" culture="neutral" />
- <bindingRedirect oldVersion="0.0.0.0-4.0.3.0" newVersion="4.0.3.0" />
- </dependentAssembly>
- </assemblyBinding>
- </runtime>
-</configuration>
\ No newline at end of file