HackPoint commented on code in PR #46316:
URL: https://github.com/apache/arrow/pull/46316#discussion_r2282038757


##########
csharp/src/Apache.Arrow.Flight/Middleware/Interceptors/ClientInterceptorAdapter.cs:
##########
@@ -0,0 +1,161 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading.Tasks;
+using Apache.Arrow.Flight.Middleware.Interfaces;
+using Grpc.Core;
+using Grpc.Core.Interceptors;
+
+namespace Apache.Arrow.Flight.Middleware.Interceptors
+{
+    public sealed class ClientInterceptorAdapter : Interceptor
+    {
+        private readonly IReadOnlyList<IFlightClientMiddlewareFactory> 
_factories;
+
+        public 
ClientInterceptorAdapter(IEnumerable<IFlightClientMiddlewareFactory> factories)
+        {
+            _factories = factories?.ToList() ?? throw new 
ArgumentNullException(nameof(factories));
+        }
+
+        public override AsyncUnaryCall<TResponse> AsyncUnaryCall<TRequest, 
TResponse>(
+            TRequest request,
+            ClientInterceptorContext<TRequest, TResponse> context,
+            AsyncUnaryCallContinuation<TRequest, TResponse> continuation)
+            where TRequest : class
+            where TResponse : class
+        {
+            var options = InterceptCall(context, out var middlewares);
+
+            var newContext = new ClientInterceptorContext<TRequest, TResponse>(
+                context.Method,
+                context.Host,
+                options);
+
+            var call = continuation(request, newContext);
+
+            return new AsyncUnaryCall<TResponse>(
+                HandleResponse(call.ResponseAsync, call.ResponseHeadersAsync, 
call.GetStatus, call.GetTrailers,
+                    call.Dispose, middlewares),
+                call.ResponseHeadersAsync,
+                call.GetStatus,
+                call.GetTrailers,
+                call.Dispose
+            );
+        }
+
+        public override AsyncServerStreamingCall<TResponse> 
AsyncServerStreamingCall<TRequest, TResponse>(
+            TRequest request,
+            ClientInterceptorContext<TRequest, TResponse> context,
+            AsyncServerStreamingCallContinuation<TRequest, TResponse> 
continuation)
+            where TRequest : class
+            where TResponse : class
+        {
+            var callOptions = InterceptCall(context, out var middlewares);
+            var newContext = new ClientInterceptorContext<TRequest, TResponse>(
+                context.Method, context.Host, callOptions);
+
+            var call = continuation(request, newContext);
+
+            var responseHeadersTask = 
call.ResponseHeadersAsync.ContinueWith(task =>
+            {
+                if (task.Exception == null && task.Result != null)
+                {
+                    var headers = task.Result;
+                    foreach (var m in middlewares)
+                        m?.OnHeadersReceived(new CallHeaders(headers));
+                }
+
+                return task.Result;
+            });
+
+            var wrappedResponseStream = new 
MiddlewareResponseStream<TResponse>(
+                call.ResponseStream,
+                call,
+                middlewares);
+
+            return new AsyncServerStreamingCall<TResponse>(
+                wrappedResponseStream,
+                responseHeadersTask,
+                call.GetStatus,
+                call.GetTrailers,
+                call.Dispose);
+        }
+
+
+        private CallOptions InterceptCall<TRequest, TResponse>(
+            ClientInterceptorContext<TRequest, TResponse> context,
+            out List<IFlightClientMiddleware> middlewareList)
+            where TRequest : class
+            where TResponse : class
+        {
+            var callInfo = new CallInfo(context.Method.FullName, 
context.Method.Type);
+
+            var headers = context.Options.Headers ?? new Metadata();
+            middlewareList = new List<IFlightClientMiddleware>();
+
+            var callHeaders = new CallHeaders(headers);
+
+            foreach (var factory in _factories)
+            {
+                var middleware = factory.OnCallStarted(callInfo);
+                middleware?.OnBeforeSendingHeaders(callHeaders);
+                middlewareList.Add(middleware);
+            }
+
+            return context.Options.WithHeaders(headers);
+        }
+
+        private async Task<TResponse> HandleResponse<TResponse>(
+            Task<TResponse> responseTask,
+            Task<Metadata> headersTask,
+            Func<Status> getStatus,
+            Func<Metadata> getTrailers,
+            Action dispose,
+            List<IFlightClientMiddleware> middlewares)
+        {
+            try
+            {
+                var headers = await headersTask.ConfigureAwait(false);
+                foreach (var m in middlewares)
+                {
+                    m?.OnHeadersReceived(new CallHeaders(headers));
+                }
+
+                var response = await responseTask.ConfigureAwait(false);
+                foreach (var m in middlewares)
+                {
+                    m?.OnCallCompleted(getStatus(), getTrailers());

Review Comment:
   Good catch:
   getStatus() / getTrailers() are pure (no side effects) and don’t change 
between middleware calls, capture them once per completion path and reuse. Also 
skip all the work if there are no (non-null) middlewares.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to