This is an automated email from the ASF dual-hosted git repository.
curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new cc771a0133 GH-40634: [C#] ArrowStreamReader should not be null (#40765)
cc771a0133 is described below
commit cc771a013362248269b75e054c2fed9c3d0f352a
Author: Curt Hagenlocher <[email protected]>
AuthorDate: Mon Mar 25 07:59:38 2024 -0700
GH-40634: [C#] ArrowStreamReader should not be null (#40765)
### What changes are included in this PR?
Small refactoring in the IPC reader implementation classes of how the
schema is read in order to support getting the schema asynchronously through
ArrowStreamReader and avoiding the case where ArrowStreamReader.Schema returns
null because no record batches have yet been read.
### Are these changes tested?
Yes.
### Are there any user-facing changes?
A new method ArrowStreamReader.GetSchema has been added to allow the schema
to be gotten asynchronously.
Closes #40634
* GitHub Issue: #40634
Authored-by: Curt Hagenlocher <[email protected]>
Signed-off-by: Curt Hagenlocher <[email protected]>
---
.../FlightRecordBatchStreamReader.cs | 4 ++--
.../Internal/RecordBatchReaderImplementation.cs | 27 ++++++++++++++++------
.../Ipc/ArrowFileReaderImplementation.cs | 6 ++---
.../Ipc/ArrowMemoryReaderImplementation.cs | 11 +++++++--
.../Apache.Arrow/Ipc/ArrowReaderImplementation.cs | 19 +++++++++++++--
csharp/src/Apache.Arrow/Ipc/ArrowStreamReader.cs | 12 ++++++++++
.../Ipc/ArrowStreamReaderImplementation.cs | 8 +++----
.../test/Apache.Arrow.Tests/ArrowReaderVerifier.cs | 3 +++
.../Apache.Arrow.Tests/ArrowStreamReaderTests.cs | 2 ++
9 files changed, 72 insertions(+), 20 deletions(-)
diff --git a/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamReader.cs
b/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamReader.cs
index d21fb25f5c..7400ec15e5 100644
--- a/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamReader.cs
+++ b/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamReader.cs
@@ -45,12 +45,12 @@ namespace Apache.Arrow.Flight
_arrowReaderImplementation = new
RecordBatchReaderImplementation(flightDataStream);
}
- public ValueTask<Schema> Schema =>
_arrowReaderImplementation.ReadSchema();
+ public ValueTask<Schema> Schema =>
_arrowReaderImplementation.GetSchemaAsync();
internal ValueTask<FlightDescriptor> GetFlightDescriptor()
{
return _arrowReaderImplementation.ReadFlightDescriptor();
- }
+ }
/// <summary>
/// Get the application metadata from the latest received record batch
diff --git
a/csharp/src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs
b/csharp/src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs
index be844ea58e..99876bf769 100644
--- a/csharp/src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs
+++ b/csharp/src/Apache.Arrow.Flight/Internal/RecordBatchReaderImplementation.cs
@@ -48,19 +48,33 @@ namespace Apache.Arrow.Flight.Internal
{
if (!HasReadSchema)
{
- await ReadSchema().ConfigureAwait(false);
+ await
ReadSchemaAsync(CancellationToken.None).ConfigureAwait(false);
}
return _flightDescriptor;
}
- public async ValueTask<Schema> ReadSchema()
+ public async ValueTask<Schema> GetSchemaAsync()
+ {
+ if (!HasReadSchema)
+ {
+ await
ReadSchemaAsync(CancellationToken.None).ConfigureAwait(false);
+ }
+ return _schema;
+ }
+
+ public override void ReadSchema()
+ {
+ ReadSchemaAsync(CancellationToken.None).AsTask().Wait();
+ }
+
+ public override async ValueTask ReadSchemaAsync(CancellationToken
cancellationToken)
{
if (HasReadSchema)
{
- return Schema;
+ return;
}
- var moveNextResult = await
_flightDataStream.MoveNext().ConfigureAwait(false);
+ var moveNextResult = await
_flightDataStream.MoveNext(cancellationToken).ConfigureAwait(false);
if (!moveNextResult)
{
@@ -87,12 +101,11 @@ namespace Apache.Arrow.Flight.Internal
switch (message.HeaderType)
{
case MessageHeader.Schema:
- Schema =
FlightMessageSerializer.DecodeSchema(message.ByteBuffer);
+ _schema =
FlightMessageSerializer.DecodeSchema(message.ByteBuffer);
break;
default:
throw new Exception($"Expected schema as the first
message, but got: {message.HeaderType.ToString()}");
}
- return Schema;
}
public override async ValueTask<RecordBatch>
ReadNextRecordBatchAsync(CancellationToken cancellationToken)
@@ -101,7 +114,7 @@ namespace Apache.Arrow.Flight.Internal
if (!HasReadSchema)
{
- await ReadSchema().ConfigureAwait(false);
+ await ReadSchemaAsync(cancellationToken).ConfigureAwait(false);
}
var moveNextResult = await
_flightDataStream.MoveNext().ConfigureAwait(false);
if (moveNextResult)
diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs
b/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs
index 02f36b0793..4b7c5f914c 100644
--- a/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs
+++ b/csharp/src/Apache.Arrow/Ipc/ArrowFileReaderImplementation.cs
@@ -52,7 +52,7 @@ namespace Apache.Arrow.Ipc
return _footer.RecordBatchCount;
}
- protected override async ValueTask ReadSchemaAsync(CancellationToken
cancellationToken = default)
+ public override async ValueTask ReadSchemaAsync(CancellationToken
cancellationToken = default)
{
if (HasReadSchema)
{
@@ -85,7 +85,7 @@ namespace Apache.Arrow.Ipc
}
}
- protected override void ReadSchema()
+ public override void ReadSchema()
{
if (HasReadSchema)
{
@@ -139,7 +139,7 @@ namespace Apache.Arrow.Ipc
// Deserialize the footer from the footer flatbuffer
_footer = new
ArrowFooter(Flatbuf.Footer.GetRootAsFooter(CreateByteBuffer(buffer)), ref
_dictionaryMemo);
- Schema = _footer.Schema;
+ _schema = _footer.Schema;
}
public async ValueTask<RecordBatch> ReadRecordBatchAsync(int index,
CancellationToken cancellationToken)
diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs
b/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs
index 6e2336a591..842c56823d 100644
--- a/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs
+++ b/csharp/src/Apache.Arrow/Ipc/ArrowMemoryReaderImplementation.cs
@@ -33,6 +33,13 @@ namespace Apache.Arrow.Ipc
_buffer = buffer;
}
+ public override ValueTask ReadSchemaAsync(CancellationToken
cancellationToken)
+ {
+ cancellationToken.ThrowIfCancellationRequested();
+ ReadSchema();
+ return default;
+ }
+
public override ValueTask<RecordBatch>
ReadNextRecordBatchAsync(CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
@@ -93,7 +100,7 @@ namespace Apache.Arrow.Ipc
return batch;
}
- private void ReadSchema()
+ public override void ReadSchema()
{
if (HasReadSchema)
{
@@ -117,7 +124,7 @@ namespace Apache.Arrow.Ipc
}
ByteBuffer schemaBuffer =
CreateByteBuffer(_buffer.Slice(_bufferPosition));
- Schema =
MessageSerializer.GetSchema(ReadMessage<Flatbuf.Schema>(schemaBuffer), ref
_dictionaryMemo);
+ _schema =
MessageSerializer.GetSchema(ReadMessage<Flatbuf.Schema>(schemaBuffer), ref
_dictionaryMemo);
_bufferPosition += schemaMessageLength;
}
}
diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs
b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs
index eb7349a570..4e273dbde5 100644
--- a/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs
+++ b/csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs
@@ -30,13 +30,25 @@ namespace Apache.Arrow.Ipc
{
internal abstract class ArrowReaderImplementation : IDisposable
{
- public Schema Schema { get; protected set; }
- protected bool HasReadSchema => Schema != null;
+ public Schema Schema
+ {
+ get
+ {
+ if (!HasReadSchema)
+ {
+ ReadSchema();
+ }
+ return _schema;
+ }
+ }
+
+ protected internal bool HasReadSchema => _schema != null;
private protected DictionaryMemo _dictionaryMemo;
private protected DictionaryMemo DictionaryMemo => _dictionaryMemo ??=
new DictionaryMemo();
private protected readonly MemoryAllocator _allocator;
private readonly ICompressionCodecFactory _compressionCodecFactory;
+ private protected Schema _schema;
private protected ArrowReaderImplementation() : this(null, null)
{ }
@@ -57,6 +69,9 @@ namespace Apache.Arrow.Ipc
{
}
+ public abstract ValueTask ReadSchemaAsync(CancellationToken
cancellationToken);
+ public abstract void ReadSchema();
+
public abstract ValueTask<RecordBatch>
ReadNextRecordBatchAsync(CancellationToken cancellationToken);
public abstract RecordBatch ReadNextRecordBatch();
diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReader.cs
b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReader.cs
index cdcfe7875d..e129da399d 100644
--- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReader.cs
+++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReader.cs
@@ -28,6 +28,9 @@ namespace Apache.Arrow.Ipc
{
private protected readonly ArrowReaderImplementation _implementation;
+ /// <summary>
+ /// May block if the schema hasn't yet been read. To avoid blocking,
use GetSchemaAsync.
+ /// </summary>
public Schema Schema => _implementation.Schema;
public ArrowStreamReader(Stream stream)
@@ -97,6 +100,15 @@ namespace Apache.Arrow.Ipc
}
}
+ public async ValueTask<Schema> GetSchema(CancellationToken
cancellationToken = default)
+ {
+ if (!_implementation.HasReadSchema)
+ {
+ await _implementation.ReadSchemaAsync(cancellationToken);
+ }
+ return _implementation.Schema;
+ }
+
public ValueTask<RecordBatch>
ReadNextRecordBatchAsync(CancellationToken cancellationToken = default)
{
return _implementation.ReadNextRecordBatchAsync(cancellationToken);
diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs
b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs
index 5428c88c27..5583a58487 100644
--- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs
+++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs
@@ -146,7 +146,7 @@ namespace Apache.Arrow.Ipc
return new ReadResult(messageLength, result);
}
- protected virtual async ValueTask ReadSchemaAsync(CancellationToken
cancellationToken = default)
+ public override async ValueTask ReadSchemaAsync(CancellationToken
cancellationToken = default)
{
if (HasReadSchema)
{
@@ -164,11 +164,11 @@ namespace Apache.Arrow.Ipc
EnsureFullRead(buff, bytesRead);
Google.FlatBuffers.ByteBuffer schemabb =
CreateByteBuffer(buff);
- Schema =
MessageSerializer.GetSchema(ReadMessage<Flatbuf.Schema>(schemabb), ref
_dictionaryMemo);
+ _schema =
MessageSerializer.GetSchema(ReadMessage<Flatbuf.Schema>(schemabb), ref
_dictionaryMemo);
}
}
- protected virtual void ReadSchema()
+ public override void ReadSchema()
{
if (HasReadSchema)
{
@@ -184,7 +184,7 @@ namespace Apache.Arrow.Ipc
EnsureFullRead(buff, bytesRead);
Google.FlatBuffers.ByteBuffer schemabb =
CreateByteBuffer(buff);
- Schema =
MessageSerializer.GetSchema(ReadMessage<Flatbuf.Schema>(schemabb), ref
_dictionaryMemo);
+ _schema =
MessageSerializer.GetSchema(ReadMessage<Flatbuf.Schema>(schemabb), ref
_dictionaryMemo);
}
}
diff --git a/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs
b/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs
index 10315ff287..2e7488092c 100644
--- a/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs
+++ b/csharp/test/Apache.Arrow.Tests/ArrowReaderVerifier.cs
@@ -38,6 +38,9 @@ namespace Apache.Arrow.Tests
public static async Task VerifyReaderAsync(ArrowStreamReader reader,
RecordBatch originalBatch)
{
+ Schema schema = await reader.GetSchema();
+ Assert.NotNull(schema);
+
RecordBatch readBatch = await reader.ReadNextRecordBatchAsync();
CompareBatches(originalBatch, readBatch);
diff --git a/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs
b/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs
index ed030cc6ac..b9e4664fdc 100644
--- a/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs
+++ b/csharp/test/Apache.Arrow.Tests/ArrowStreamReaderTests.cs
@@ -94,6 +94,8 @@ namespace Apache.Arrow.Tests
{
await TestReaderFromMemory((reader, originalBatch) =>
{
+ Assert.NotNull(reader.Schema);
+
ArrowReaderVerifier.VerifyReader(reader, originalBatch);
return Task.CompletedTask;
}, writeEnd);