This is an automated email from the ASF dual-hosted git repository.

curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new ead8d6fdd1 GH-44363: [C#] Handle Flight data with zero batches (#45315)
ead8d6fdd1 is described below

commit ead8d6fdd12a936cc35a51ccb5af80674eda0faa
Author: Adam Reeve <[email protected]>
AuthorDate: Wed Jan 22 03:00:02 2025 +1300

    GH-44363: [C#] Handle Flight data with zero batches (#45315)
    
    ### Rationale for this change
    
    See #44363. This improves compatibility with other Flight implementations 
and means user code works with empty data without needing to treat it as a 
special case to work around this limitation.
    
    ### What changes are included in this PR?
    
    * Adds new async overloads of `FlightClient.StartPut` that immediately send 
the schema, before any data batches are sent.
    * Updates the test server to send the schema on `DoGet` even when there are 
no data batches.
    * Enables the `primitive_no_batches` test case for C# Flight.
    
    ### Are these changes tested?
    
    Yes, using a new unit test and with the integration tests.
    
    ### Are there any user-facing changes?
    
    Yes. New overloads of the `FlightClient.StartPut` method have been added 
that are async and accept a `Schema` parameter, and ensure the schema is sent 
when no data batches are sent.
    
    * GitHub Issue: #44363
    
    Authored-by: Adam Reeve <[email protected]>
    Signed-off-by: Curt Hagenlocher <[email protected]>
---
 .../src/Apache.Arrow.Flight/Client/FlightClient.cs | 55 ++++++++++++++++++++++
 .../FlightRecordBatchStreamWriter.cs               | 21 +++++++--
 .../Internal/FlightDataStream.cs                   |  2 +-
 .../Scenarios/JsonTestScenario.cs                  |  7 +--
 .../TestFlightServer.cs                            |  3 +-
 .../test/Apache.Arrow.Flight.Tests/FlightTests.cs  | 36 +++++++++++---
 dev/archery/archery/integration/datagen.py         |  5 +-
 7 files changed, 109 insertions(+), 20 deletions(-)

diff --git a/csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs 
b/csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs
index b89ce9da79..10660f40b4 100644
--- a/csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs
+++ b/csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs
@@ -98,11 +98,39 @@ namespace Apache.Arrow.Flight.Client
                 flightInfoResult.Dispose);
         }
 
+        /// <summary>
+        /// Start a Flight Put request.
+        /// </summary>
+        /// <param name="flightDescriptor">Descriptor for the data to be 
put</param>
+        /// <param name="headers">gRPC headers to send with the request</param>
+        /// <returns>A <see cref="FlightRecordBatchDuplexStreamingCall" /> 
object used to write data batches and receive responses</returns>
         public FlightRecordBatchDuplexStreamingCall StartPut(FlightDescriptor 
flightDescriptor, Metadata headers = null)
         {
             return StartPut(flightDescriptor, headers, null, 
CancellationToken.None);
         }
 
+        /// <summary>
+        /// Start a Flight Put request.
+        /// </summary>
+        /// <param name="flightDescriptor">Descriptor for the data to be 
put</param>
+        /// <param name="schema">The schema of the data</param>
+        /// <param name="headers">gRPC headers to send with the request</param>
+        /// <returns>A <see cref="FlightRecordBatchDuplexStreamingCall" /> 
object used to write data batches and receive responses</returns>
+        /// <remarks>Using this method rather than a StartPut overload that 
doesn't accept a schema
+        /// means that the schema is sent even if no data batches are 
sent</remarks>
+        public Task<FlightRecordBatchDuplexStreamingCall> 
StartPut(FlightDescriptor flightDescriptor, Schema schema, Metadata headers = 
null)
+        {
+            return StartPut(flightDescriptor, schema, headers, null, 
CancellationToken.None);
+        }
+
+        /// <summary>
+        /// Start a Flight Put request.
+        /// </summary>
+        /// <param name="flightDescriptor">Descriptor for the data to be 
put</param>
+        /// <param name="headers">gRPC headers to send with the request</param>
+        /// <param name="deadline">Optional deadline. The request will be 
cancelled if this deadline is reached.</param>
+        /// <param name="cancellationToken">Optional token for cancelling the 
request</param>
+        /// <returns>A <see cref="FlightRecordBatchDuplexStreamingCall" /> 
object used to write data batches and receive responses</returns>
         public FlightRecordBatchDuplexStreamingCall StartPut(FlightDescriptor 
flightDescriptor, Metadata headers, System.DateTime? deadline, 
CancellationToken cancellationToken = default)
         {
             var channels = _client.DoPut(headers, deadline, cancellationToken);
@@ -117,6 +145,33 @@ namespace Apache.Arrow.Flight.Client
                 channels.Dispose);
         }
 
+        /// <summary>
+        /// Start a Flight Put request.
+        /// </summary>
+        /// <param name="flightDescriptor">Descriptor for the data to be 
put</param>
+        /// <param name="schema">The schema of the data</param>
+        /// <param name="headers">gRPC headers to send with the request</param>
+        /// <param name="deadline">Optional deadline. The request will be 
cancelled if this deadline is reached.</param>
+        /// <param name="cancellationToken">Optional token for cancelling the 
request</param>
+        /// <returns>A <see cref="FlightRecordBatchDuplexStreamingCall" /> 
object used to write data batches and receive responses</returns>
+        /// <remarks>Using this method rather than a StartPut overload that 
doesn't accept a schema
+        /// means that the schema is sent even if no data batches are 
sent</remarks>
+        public async Task<FlightRecordBatchDuplexStreamingCall> 
StartPut(FlightDescriptor flightDescriptor, Schema schema, Metadata headers, 
System.DateTime? deadline, CancellationToken cancellationToken = default)
+        {
+            var channels = _client.DoPut(headers, deadline, cancellationToken);
+            var requestStream = new 
FlightClientRecordBatchStreamWriter(channels.RequestStream, flightDescriptor);
+            var readStream = new StreamReader<Protocol.PutResult, 
FlightPutResult>(channels.ResponseStream, putResult => new 
FlightPutResult(putResult));
+            var streamingCall = new FlightRecordBatchDuplexStreamingCall(
+                requestStream,
+                readStream,
+                channels.ResponseHeadersAsync,
+                channels.GetStatus,
+                channels.GetTrailers,
+                channels.Dispose);
+            await 
streamingCall.RequestStream.SetupStream(schema).ConfigureAwait(false);
+            return streamingCall;
+        }
+
         public AsyncDuplexStreamingCall<FlightHandshakeRequest, 
FlightHandshakeResponse> Handshake(Metadata headers = null)
         {
             return Handshake(headers, null, CancellationToken.None);
diff --git a/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamWriter.cs 
b/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamWriter.cs
index 7a8a6fd677..314d46da00 100644
--- a/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamWriter.cs
+++ b/csharp/src/Apache.Arrow.Flight/FlightRecordBatchStreamWriter.cs
@@ -38,9 +38,22 @@ namespace Apache.Arrow.Flight
             _flightDescriptor = flightDescriptor;
         }
 
-        private void SetupStream(Schema schema)
+        /// <summary>
+        /// Configure the data stream to write to.
+        /// </summary>
+        /// <remarks>
+        /// The stream will be set up automatically when writing a RecordBatch 
if required,
+        /// but calling this method before writing any data allows handling 
empty streams.
+        /// </remarks>
+        /// <param name="schema">The schema of data to be written to this 
stream</param>
+        public async Task SetupStream(Schema schema)
         {
+            if (_flightDataStream != null)
+            {
+                throw new InvalidOperationException("Flight data stream is 
already set");
+            }
             _flightDataStream = new FlightDataStream(_clientStreamWriter, 
_flightDescriptor, schema);
+            await _flightDataStream.SendSchema().ConfigureAwait(false);
         }
 
         public WriteOptions WriteOptions { get => throw new 
NotImplementedException(); set => throw new NotImplementedException(); }
@@ -50,14 +63,14 @@ namespace Apache.Arrow.Flight
             return WriteAsync(message, default);
         }
 
-        public Task WriteAsync(RecordBatch message, ByteString 
applicationMetadata)
+        public async Task WriteAsync(RecordBatch message, ByteString 
applicationMetadata)
         {
             if (_flightDataStream == null)
             {
-                SetupStream(message.Schema);
+                await SetupStream(message.Schema).ConfigureAwait(false);
             }
 
-            return _flightDataStream.Write(message, applicationMetadata);
+            await _flightDataStream.Write(message, applicationMetadata);
         }
 
         protected virtual void Dispose(bool disposing)
diff --git a/csharp/src/Apache.Arrow.Flight/Internal/FlightDataStream.cs 
b/csharp/src/Apache.Arrow.Flight/Internal/FlightDataStream.cs
index 72c1551be2..7cbbe66f40 100644
--- a/csharp/src/Apache.Arrow.Flight/Internal/FlightDataStream.cs
+++ b/csharp/src/Apache.Arrow.Flight/Internal/FlightDataStream.cs
@@ -44,7 +44,7 @@ namespace Apache.Arrow.Flight.Internal
             _flightDescriptor = flightDescriptor;
         }
 
-        private async Task SendSchema()
+        public async Task SendSchema()
         {
             _currentFlightData = new Protocol.FlightData();
 
diff --git 
a/csharp/test/Apache.Arrow.Flight.IntegrationTest/Scenarios/JsonTestScenario.cs 
b/csharp/test/Apache.Arrow.Flight.IntegrationTest/Scenarios/JsonTestScenario.cs
index 4f7fed7435..7847510440 100644
--- 
a/csharp/test/Apache.Arrow.Flight.IntegrationTest/Scenarios/JsonTestScenario.cs
+++ 
b/csharp/test/Apache.Arrow.Flight.IntegrationTest/Scenarios/JsonTestScenario.cs
@@ -76,7 +76,7 @@ internal class JsonTestScenario : IScenario
         var batches = jsonFile.Batches.Select(batch => batch.ToArrow(schema, 
dictionaries)).ToArray();
 
         // 1. Put the data to the server.
-        await UploadBatches(client, descriptor, batches).ConfigureAwait(false);
+        await UploadBatches(client, descriptor, schema, 
batches).ConfigureAwait(false);
 
         // 2. Get the ticket for the data.
         var info = await client.GetInfo(descriptor).ConfigureAwait(false);
@@ -112,9 +112,10 @@ internal class JsonTestScenario : IScenario
         }
     }
 
-    private static async Task UploadBatches(FlightClient client, 
FlightDescriptor descriptor, RecordBatch[] batches)
+    private static async Task UploadBatches(
+        FlightClient client, FlightDescriptor descriptor, Schema schema, 
RecordBatch[] batches)
     {
-        using var putCall = client.StartPut(descriptor);
+        using var putCall = await client.StartPut(descriptor, schema);
         using var writer = putCall.RequestStream;
 
         try
diff --git a/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs 
b/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs
index 46c5460912..5689b45bfd 100644
--- a/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs
+++ b/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs
@@ -51,9 +51,10 @@ namespace Apache.Arrow.Flight.TestWeb
 
             if(_flightStore.Flights.TryGetValue(flightDescriptor, out var 
flightHolder))
             {
+                await 
responseStream.SetupStream(flightHolder.GetFlightInfo().Schema);
+
                 var batches = flightHolder.GetRecordBatches();
 
-                
                 foreach(var batch in batches)
                 {
                     await responseStream.WriteAsync(batch.RecordBatch, 
batch.Metadata);
diff --git a/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs 
b/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs
index 350762c992..241b3c006a 100644
--- a/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs
+++ b/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs
@@ -57,6 +57,13 @@ namespace Apache.Arrow.Flight.Tests
             return batchBuilder.Build();
         }
 
+        private Schema GetStoreSchema(FlightDescriptor flightDescriptor)
+        {
+            Assert.Contains(flightDescriptor, 
(IReadOnlyDictionary<FlightDescriptor, FlightHolder>)_flightStore.Flights);
+
+            var flightHolder = _flightStore.Flights[flightDescriptor];
+            return flightHolder.GetFlightInfo().Schema;
+        }
 
         private IEnumerable<RecordBatchWithMetadata> 
GetStoreBatch(FlightDescriptor flightDescriptor)
         {
@@ -88,7 +95,7 @@ namespace Apache.Arrow.Flight.Tests
             var flightDescriptor = 
FlightDescriptor.CreatePathDescriptor("test");
             var expectedBatch = CreateTestBatch(0, 100);
 
-            var putStream = _flightClient.StartPut(flightDescriptor);
+            var putStream = await _flightClient.StartPut(flightDescriptor, 
expectedBatch.Schema);
             await putStream.RequestStream.WriteAsync(expectedBatch);
             await putStream.RequestStream.CompleteAsync();
             var putResults = await putStream.ResponseStream.ToListAsync();
@@ -108,7 +115,7 @@ namespace Apache.Arrow.Flight.Tests
             var expectedBatch1 = CreateTestBatch(0, 100);
             var expectedBatch2 = CreateTestBatch(0, 100);
 
-            var putStream = _flightClient.StartPut(flightDescriptor);
+            var putStream = await _flightClient.StartPut(flightDescriptor, 
expectedBatch1.Schema);
             await putStream.RequestStream.WriteAsync(expectedBatch1);
             await putStream.RequestStream.WriteAsync(expectedBatch2);
             await putStream.RequestStream.CompleteAsync();
@@ -123,6 +130,23 @@ namespace Apache.Arrow.Flight.Tests
             ArrowReaderVerifier.CompareBatches(expectedBatch2, 
actualBatches[1].RecordBatch);
         }
 
+        [Fact]
+        public async Task TestPutZeroRecordBatches()
+        {
+            var flightDescriptor = 
FlightDescriptor.CreatePathDescriptor("test");
+            var schema = CreateTestBatch(0, 1).Schema;
+
+            var putStream = await _flightClient.StartPut(flightDescriptor, 
schema);
+            await putStream.RequestStream.CompleteAsync();
+            var putResults = await putStream.ResponseStream.ToListAsync();
+
+            Assert.Empty(putResults);
+
+            var actualSchema = GetStoreSchema(flightDescriptor);
+
+            SchemaComparer.Compare(schema, actualSchema);
+        }
+
         [Fact]
         public async Task TestGetRecordBatchWithDelayedSchema()
         {
@@ -230,7 +254,7 @@ namespace Apache.Arrow.Flight.Tests
             var expectedBatch = CreateTestBatch(0, 100);
             var expectedMetadata = ByteString.CopyFromUtf8("test metadata");
 
-            var putStream = _flightClient.StartPut(flightDescriptor);
+            var putStream = await _flightClient.StartPut(flightDescriptor, 
expectedBatch.Schema);
             await putStream.RequestStream.WriteAsync(expectedBatch, 
expectedMetadata);
             await putStream.RequestStream.CompleteAsync();
             var putResults = await putStream.ResponseStream.ToListAsync();
@@ -471,8 +495,7 @@ namespace Apache.Arrow.Flight.Tests
             exception = await Assert.ThrowsAsync<RpcException>(async () => 
await duplexStreamingCall.RequestStream.WriteAsync(batch));
             Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode);
 
-            var putStream = _flightClient.StartPut(flightDescriptor, null, 
deadline);
-            exception = await Assert.ThrowsAsync<RpcException>(async () => 
await putStream.RequestStream.WriteAsync(batch));
+            exception = await Assert.ThrowsAsync<RpcException>(async () => 
await _flightClient.StartPut(flightDescriptor, batch.Schema, null, deadline));
             Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode);
 
             exception = await Assert.ThrowsAsync<RpcException>(async () => 
await _flightClient.GetSchema(flightDescriptor, null, deadline));
@@ -514,8 +537,7 @@ namespace Apache.Arrow.Flight.Tests
             exception = await Assert.ThrowsAsync<RpcException>(async () => 
await duplexStreamingCall.RequestStream.WriteAsync(batch));
             Assert.Equal(StatusCode.Cancelled, exception.StatusCode);
 
-            var putStream = _flightClient.StartPut(flightDescriptor, null, 
null, cts.Token);
-            exception = await Assert.ThrowsAsync<RpcException>(async () => 
await putStream.RequestStream.WriteAsync(batch));
+            exception = await Assert.ThrowsAsync<RpcException>(async () => 
await _flightClient.StartPut(flightDescriptor, batch.Schema, null, null, 
cts.Token));
             Assert.Equal(StatusCode.Cancelled, exception.StatusCode);
 
             exception = await Assert.ThrowsAsync<RpcException>(async () => 
await _flightClient.GetSchema(flightDescriptor, null, null, cts.Token));
diff --git a/dev/archery/archery/integration/datagen.py 
b/dev/archery/archery/integration/datagen.py
index b4fbbb2d41..027e675792 100644
--- a/dev/archery/archery/integration/datagen.py
+++ b/dev/archery/archery/integration/datagen.py
@@ -1890,10 +1890,7 @@ def get_generated_json_files(tempdir=None):
         return
 
     file_objs = [
-        generate_primitive_case([], name='primitive_no_batches')
-        # TODO(https://github.com/apache/arrow/issues/44363)
-        .skip_format(SKIP_FLIGHT, 'C#'),
-
+        generate_primitive_case([], name='primitive_no_batches'),
         generate_primitive_case([17, 20], name='primitive'),
         generate_primitive_case([0, 0, 0], name='primitive_zerolength'),
 

Reply via email to