This is an automated email from the ASF dual-hosted git repository.

curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new b0786d48a5 GH-43907: [C#][FlightRPC] Add Grpc Call Options support on 
Flight Client (#43910)
b0786d48a5 is described below

commit b0786d48a58a5f95fe22db12932a3a9dffb4101f
Author: qmmk <[email protected]>
AuthorDate: Tue Sep 3 15:30:06 2024 +0200

    GH-43907: [C#][FlightRPC] Add Grpc Call Options support on Flight Client 
(#43910)
    
    ### Rationale for this change
    
    This implementation add default grpc call options on the csharp 
implementation FlightClient
    
    ### What changes are included in this PR?
    
    - FlightClient.cs with updated signature for all the methods accepting grpc 
call options
    - FlightTest.cs update test to verify the raise of the right exception
    
    ### Are these changes tested?
    
    Yes, tests are added in FlightTest.cs
    I've tested locally with the C++ implementation.
    
    ### Are there any user-facing changes?
    
    No is transparent for the user, following the already present documentation 
should be sufficient.
    
    ### References
    
    * GitHub Issue: #43907
    
    Authored-by: Marco Malagoli <[email protected]>
    Signed-off-by: Curt Hagenlocher <[email protected]>
---
 .../src/Apache.Arrow.Flight/Client/FlightClient.cs | 69 ++++++++++++---
 .../test/Apache.Arrow.Flight.Tests/FlightTests.cs  | 97 ++++++++++++++++++++--
 2 files changed, 150 insertions(+), 16 deletions(-)

diff --git a/csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs 
b/csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs
index efb22b1948..b89ce9da79 100644
--- a/csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs
+++ b/csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs
@@ -13,6 +13,7 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
+using System.Threading;
 using System.Threading.Tasks;
 using Apache.Arrow.Flight.Internal;
 using Apache.Arrow.Flight.Protocol;
@@ -34,12 +35,17 @@ namespace Apache.Arrow.Flight.Client
 
         public AsyncServerStreamingCall<FlightInfo> ListFlights(FlightCriteria 
criteria = null, Metadata headers = null)
         {
-            if(criteria == null)
+            return ListFlights(criteria, headers, null, 
CancellationToken.None);
+        }
+
+        public AsyncServerStreamingCall<FlightInfo> ListFlights(FlightCriteria 
criteria, Metadata headers, System.DateTime? deadline, CancellationToken 
cancellationToken = default)
+        {
+            if (criteria == null)
             {
                 criteria = FlightCriteria.Empty;
             }
-            
-            var response = _client.ListFlights(criteria.ToProtocol(), headers);
+
+            var response = _client.ListFlights(criteria.ToProtocol(), headers, 
deadline, cancellationToken);
             var convertStream = new StreamReader<Protocol.FlightInfo, 
FlightInfo>(response.ResponseStream, inFlight => new FlightInfo(inFlight));
 
             return new AsyncServerStreamingCall<FlightInfo>(convertStream, 
response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, 
response.Dispose);
@@ -47,7 +53,12 @@ namespace Apache.Arrow.Flight.Client
 
         public AsyncServerStreamingCall<FlightActionType> ListActions(Metadata 
headers = null)
         {
-            var response = _client.ListActions(EmptyInstance, headers);
+            return ListActions(headers, null, CancellationToken.None);
+        }
+
+        public AsyncServerStreamingCall<FlightActionType> ListActions(Metadata 
headers, System.DateTime? deadline, CancellationToken cancellationToken = 
default)
+        {
+            var response = _client.ListActions(EmptyInstance, headers, 
deadline, cancellationToken);
             var convertStream = new StreamReader<Protocol.ActionType, 
FlightActionType>(response.ResponseStream, actionType => new 
FlightActionType(actionType));
 
             return new 
AsyncServerStreamingCall<FlightActionType>(convertStream, 
response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, 
response.Dispose);
@@ -55,14 +66,24 @@ namespace Apache.Arrow.Flight.Client
 
         public FlightRecordBatchStreamingCall GetStream(FlightTicket ticket, 
Metadata headers = null)
         {
-            var stream = _client.DoGet(ticket.ToProtocol(),  headers);
+            return GetStream(ticket, headers, null, CancellationToken.None);
+        }
+
+        public FlightRecordBatchStreamingCall GetStream(FlightTicket ticket, 
Metadata headers, System.DateTime? deadline, CancellationToken 
cancellationToken = default)
+        {
+            var stream = _client.DoGet(ticket.ToProtocol(), headers, deadline, 
cancellationToken);
             var responseStream = new 
FlightClientRecordBatchStreamReader(stream.ResponseStream);
             return new FlightRecordBatchStreamingCall(responseStream, 
stream.ResponseHeadersAsync, stream.GetStatus, stream.GetTrailers, 
stream.Dispose);
         }
 
         public AsyncUnaryCall<FlightInfo> GetInfo(FlightDescriptor 
flightDescriptor, Metadata headers = null)
         {
-            var flightInfoResult = 
_client.GetFlightInfoAsync(flightDescriptor.ToProtocol(), headers);
+            return GetInfo(flightDescriptor, headers, null, 
CancellationToken.None);
+        }
+
+        public AsyncUnaryCall<FlightInfo> GetInfo(FlightDescriptor 
flightDescriptor, Metadata headers, System.DateTime? deadline, 
CancellationToken cancellationToken = default)
+        {
+            var flightInfoResult = 
_client.GetFlightInfoAsync(flightDescriptor.ToProtocol(), headers, deadline, 
cancellationToken);
 
             var flightInfo = flightInfoResult
                 .ResponseAsync
@@ -79,7 +100,12 @@ namespace Apache.Arrow.Flight.Client
 
         public FlightRecordBatchDuplexStreamingCall StartPut(FlightDescriptor 
flightDescriptor, Metadata headers = null)
         {
-            var channels = _client.DoPut(headers);
+            return StartPut(flightDescriptor, headers, null, 
CancellationToken.None);
+        }
+
+        public FlightRecordBatchDuplexStreamingCall StartPut(FlightDescriptor 
flightDescriptor, Metadata headers, System.DateTime? deadline, 
CancellationToken cancellationToken = default)
+        {
+            var channels = _client.DoPut(headers, deadline, cancellationToken);
             var requestStream = new 
FlightClientRecordBatchStreamWriter(channels.RequestStream, flightDescriptor);
             var readStream = new StreamReader<Protocol.PutResult, 
FlightPutResult>(channels.ResponseStream, putResult => new 
FlightPutResult(putResult));
             return new FlightRecordBatchDuplexStreamingCall(
@@ -93,7 +119,13 @@ namespace Apache.Arrow.Flight.Client
 
         public AsyncDuplexStreamingCall<FlightHandshakeRequest, 
FlightHandshakeResponse> Handshake(Metadata headers = null)
         {
-            var channel = _client.Handshake(headers);
+            return Handshake(headers, null, CancellationToken.None);
+
+        }
+
+        public AsyncDuplexStreamingCall<FlightHandshakeRequest, 
FlightHandshakeResponse> Handshake(Metadata headers, System.DateTime? deadline, 
CancellationToken cancellationToken = default)
+        {
+            var channel = _client.Handshake(headers, deadline, 
cancellationToken);
             var readStream = new StreamReader<HandshakeResponse, 
FlightHandshakeResponse>(channel.ResponseStream, response => new 
FlightHandshakeResponse(response));
             var writeStream = new 
FlightHandshakeStreamWriterAdapter(channel.RequestStream);
             var call = new AsyncDuplexStreamingCall<FlightHandshakeRequest, 
FlightHandshakeResponse>(
@@ -109,7 +141,12 @@ namespace Apache.Arrow.Flight.Client
 
         public FlightRecordBatchExchangeCall DoExchange(FlightDescriptor 
flightDescriptor, Metadata headers = null)
         {
-            var channel = _client.DoExchange(headers);
+            return DoExchange(flightDescriptor, headers, null, 
CancellationToken.None);
+        }
+
+        public FlightRecordBatchExchangeCall DoExchange(FlightDescriptor 
flightDescriptor, Metadata headers, System.DateTime? deadline, 
CancellationToken cancellationToken = default)
+        {
+            var channel = _client.DoExchange(headers, deadline, 
cancellationToken);
             var requestStream = new 
FlightClientRecordBatchStreamWriter(channel.RequestStream, flightDescriptor);
             var responseStream = new 
FlightClientRecordBatchStreamReader(channel.ResponseStream);
             var call = new FlightRecordBatchExchangeCall(
@@ -125,14 +162,24 @@ namespace Apache.Arrow.Flight.Client
 
         public AsyncServerStreamingCall<FlightResult> DoAction(FlightAction 
action, Metadata headers = null)
         {
-            var stream = _client.DoAction(action.ToProtocol(), headers);
+            return DoAction(action, headers, null, CancellationToken.None);
+        }
+
+        public AsyncServerStreamingCall<FlightResult> DoAction(FlightAction 
action, Metadata headers, System.DateTime? deadline, CancellationToken 
cancellationToken = default)
+        {
+            var stream = _client.DoAction(action.ToProtocol(), headers, 
deadline, cancellationToken);
             var streamReader = new StreamReader<Protocol.Result, 
FlightResult>(stream.ResponseStream, result => new FlightResult(result));
             return new AsyncServerStreamingCall<FlightResult>(streamReader, 
stream.ResponseHeadersAsync, stream.GetStatus, stream.GetTrailers, 
stream.Dispose);
         }
 
         public AsyncUnaryCall<Schema> GetSchema(FlightDescriptor 
flightDescriptor, Metadata headers = null)
         {
-            var schemaResult = 
_client.GetSchemaAsync(flightDescriptor.ToProtocol(), headers);
+            return GetSchema(flightDescriptor, headers, null, 
CancellationToken.None);
+        }
+
+        public AsyncUnaryCall<Schema> GetSchema(FlightDescriptor 
flightDescriptor, Metadata headers, System.DateTime? deadline, 
CancellationToken cancellationToken = default)
+        {
+            var schemaResult = 
_client.GetSchemaAsync(flightDescriptor.ToProtocol(), headers, deadline, 
cancellationToken);
 
             var schema = schemaResult
                 .ResponseAsync
diff --git a/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs 
b/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs
index aac4e42092..8bf6e1120c 100644
--- a/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs
+++ b/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs
@@ -16,12 +16,15 @@
 using System;
 using System.Collections.Generic;
 using System.Linq;
+using System.Threading;
 using System.Threading.Tasks;
 using Apache.Arrow.Flight.Client;
 using Apache.Arrow.Flight.TestWeb;
 using Apache.Arrow.Tests;
 using Google.Protobuf;
+using Grpc.Core;
 using Grpc.Core.Utils;
+using Python.Runtime;
 using Xunit;
 
 namespace Apache.Arrow.Flight.Tests
@@ -70,7 +73,7 @@ namespace Apache.Arrow.Flight.Tests
 
             var flightHolder = new FlightHolder(flightDescriptor, 
initialBatch.RecordBatch.Schema, _testWebFactory.GetAddress());
 
-            foreach(var batch in batches)
+            foreach (var batch in batches)
             {
                 flightHolder.AddBatch(batch);
             }
@@ -187,8 +190,8 @@ namespace Apache.Arrow.Flight.Tests
 
             var getStream = _flightClient.GetStream(endpoint.Ticket);
 
-            List<ByteString> actualMetadata = new List<ByteString>(); 
-            while(await getStream.ResponseStream.MoveNext(default))
+            List<ByteString> actualMetadata = new List<ByteString>();
+            while (await getStream.ResponseStream.MoveNext(default))
             {
                 
actualMetadata.AddRange(getStream.ResponseStream.ApplicationMetadata);
             }
@@ -277,7 +280,7 @@ namespace Apache.Arrow.Flight.Tests
 
             var actualFlights = await 
listFlightStream.ResponseStream.ToListAsync();
 
-            for(int i = 0; i < expectedFlightInfo.Count; i++)
+            for (int i = 0; i < expectedFlightInfo.Count; i++)
             {
                 FlightInfoComparer.Compare(expectedFlightInfo[i], 
actualFlights[i]);
             }
@@ -386,7 +389,7 @@ namespace Apache.Arrow.Flight.Tests
 
 
             List<RecordBatch> resultList = new List<RecordBatch>();
-            await foreach(var recordBatch in getStream.ResponseStream)
+            await foreach (var recordBatch in getStream.ResponseStream)
             {
                 resultList.Add(recordBatch);
             }
@@ -415,5 +418,89 @@ namespace Apache.Arrow.Flight.Tests
             Assert.Equal(expectedBatch.Length, result.TotalRecords);
             Assert.Equal(expectedTotalBytes, result.TotalBytes);
         }
+
+        [Fact]
+        public async Task EnsureCallRaisesDeadlineExceeded()
+        {
+            var flightDescriptor = 
FlightDescriptor.CreatePathDescriptor("raise_deadline");
+            var deadline = DateTime.UtcNow;
+            var batch = CreateTestBatch(0, 100);
+
+            RpcException exception = null;
+
+            var asyncServerStreamingCallFlights = 
_flightClient.ListFlights(null, null, deadline);
+            Assert.Equal(StatusCode.DeadlineExceeded, 
asyncServerStreamingCallFlights.GetStatus().StatusCode);
+
+            var asyncServerStreamingCallActions = 
_flightClient.ListActions(null, deadline);
+            Assert.Equal(StatusCode.DeadlineExceeded, 
asyncServerStreamingCallFlights.GetStatus().StatusCode);
+
+            GivenStoreBatches(flightDescriptor, new 
RecordBatchWithMetadata(batch));
+            exception = await Assert.ThrowsAsync<RpcException>(async () => 
await _flightClient.GetInfo(flightDescriptor, null, deadline));
+            Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode);
+
+            var flightInfo = await _flightClient.GetInfo(flightDescriptor);
+            var endpoint = flightInfo.Endpoints.FirstOrDefault();
+            var getStream = _flightClient.GetStream(endpoint.Ticket, null, 
deadline);
+            Assert.Equal(StatusCode.DeadlineExceeded, 
getStream.GetStatus().StatusCode);
+
+            var duplexStreamingCall = 
_flightClient.DoExchange(flightDescriptor, null, deadline);
+            exception = await Assert.ThrowsAsync<RpcException>(async () => 
await duplexStreamingCall.RequestStream.WriteAsync(batch));
+            Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode);
+
+            var putStream = _flightClient.StartPut(flightDescriptor, null, 
deadline);
+            exception = await Assert.ThrowsAsync<RpcException>(async () => 
await putStream.RequestStream.WriteAsync(batch));
+            Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode);
+
+            exception = await Assert.ThrowsAsync<RpcException>(async () => 
await _flightClient.GetSchema(flightDescriptor, null, deadline));
+            Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode);
+
+            var handshakeStreamingCall = _flightClient.Handshake(null, 
deadline);
+            exception = await Assert.ThrowsAsync<RpcException>(async () => 
await handshakeStreamingCall.RequestStream.WriteAsync(new 
FlightHandshakeRequest(ByteString.Empty)));
+            Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode);
+        }
+
+        [Fact]
+        public async Task EnsureCallRaisesRequestCancelled()
+        {
+            var cts = new CancellationTokenSource();
+            cts.CancelAfter(1);
+            
+            var batch = CreateTestBatch(0, 100);
+            var metadata = new Metadata();
+            var flightDescriptor = 
FlightDescriptor.CreatePathDescriptor("raise_cancelled");
+            await Task.Delay(5);
+            RpcException exception = null;
+
+            var asyncServerStreamingCallFlights = 
_flightClient.ListFlights(null, null, null, cts.Token);
+            Assert.Equal(StatusCode.Cancelled, 
asyncServerStreamingCallFlights.GetStatus().StatusCode);
+
+            var asyncServerStreamingCallActions = 
_flightClient.ListActions(null, null, cts.Token);
+            Assert.Equal(StatusCode.Cancelled, 
asyncServerStreamingCallFlights.GetStatus().StatusCode);
+
+            GivenStoreBatches(flightDescriptor, new 
RecordBatchWithMetadata(batch));
+            exception = await Assert.ThrowsAsync<RpcException>(async () => 
await _flightClient.GetInfo(flightDescriptor, null, null, cts.Token));
+            Assert.Equal(StatusCode.Cancelled, exception.StatusCode);
+
+            var flightInfo = await _flightClient.GetInfo(flightDescriptor);
+            var endpoint = flightInfo.Endpoints.FirstOrDefault();
+            var getStream = _flightClient.GetStream(endpoint.Ticket, null, 
null, cts.Token);
+            Assert.Equal(StatusCode.Cancelled, 
getStream.GetStatus().StatusCode);
+
+            var duplexStreamingCall = 
_flightClient.DoExchange(flightDescriptor, null, null, cts.Token);
+            exception = await Assert.ThrowsAsync<RpcException>(async () => 
await duplexStreamingCall.RequestStream.WriteAsync(batch));
+            Assert.Equal(StatusCode.Cancelled, exception.StatusCode);
+
+            var putStream = _flightClient.StartPut(flightDescriptor, null, 
null, cts.Token);
+            exception = await Assert.ThrowsAsync<RpcException>(async () => 
await putStream.RequestStream.WriteAsync(batch));
+            Assert.Equal(StatusCode.Cancelled, exception.StatusCode);
+
+            exception = await Assert.ThrowsAsync<RpcException>(async () => 
await _flightClient.GetSchema(flightDescriptor, null, null, cts.Token));
+            Assert.Equal(StatusCode.Cancelled, exception.StatusCode);
+
+            var handshakeStreamingCall = _flightClient.Handshake(null, null, 
cts.Token);
+            exception = await Assert.ThrowsAsync<RpcException>(async () => 
await handshakeStreamingCall.RequestStream.WriteAsync(new 
FlightHandshakeRequest(ByteString.Empty)));
+            Assert.Equal(StatusCode.Cancelled, exception.StatusCode);
+
+        }
     }
 }

Reply via email to