CurtHagenlocher commented on code in PR #46316: URL: https://github.com/apache/arrow/pull/46316#discussion_r2275253284
########## csharp/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddlewareFactory.cs: ########## @@ -0,0 +1,70 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Globalization; +using System.Net; +using Apache.Arrow.Flight.Middleware.Extensions; +using Apache.Arrow.Flight.Middleware.Interfaces; +using Microsoft.Extensions.Logging; +namespace Apache.Arrow.Flight.Middleware; + +public class ClientCookieMiddlewareFactory : IFlightClientMiddlewareFactory +{ + public readonly ConcurrentDictionary<string, Cookie> Cookies = new(StringComparer.OrdinalIgnoreCase); + private readonly ILoggerFactory _loggerFactory; + + public ClientCookieMiddlewareFactory(ILoggerFactory loggerFactory) + { + _loggerFactory = loggerFactory; Review Comment: Wouldn't it make more sense to create and hold on to a single `ILogger` instance instead of creating a new one for every call? ########## csharp/src/Apache.Arrow.Flight/Middleware/Interceptors/ClientInterceptorAdapter.cs: ########## @@ -0,0 +1,161 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Apache.Arrow.Flight.Middleware.Interfaces; +using Grpc.Core; +using Grpc.Core.Interceptors; + +namespace Apache.Arrow.Flight.Middleware.Interceptors +{ Review Comment: It would be good to consistently use `namespace` declarations instead of scopes inside this project. ########## csharp/src/Apache.Arrow.Flight/Middleware/Interceptors/ClientInterceptorAdapter.cs: ########## @@ -0,0 +1,161 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Apache.Arrow.Flight.Middleware.Interfaces; +using Grpc.Core; +using Grpc.Core.Interceptors; + +namespace Apache.Arrow.Flight.Middleware.Interceptors +{ + public sealed class ClientInterceptorAdapter : Interceptor + { + private readonly IReadOnlyList<IFlightClientMiddlewareFactory> _factories; + + public ClientInterceptorAdapter(IEnumerable<IFlightClientMiddlewareFactory> factories) + { + _factories = factories?.ToList() ?? throw new ArgumentNullException(nameof(factories)); + } + + public override AsyncUnaryCall<TResponse> AsyncUnaryCall<TRequest, TResponse>( + TRequest request, + ClientInterceptorContext<TRequest, TResponse> context, + AsyncUnaryCallContinuation<TRequest, TResponse> continuation) + where TRequest : class + where TResponse : class + { + var options = InterceptCall(context, out var middlewares); + + var newContext = new ClientInterceptorContext<TRequest, TResponse>( + context.Method, + context.Host, + options); + + var call = continuation(request, newContext); + + return new AsyncUnaryCall<TResponse>( + HandleResponse(call.ResponseAsync, call.ResponseHeadersAsync, call.GetStatus, call.GetTrailers, + call.Dispose, middlewares), + call.ResponseHeadersAsync, + call.GetStatus, + call.GetTrailers, + call.Dispose + ); + } + + public override AsyncServerStreamingCall<TResponse> AsyncServerStreamingCall<TRequest, TResponse>( + TRequest request, + ClientInterceptorContext<TRequest, TResponse> context, + AsyncServerStreamingCallContinuation<TRequest, TResponse> continuation) + where TRequest : class + where TResponse : class + { + var callOptions = InterceptCall(context, out var middlewares); + var newContext = new ClientInterceptorContext<TRequest, TResponse>( + context.Method, context.Host, callOptions); + + var call = continuation(request, newContext); + + var responseHeadersTask = call.ResponseHeadersAsync.ContinueWith(task => + { + if (task.Exception == null && task.Result != null) + { + var headers = task.Result; + foreach (var m in middlewares) + m?.OnHeadersReceived(new CallHeaders(headers)); + } + + return task.Result; Review Comment: If `task.Exception` can be not-null, then should this throw the exception instead of returning null? ########## csharp/src/Apache.Arrow.Flight/Middleware/Interceptors/ClientInterceptorAdapter.cs: ########## @@ -0,0 +1,161 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Apache.Arrow.Flight.Middleware.Interfaces; +using Grpc.Core; +using Grpc.Core.Interceptors; + +namespace Apache.Arrow.Flight.Middleware.Interceptors +{ + public sealed class ClientInterceptorAdapter : Interceptor + { + private readonly IReadOnlyList<IFlightClientMiddlewareFactory> _factories; + + public ClientInterceptorAdapter(IEnumerable<IFlightClientMiddlewareFactory> factories) + { + _factories = factories?.ToList() ?? throw new ArgumentNullException(nameof(factories)); + } + + public override AsyncUnaryCall<TResponse> AsyncUnaryCall<TRequest, TResponse>( + TRequest request, + ClientInterceptorContext<TRequest, TResponse> context, + AsyncUnaryCallContinuation<TRequest, TResponse> continuation) + where TRequest : class + where TResponse : class + { + var options = InterceptCall(context, out var middlewares); + + var newContext = new ClientInterceptorContext<TRequest, TResponse>( + context.Method, + context.Host, + options); + + var call = continuation(request, newContext); + + return new AsyncUnaryCall<TResponse>( + HandleResponse(call.ResponseAsync, call.ResponseHeadersAsync, call.GetStatus, call.GetTrailers, + call.Dispose, middlewares), + call.ResponseHeadersAsync, + call.GetStatus, + call.GetTrailers, + call.Dispose + ); + } + + public override AsyncServerStreamingCall<TResponse> AsyncServerStreamingCall<TRequest, TResponse>( + TRequest request, + ClientInterceptorContext<TRequest, TResponse> context, + AsyncServerStreamingCallContinuation<TRequest, TResponse> continuation) + where TRequest : class + where TResponse : class + { + var callOptions = InterceptCall(context, out var middlewares); + var newContext = new ClientInterceptorContext<TRequest, TResponse>( + context.Method, context.Host, callOptions); + + var call = continuation(request, newContext); + + var responseHeadersTask = call.ResponseHeadersAsync.ContinueWith(task => + { + if (task.Exception == null && task.Result != null) + { + var headers = task.Result; + foreach (var m in middlewares) + m?.OnHeadersReceived(new CallHeaders(headers)); + } + + return task.Result; + }); + + var wrappedResponseStream = new MiddlewareResponseStream<TResponse>( + call.ResponseStream, + call, + middlewares); + + return new AsyncServerStreamingCall<TResponse>( + wrappedResponseStream, + responseHeadersTask, + call.GetStatus, + call.GetTrailers, + call.Dispose); + } + + Review Comment: nit: extra blank line ########## csharp/src/Apache.Arrow.Flight/Middleware/MetadataAdapter.cs: ########## @@ -0,0 +1,102 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using System.Linq; +using Apache.Arrow.Flight.Middleware.Interfaces; +using Grpc.Core; + +namespace Apache.Arrow.Flight.Middleware; + +public class MetadataAdapter : ICallHeaders Review Comment: This class doesn't appear to be used in this PR. Is it possible it's been superseded by `CallHeaders`? ########## csharp/src/Apache.Arrow.Flight/Middleware/Interceptors/ClientInterceptorAdapter.cs: ########## @@ -0,0 +1,161 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Apache.Arrow.Flight.Middleware.Interfaces; +using Grpc.Core; +using Grpc.Core.Interceptors; + +namespace Apache.Arrow.Flight.Middleware.Interceptors +{ + public sealed class ClientInterceptorAdapter : Interceptor + { + private readonly IReadOnlyList<IFlightClientMiddlewareFactory> _factories; + + public ClientInterceptorAdapter(IEnumerable<IFlightClientMiddlewareFactory> factories) + { + _factories = factories?.ToList() ?? throw new ArgumentNullException(nameof(factories)); + } + + public override AsyncUnaryCall<TResponse> AsyncUnaryCall<TRequest, TResponse>( + TRequest request, + ClientInterceptorContext<TRequest, TResponse> context, + AsyncUnaryCallContinuation<TRequest, TResponse> continuation) + where TRequest : class + where TResponse : class + { + var options = InterceptCall(context, out var middlewares); + + var newContext = new ClientInterceptorContext<TRequest, TResponse>( + context.Method, + context.Host, + options); + + var call = continuation(request, newContext); + + return new AsyncUnaryCall<TResponse>( + HandleResponse(call.ResponseAsync, call.ResponseHeadersAsync, call.GetStatus, call.GetTrailers, + call.Dispose, middlewares), + call.ResponseHeadersAsync, + call.GetStatus, + call.GetTrailers, + call.Dispose + ); + } + + public override AsyncServerStreamingCall<TResponse> AsyncServerStreamingCall<TRequest, TResponse>( + TRequest request, + ClientInterceptorContext<TRequest, TResponse> context, + AsyncServerStreamingCallContinuation<TRequest, TResponse> continuation) + where TRequest : class + where TResponse : class + { + var callOptions = InterceptCall(context, out var middlewares); + var newContext = new ClientInterceptorContext<TRequest, TResponse>( + context.Method, context.Host, callOptions); + + var call = continuation(request, newContext); + + var responseHeadersTask = call.ResponseHeadersAsync.ContinueWith(task => + { + if (task.Exception == null && task.Result != null) + { + var headers = task.Result; + foreach (var m in middlewares) + m?.OnHeadersReceived(new CallHeaders(headers)); + } + + return task.Result; + }); + + var wrappedResponseStream = new MiddlewareResponseStream<TResponse>( + call.ResponseStream, + call, + middlewares); + + return new AsyncServerStreamingCall<TResponse>( + wrappedResponseStream, + responseHeadersTask, + call.GetStatus, + call.GetTrailers, + call.Dispose); + } + + + private CallOptions InterceptCall<TRequest, TResponse>( + ClientInterceptorContext<TRequest, TResponse> context, + out List<IFlightClientMiddleware> middlewareList) + where TRequest : class + where TResponse : class + { + var callInfo = new CallInfo(context.Method.FullName, context.Method.Type); + + var headers = context.Options.Headers ?? new Metadata(); + middlewareList = new List<IFlightClientMiddleware>(); + + var callHeaders = new CallHeaders(headers); + + foreach (var factory in _factories) + { + var middleware = factory.OnCallStarted(callInfo); + middleware?.OnBeforeSendingHeaders(callHeaders); + middlewareList.Add(middleware); + } + + return context.Options.WithHeaders(headers); + } + + private async Task<TResponse> HandleResponse<TResponse>( + Task<TResponse> responseTask, + Task<Metadata> headersTask, + Func<Status> getStatus, + Func<Metadata> getTrailers, + Action dispose, + List<IFlightClientMiddleware> middlewares) + { + try + { + var headers = await headersTask.ConfigureAwait(false); + foreach (var m in middlewares) + { + m?.OnHeadersReceived(new CallHeaders(headers)); + } + + var response = await responseTask.ConfigureAwait(false); + foreach (var m in middlewares) + { + m?.OnCallCompleted(getStatus(), getTrailers()); + } + + return response; + } + catch + { + foreach (var m in middlewares) Review Comment: If the very last call to `OnCallCompleted` throws an exception, we're going to cycle through the list a second time calling all the middleware again. Consider instead catching errors inside of the enumeration and then only bubbling the exception out at the very end -- perhaps wrapping in an `AggregateException` if there is more than one. ########## csharp/src/Apache.Arrow.Flight/Middleware/Interceptors/ClientInterceptorAdapter.cs: ########## @@ -0,0 +1,161 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Apache.Arrow.Flight.Middleware.Interfaces; +using Grpc.Core; +using Grpc.Core.Interceptors; + +namespace Apache.Arrow.Flight.Middleware.Interceptors +{ + public sealed class ClientInterceptorAdapter : Interceptor + { + private readonly IReadOnlyList<IFlightClientMiddlewareFactory> _factories; + + public ClientInterceptorAdapter(IEnumerable<IFlightClientMiddlewareFactory> factories) + { + _factories = factories?.ToList() ?? throw new ArgumentNullException(nameof(factories)); + } + + public override AsyncUnaryCall<TResponse> AsyncUnaryCall<TRequest, TResponse>( + TRequest request, + ClientInterceptorContext<TRequest, TResponse> context, + AsyncUnaryCallContinuation<TRequest, TResponse> continuation) + where TRequest : class + where TResponse : class + { + var options = InterceptCall(context, out var middlewares); + + var newContext = new ClientInterceptorContext<TRequest, TResponse>( + context.Method, + context.Host, + options); + + var call = continuation(request, newContext); + + return new AsyncUnaryCall<TResponse>( + HandleResponse(call.ResponseAsync, call.ResponseHeadersAsync, call.GetStatus, call.GetTrailers, + call.Dispose, middlewares), + call.ResponseHeadersAsync, + call.GetStatus, + call.GetTrailers, + call.Dispose + ); + } + + public override AsyncServerStreamingCall<TResponse> AsyncServerStreamingCall<TRequest, TResponse>( + TRequest request, + ClientInterceptorContext<TRequest, TResponse> context, + AsyncServerStreamingCallContinuation<TRequest, TResponse> continuation) + where TRequest : class + where TResponse : class + { + var callOptions = InterceptCall(context, out var middlewares); + var newContext = new ClientInterceptorContext<TRequest, TResponse>( + context.Method, context.Host, callOptions); + + var call = continuation(request, newContext); + + var responseHeadersTask = call.ResponseHeadersAsync.ContinueWith(task => + { + if (task.Exception == null && task.Result != null) + { + var headers = task.Result; + foreach (var m in middlewares) + m?.OnHeadersReceived(new CallHeaders(headers)); + } + + return task.Result; + }); + + var wrappedResponseStream = new MiddlewareResponseStream<TResponse>( + call.ResponseStream, + call, + middlewares); + + return new AsyncServerStreamingCall<TResponse>( + wrappedResponseStream, + responseHeadersTask, + call.GetStatus, + call.GetTrailers, + call.Dispose); + } + + + private CallOptions InterceptCall<TRequest, TResponse>( + ClientInterceptorContext<TRequest, TResponse> context, + out List<IFlightClientMiddleware> middlewareList) + where TRequest : class + where TResponse : class + { + var callInfo = new CallInfo(context.Method.FullName, context.Method.Type); + + var headers = context.Options.Headers ?? new Metadata(); + middlewareList = new List<IFlightClientMiddleware>(); + + var callHeaders = new CallHeaders(headers); + + foreach (var factory in _factories) + { + var middleware = factory.OnCallStarted(callInfo); + middleware?.OnBeforeSendingHeaders(callHeaders); + middlewareList.Add(middleware); + } + + return context.Options.WithHeaders(headers); + } + + private async Task<TResponse> HandleResponse<TResponse>( + Task<TResponse> responseTask, + Task<Metadata> headersTask, + Func<Status> getStatus, + Func<Metadata> getTrailers, + Action dispose, + List<IFlightClientMiddleware> middlewares) + { + try + { + var headers = await headersTask.ConfigureAwait(false); + foreach (var m in middlewares) + { + m?.OnHeadersReceived(new CallHeaders(headers)); + } + + var response = await responseTask.ConfigureAwait(false); + foreach (var m in middlewares) + { + m?.OnCallCompleted(getStatus(), getTrailers()); Review Comment: Can the `getStatus()` and `getTrailers()` be lifted out of the loop instead of being called repeatedly? If necessary, this could be e.g. `var status = middlewares.Count > 0 ? getStatus() : default` to avoid the calls when no middleware is present. ########## csharp/src/Apache.Arrow.Flight/Middleware/Interceptors/ClientInterceptorAdapter.cs: ########## @@ -0,0 +1,161 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Apache.Arrow.Flight.Middleware.Interfaces; +using Grpc.Core; +using Grpc.Core.Interceptors; + +namespace Apache.Arrow.Flight.Middleware.Interceptors +{ + public sealed class ClientInterceptorAdapter : Interceptor + { + private readonly IReadOnlyList<IFlightClientMiddlewareFactory> _factories; + + public ClientInterceptorAdapter(IEnumerable<IFlightClientMiddlewareFactory> factories) + { + _factories = factories?.ToList() ?? throw new ArgumentNullException(nameof(factories)); + } + + public override AsyncUnaryCall<TResponse> AsyncUnaryCall<TRequest, TResponse>( + TRequest request, + ClientInterceptorContext<TRequest, TResponse> context, + AsyncUnaryCallContinuation<TRequest, TResponse> continuation) + where TRequest : class + where TResponse : class + { + var options = InterceptCall(context, out var middlewares); + + var newContext = new ClientInterceptorContext<TRequest, TResponse>( + context.Method, + context.Host, + options); + + var call = continuation(request, newContext); + + return new AsyncUnaryCall<TResponse>( + HandleResponse(call.ResponseAsync, call.ResponseHeadersAsync, call.GetStatus, call.GetTrailers, + call.Dispose, middlewares), + call.ResponseHeadersAsync, + call.GetStatus, + call.GetTrailers, + call.Dispose + ); + } + + public override AsyncServerStreamingCall<TResponse> AsyncServerStreamingCall<TRequest, TResponse>( + TRequest request, + ClientInterceptorContext<TRequest, TResponse> context, + AsyncServerStreamingCallContinuation<TRequest, TResponse> continuation) + where TRequest : class + where TResponse : class + { + var callOptions = InterceptCall(context, out var middlewares); + var newContext = new ClientInterceptorContext<TRequest, TResponse>( + context.Method, context.Host, callOptions); + + var call = continuation(request, newContext); + + var responseHeadersTask = call.ResponseHeadersAsync.ContinueWith(task => + { + if (task.Exception == null && task.Result != null) + { + var headers = task.Result; + foreach (var m in middlewares) + m?.OnHeadersReceived(new CallHeaders(headers)); Review Comment: Can the `new CallHeaders(headers)` be lifted out of the loop to save allocations? ########## csharp/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddlewareFactory.cs: ########## @@ -0,0 +1,70 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Globalization; +using System.Net; +using Apache.Arrow.Flight.Middleware.Extensions; +using Apache.Arrow.Flight.Middleware.Interfaces; +using Microsoft.Extensions.Logging; +namespace Apache.Arrow.Flight.Middleware; Review Comment: nit: insert a blank line before the namespace declaration ########## csharp/src/Apache.Arrow.Flight/Middleware/Interceptors/ClientInterceptorAdapter.cs: ########## @@ -0,0 +1,161 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Apache.Arrow.Flight.Middleware.Interfaces; +using Grpc.Core; +using Grpc.Core.Interceptors; + +namespace Apache.Arrow.Flight.Middleware.Interceptors +{ + public sealed class ClientInterceptorAdapter : Interceptor + { + private readonly IReadOnlyList<IFlightClientMiddlewareFactory> _factories; + + public ClientInterceptorAdapter(IEnumerable<IFlightClientMiddlewareFactory> factories) + { + _factories = factories?.ToList() ?? throw new ArgumentNullException(nameof(factories)); + } + + public override AsyncUnaryCall<TResponse> AsyncUnaryCall<TRequest, TResponse>( + TRequest request, + ClientInterceptorContext<TRequest, TResponse> context, + AsyncUnaryCallContinuation<TRequest, TResponse> continuation) + where TRequest : class + where TResponse : class + { + var options = InterceptCall(context, out var middlewares); + + var newContext = new ClientInterceptorContext<TRequest, TResponse>( + context.Method, + context.Host, + options); + + var call = continuation(request, newContext); + + return new AsyncUnaryCall<TResponse>( + HandleResponse(call.ResponseAsync, call.ResponseHeadersAsync, call.GetStatus, call.GetTrailers, + call.Dispose, middlewares), + call.ResponseHeadersAsync, + call.GetStatus, + call.GetTrailers, + call.Dispose + ); + } + + public override AsyncServerStreamingCall<TResponse> AsyncServerStreamingCall<TRequest, TResponse>( + TRequest request, + ClientInterceptorContext<TRequest, TResponse> context, + AsyncServerStreamingCallContinuation<TRequest, TResponse> continuation) + where TRequest : class + where TResponse : class + { + var callOptions = InterceptCall(context, out var middlewares); + var newContext = new ClientInterceptorContext<TRequest, TResponse>( + context.Method, context.Host, callOptions); + + var call = continuation(request, newContext); + + var responseHeadersTask = call.ResponseHeadersAsync.ContinueWith(task => + { + if (task.Exception == null && task.Result != null) + { + var headers = task.Result; + foreach (var m in middlewares) + m?.OnHeadersReceived(new CallHeaders(headers)); + } + + return task.Result; + }); + + var wrappedResponseStream = new MiddlewareResponseStream<TResponse>( + call.ResponseStream, + call, + middlewares); + + return new AsyncServerStreamingCall<TResponse>( + wrappedResponseStream, + responseHeadersTask, + call.GetStatus, + call.GetTrailers, + call.Dispose); + } + + + private CallOptions InterceptCall<TRequest, TResponse>( + ClientInterceptorContext<TRequest, TResponse> context, + out List<IFlightClientMiddleware> middlewareList) + where TRequest : class + where TResponse : class + { + var callInfo = new CallInfo(context.Method.FullName, context.Method.Type); + + var headers = context.Options.Headers ?? new Metadata(); + middlewareList = new List<IFlightClientMiddleware>(); + + var callHeaders = new CallHeaders(headers); + + foreach (var factory in _factories) + { + var middleware = factory.OnCallStarted(callInfo); + middleware?.OnBeforeSendingHeaders(callHeaders); + middlewareList.Add(middleware); + } + + return context.Options.WithHeaders(headers); + } + + private async Task<TResponse> HandleResponse<TResponse>( + Task<TResponse> responseTask, + Task<Metadata> headersTask, + Func<Status> getStatus, + Func<Metadata> getTrailers, + Action dispose, + List<IFlightClientMiddleware> middlewares) + { + try + { + var headers = await headersTask.ConfigureAwait(false); + foreach (var m in middlewares) + { + m?.OnHeadersReceived(new CallHeaders(headers)); Review Comment: Can the `new CallHeaders(headers)` be lifted out of the loop to save allocations? ########## csharp/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddlewareFactory.cs: ########## @@ -0,0 +1,70 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Globalization; +using System.Net; +using Apache.Arrow.Flight.Middleware.Extensions; +using Apache.Arrow.Flight.Middleware.Interfaces; +using Microsoft.Extensions.Logging; +namespace Apache.Arrow.Flight.Middleware; + +public class ClientCookieMiddlewareFactory : IFlightClientMiddlewareFactory +{ + public readonly ConcurrentDictionary<string, Cookie> Cookies = new(StringComparer.OrdinalIgnoreCase); + private readonly ILoggerFactory _loggerFactory; + + public ClientCookieMiddlewareFactory(ILoggerFactory loggerFactory) + { + _loggerFactory = loggerFactory; + } + + public IFlightClientMiddleware OnCallStarted(CallInfo callInfo) + { + var logger = _loggerFactory.CreateLogger<ClientCookieMiddleware>(); + return new ClientCookieMiddleware(this, logger); + } + + internal void UpdateCookies(IEnumerable<string> newCookieHeaderValues) + { + var logger = _loggerFactory.CreateLogger<ClientCookieMiddleware>(); + foreach (var headerValue in newCookieHeaderValues) + { + try + { + foreach (var parsedCookie in headerValue.ParseHeader()) + { + var nameLc = parsedCookie.Name.ToLower(CultureInfo.InvariantCulture); + if (parsedCookie.IsExpired(headerValue)) + { + Cookies.TryRemove(nameLc, out _); + } + else + { + Cookies[nameLc] = parsedCookie; + } + } + } + catch (FormatException ex) + { + Review Comment: nit: remove blank lines here and at line 69 ########## csharp/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddleware.cs: ########## @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Collections.Generic; +using Apache.Arrow.Flight.Middleware.Interfaces; +using Grpc.Core; +using Microsoft.Extensions.Logging; + +namespace Apache.Arrow.Flight.Middleware; + +public class ClientCookieMiddleware : IFlightClientMiddleware +{ + private readonly ClientCookieMiddlewareFactory _factory; + private readonly ILogger<ClientCookieMiddleware> _logger; + private const string SET_COOKIE_HEADER = "Set-Cookie"; + private const string COOKIE_HEADER = "Cookie"; + + public ClientCookieMiddleware(ClientCookieMiddlewareFactory factory, + ILogger<ClientCookieMiddleware> logger) + { + _factory = factory; + _logger = logger; + } + + public void OnBeforeSendingHeaders(ICallHeaders outgoingHeaders) + { + if (_factory.Cookies.IsEmpty) + return; + var cookieValue = GetValidCookiesAsString(); + if (!string.IsNullOrEmpty(cookieValue)) + { + outgoingHeaders.Insert(COOKIE_HEADER, cookieValue); + } + _logger.LogInformation("Sending Headers: " + string.Join(", ", outgoingHeaders)); Review Comment: What does this log? It doesn't look like any of the types implementing `ICallHeaders` have overridden `ToString`. And given that the cookies might include credential information, it's probably not a good idea to log all their values. Same comment applies to line 53. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
