This is an automated email from the ASF dual-hosted git repository.
curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new b0786d48a5 GH-43907: [C#][FlightRPC] Add Grpc Call Options support on
Flight Client (#43910)
b0786d48a5 is described below
commit b0786d48a58a5f95fe22db12932a3a9dffb4101f
Author: qmmk <[email protected]>
AuthorDate: Tue Sep 3 15:30:06 2024 +0200
GH-43907: [C#][FlightRPC] Add Grpc Call Options support on Flight Client
(#43910)
### Rationale for this change
This implementation add default grpc call options on the csharp
implementation FlightClient
### What changes are included in this PR?
- FlightClient.cs with updated signature for all the methods accepting grpc
call options
- FlightTest.cs update test to verify the raise of the right exception
### Are these changes tested?
Yes, tests are added in FlightTest.cs
I've tested locally with the C++ implementation.
### Are there any user-facing changes?
No is transparent for the user, following the already present documentation
should be sufficient.
### References
* GitHub Issue: #43907
Authored-by: Marco Malagoli <[email protected]>
Signed-off-by: Curt Hagenlocher <[email protected]>
---
.../src/Apache.Arrow.Flight/Client/FlightClient.cs | 69 ++++++++++++---
.../test/Apache.Arrow.Flight.Tests/FlightTests.cs | 97 ++++++++++++++++++++--
2 files changed, 150 insertions(+), 16 deletions(-)
diff --git a/csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs
b/csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs
index efb22b1948..b89ce9da79 100644
--- a/csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs
+++ b/csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs
@@ -13,6 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Flight.Internal;
using Apache.Arrow.Flight.Protocol;
@@ -34,12 +35,17 @@ namespace Apache.Arrow.Flight.Client
public AsyncServerStreamingCall<FlightInfo> ListFlights(FlightCriteria
criteria = null, Metadata headers = null)
{
- if(criteria == null)
+ return ListFlights(criteria, headers, null,
CancellationToken.None);
+ }
+
+ public AsyncServerStreamingCall<FlightInfo> ListFlights(FlightCriteria
criteria, Metadata headers, System.DateTime? deadline, CancellationToken
cancellationToken = default)
+ {
+ if (criteria == null)
{
criteria = FlightCriteria.Empty;
}
-
- var response = _client.ListFlights(criteria.ToProtocol(), headers);
+
+ var response = _client.ListFlights(criteria.ToProtocol(), headers,
deadline, cancellationToken);
var convertStream = new StreamReader<Protocol.FlightInfo,
FlightInfo>(response.ResponseStream, inFlight => new FlightInfo(inFlight));
return new AsyncServerStreamingCall<FlightInfo>(convertStream,
response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers,
response.Dispose);
@@ -47,7 +53,12 @@ namespace Apache.Arrow.Flight.Client
public AsyncServerStreamingCall<FlightActionType> ListActions(Metadata
headers = null)
{
- var response = _client.ListActions(EmptyInstance, headers);
+ return ListActions(headers, null, CancellationToken.None);
+ }
+
+ public AsyncServerStreamingCall<FlightActionType> ListActions(Metadata
headers, System.DateTime? deadline, CancellationToken cancellationToken =
default)
+ {
+ var response = _client.ListActions(EmptyInstance, headers,
deadline, cancellationToken);
var convertStream = new StreamReader<Protocol.ActionType,
FlightActionType>(response.ResponseStream, actionType => new
FlightActionType(actionType));
return new
AsyncServerStreamingCall<FlightActionType>(convertStream,
response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers,
response.Dispose);
@@ -55,14 +66,24 @@ namespace Apache.Arrow.Flight.Client
public FlightRecordBatchStreamingCall GetStream(FlightTicket ticket,
Metadata headers = null)
{
- var stream = _client.DoGet(ticket.ToProtocol(), headers);
+ return GetStream(ticket, headers, null, CancellationToken.None);
+ }
+
+ public FlightRecordBatchStreamingCall GetStream(FlightTicket ticket,
Metadata headers, System.DateTime? deadline, CancellationToken
cancellationToken = default)
+ {
+ var stream = _client.DoGet(ticket.ToProtocol(), headers, deadline,
cancellationToken);
var responseStream = new
FlightClientRecordBatchStreamReader(stream.ResponseStream);
return new FlightRecordBatchStreamingCall(responseStream,
stream.ResponseHeadersAsync, stream.GetStatus, stream.GetTrailers,
stream.Dispose);
}
public AsyncUnaryCall<FlightInfo> GetInfo(FlightDescriptor
flightDescriptor, Metadata headers = null)
{
- var flightInfoResult =
_client.GetFlightInfoAsync(flightDescriptor.ToProtocol(), headers);
+ return GetInfo(flightDescriptor, headers, null,
CancellationToken.None);
+ }
+
+ public AsyncUnaryCall<FlightInfo> GetInfo(FlightDescriptor
flightDescriptor, Metadata headers, System.DateTime? deadline,
CancellationToken cancellationToken = default)
+ {
+ var flightInfoResult =
_client.GetFlightInfoAsync(flightDescriptor.ToProtocol(), headers, deadline,
cancellationToken);
var flightInfo = flightInfoResult
.ResponseAsync
@@ -79,7 +100,12 @@ namespace Apache.Arrow.Flight.Client
public FlightRecordBatchDuplexStreamingCall StartPut(FlightDescriptor
flightDescriptor, Metadata headers = null)
{
- var channels = _client.DoPut(headers);
+ return StartPut(flightDescriptor, headers, null,
CancellationToken.None);
+ }
+
+ public FlightRecordBatchDuplexStreamingCall StartPut(FlightDescriptor
flightDescriptor, Metadata headers, System.DateTime? deadline,
CancellationToken cancellationToken = default)
+ {
+ var channels = _client.DoPut(headers, deadline, cancellationToken);
var requestStream = new
FlightClientRecordBatchStreamWriter(channels.RequestStream, flightDescriptor);
var readStream = new StreamReader<Protocol.PutResult,
FlightPutResult>(channels.ResponseStream, putResult => new
FlightPutResult(putResult));
return new FlightRecordBatchDuplexStreamingCall(
@@ -93,7 +119,13 @@ namespace Apache.Arrow.Flight.Client
public AsyncDuplexStreamingCall<FlightHandshakeRequest,
FlightHandshakeResponse> Handshake(Metadata headers = null)
{
- var channel = _client.Handshake(headers);
+ return Handshake(headers, null, CancellationToken.None);
+
+ }
+
+ public AsyncDuplexStreamingCall<FlightHandshakeRequest,
FlightHandshakeResponse> Handshake(Metadata headers, System.DateTime? deadline,
CancellationToken cancellationToken = default)
+ {
+ var channel = _client.Handshake(headers, deadline,
cancellationToken);
var readStream = new StreamReader<HandshakeResponse,
FlightHandshakeResponse>(channel.ResponseStream, response => new
FlightHandshakeResponse(response));
var writeStream = new
FlightHandshakeStreamWriterAdapter(channel.RequestStream);
var call = new AsyncDuplexStreamingCall<FlightHandshakeRequest,
FlightHandshakeResponse>(
@@ -109,7 +141,12 @@ namespace Apache.Arrow.Flight.Client
public FlightRecordBatchExchangeCall DoExchange(FlightDescriptor
flightDescriptor, Metadata headers = null)
{
- var channel = _client.DoExchange(headers);
+ return DoExchange(flightDescriptor, headers, null,
CancellationToken.None);
+ }
+
+ public FlightRecordBatchExchangeCall DoExchange(FlightDescriptor
flightDescriptor, Metadata headers, System.DateTime? deadline,
CancellationToken cancellationToken = default)
+ {
+ var channel = _client.DoExchange(headers, deadline,
cancellationToken);
var requestStream = new
FlightClientRecordBatchStreamWriter(channel.RequestStream, flightDescriptor);
var responseStream = new
FlightClientRecordBatchStreamReader(channel.ResponseStream);
var call = new FlightRecordBatchExchangeCall(
@@ -125,14 +162,24 @@ namespace Apache.Arrow.Flight.Client
public AsyncServerStreamingCall<FlightResult> DoAction(FlightAction
action, Metadata headers = null)
{
- var stream = _client.DoAction(action.ToProtocol(), headers);
+ return DoAction(action, headers, null, CancellationToken.None);
+ }
+
+ public AsyncServerStreamingCall<FlightResult> DoAction(FlightAction
action, Metadata headers, System.DateTime? deadline, CancellationToken
cancellationToken = default)
+ {
+ var stream = _client.DoAction(action.ToProtocol(), headers,
deadline, cancellationToken);
var streamReader = new StreamReader<Protocol.Result,
FlightResult>(stream.ResponseStream, result => new FlightResult(result));
return new AsyncServerStreamingCall<FlightResult>(streamReader,
stream.ResponseHeadersAsync, stream.GetStatus, stream.GetTrailers,
stream.Dispose);
}
public AsyncUnaryCall<Schema> GetSchema(FlightDescriptor
flightDescriptor, Metadata headers = null)
{
- var schemaResult =
_client.GetSchemaAsync(flightDescriptor.ToProtocol(), headers);
+ return GetSchema(flightDescriptor, headers, null,
CancellationToken.None);
+ }
+
+ public AsyncUnaryCall<Schema> GetSchema(FlightDescriptor
flightDescriptor, Metadata headers, System.DateTime? deadline,
CancellationToken cancellationToken = default)
+ {
+ var schemaResult =
_client.GetSchemaAsync(flightDescriptor.ToProtocol(), headers, deadline,
cancellationToken);
var schema = schemaResult
.ResponseAsync
diff --git a/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs
b/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs
index aac4e42092..8bf6e1120c 100644
--- a/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs
+++ b/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs
@@ -16,12 +16,15 @@
using System;
using System.Collections.Generic;
using System.Linq;
+using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Flight.Client;
using Apache.Arrow.Flight.TestWeb;
using Apache.Arrow.Tests;
using Google.Protobuf;
+using Grpc.Core;
using Grpc.Core.Utils;
+using Python.Runtime;
using Xunit;
namespace Apache.Arrow.Flight.Tests
@@ -70,7 +73,7 @@ namespace Apache.Arrow.Flight.Tests
var flightHolder = new FlightHolder(flightDescriptor,
initialBatch.RecordBatch.Schema, _testWebFactory.GetAddress());
- foreach(var batch in batches)
+ foreach (var batch in batches)
{
flightHolder.AddBatch(batch);
}
@@ -187,8 +190,8 @@ namespace Apache.Arrow.Flight.Tests
var getStream = _flightClient.GetStream(endpoint.Ticket);
- List<ByteString> actualMetadata = new List<ByteString>();
- while(await getStream.ResponseStream.MoveNext(default))
+ List<ByteString> actualMetadata = new List<ByteString>();
+ while (await getStream.ResponseStream.MoveNext(default))
{
actualMetadata.AddRange(getStream.ResponseStream.ApplicationMetadata);
}
@@ -277,7 +280,7 @@ namespace Apache.Arrow.Flight.Tests
var actualFlights = await
listFlightStream.ResponseStream.ToListAsync();
- for(int i = 0; i < expectedFlightInfo.Count; i++)
+ for (int i = 0; i < expectedFlightInfo.Count; i++)
{
FlightInfoComparer.Compare(expectedFlightInfo[i],
actualFlights[i]);
}
@@ -386,7 +389,7 @@ namespace Apache.Arrow.Flight.Tests
List<RecordBatch> resultList = new List<RecordBatch>();
- await foreach(var recordBatch in getStream.ResponseStream)
+ await foreach (var recordBatch in getStream.ResponseStream)
{
resultList.Add(recordBatch);
}
@@ -415,5 +418,89 @@ namespace Apache.Arrow.Flight.Tests
Assert.Equal(expectedBatch.Length, result.TotalRecords);
Assert.Equal(expectedTotalBytes, result.TotalBytes);
}
+
+ [Fact]
+ public async Task EnsureCallRaisesDeadlineExceeded()
+ {
+ var flightDescriptor =
FlightDescriptor.CreatePathDescriptor("raise_deadline");
+ var deadline = DateTime.UtcNow;
+ var batch = CreateTestBatch(0, 100);
+
+ RpcException exception = null;
+
+ var asyncServerStreamingCallFlights =
_flightClient.ListFlights(null, null, deadline);
+ Assert.Equal(StatusCode.DeadlineExceeded,
asyncServerStreamingCallFlights.GetStatus().StatusCode);
+
+ var asyncServerStreamingCallActions =
_flightClient.ListActions(null, deadline);
+ Assert.Equal(StatusCode.DeadlineExceeded,
asyncServerStreamingCallFlights.GetStatus().StatusCode);
+
+ GivenStoreBatches(flightDescriptor, new
RecordBatchWithMetadata(batch));
+ exception = await Assert.ThrowsAsync<RpcException>(async () =>
await _flightClient.GetInfo(flightDescriptor, null, deadline));
+ Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode);
+
+ var flightInfo = await _flightClient.GetInfo(flightDescriptor);
+ var endpoint = flightInfo.Endpoints.FirstOrDefault();
+ var getStream = _flightClient.GetStream(endpoint.Ticket, null,
deadline);
+ Assert.Equal(StatusCode.DeadlineExceeded,
getStream.GetStatus().StatusCode);
+
+ var duplexStreamingCall =
_flightClient.DoExchange(flightDescriptor, null, deadline);
+ exception = await Assert.ThrowsAsync<RpcException>(async () =>
await duplexStreamingCall.RequestStream.WriteAsync(batch));
+ Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode);
+
+ var putStream = _flightClient.StartPut(flightDescriptor, null,
deadline);
+ exception = await Assert.ThrowsAsync<RpcException>(async () =>
await putStream.RequestStream.WriteAsync(batch));
+ Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode);
+
+ exception = await Assert.ThrowsAsync<RpcException>(async () =>
await _flightClient.GetSchema(flightDescriptor, null, deadline));
+ Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode);
+
+ var handshakeStreamingCall = _flightClient.Handshake(null,
deadline);
+ exception = await Assert.ThrowsAsync<RpcException>(async () =>
await handshakeStreamingCall.RequestStream.WriteAsync(new
FlightHandshakeRequest(ByteString.Empty)));
+ Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode);
+ }
+
+ [Fact]
+ public async Task EnsureCallRaisesRequestCancelled()
+ {
+ var cts = new CancellationTokenSource();
+ cts.CancelAfter(1);
+
+ var batch = CreateTestBatch(0, 100);
+ var metadata = new Metadata();
+ var flightDescriptor =
FlightDescriptor.CreatePathDescriptor("raise_cancelled");
+ await Task.Delay(5);
+ RpcException exception = null;
+
+ var asyncServerStreamingCallFlights =
_flightClient.ListFlights(null, null, null, cts.Token);
+ Assert.Equal(StatusCode.Cancelled,
asyncServerStreamingCallFlights.GetStatus().StatusCode);
+
+ var asyncServerStreamingCallActions =
_flightClient.ListActions(null, null, cts.Token);
+ Assert.Equal(StatusCode.Cancelled,
asyncServerStreamingCallFlights.GetStatus().StatusCode);
+
+ GivenStoreBatches(flightDescriptor, new
RecordBatchWithMetadata(batch));
+ exception = await Assert.ThrowsAsync<RpcException>(async () =>
await _flightClient.GetInfo(flightDescriptor, null, null, cts.Token));
+ Assert.Equal(StatusCode.Cancelled, exception.StatusCode);
+
+ var flightInfo = await _flightClient.GetInfo(flightDescriptor);
+ var endpoint = flightInfo.Endpoints.FirstOrDefault();
+ var getStream = _flightClient.GetStream(endpoint.Ticket, null,
null, cts.Token);
+ Assert.Equal(StatusCode.Cancelled,
getStream.GetStatus().StatusCode);
+
+ var duplexStreamingCall =
_flightClient.DoExchange(flightDescriptor, null, null, cts.Token);
+ exception = await Assert.ThrowsAsync<RpcException>(async () =>
await duplexStreamingCall.RequestStream.WriteAsync(batch));
+ Assert.Equal(StatusCode.Cancelled, exception.StatusCode);
+
+ var putStream = _flightClient.StartPut(flightDescriptor, null,
null, cts.Token);
+ exception = await Assert.ThrowsAsync<RpcException>(async () =>
await putStream.RequestStream.WriteAsync(batch));
+ Assert.Equal(StatusCode.Cancelled, exception.StatusCode);
+
+ exception = await Assert.ThrowsAsync<RpcException>(async () =>
await _flightClient.GetSchema(flightDescriptor, null, null, cts.Token));
+ Assert.Equal(StatusCode.Cancelled, exception.StatusCode);
+
+ var handshakeStreamingCall = _flightClient.Handshake(null, null,
cts.Token);
+ exception = await Assert.ThrowsAsync<RpcException>(async () =>
await handshakeStreamingCall.RequestStream.WriteAsync(new
FlightHandshakeRequest(ByteString.Empty)));
+ Assert.Equal(StatusCode.Cancelled, exception.StatusCode);
+
+ }
}
}