This is an automated email from the ASF dual-hosted git repository.
curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-dotnet.git
The following commit(s) were added to refs/heads/main by this push:
new 20cb82b feat: Implement Arrow Flight Middleware support (#139)
20cb82b is described below
commit 20cb82b7f4045583df77f320129ca1af28c740db
Author: Curt Hagenlocher <[email protected]>
AuthorDate: Fri Nov 21 07:16:04 2025 -0800
feat: Implement Arrow Flight Middleware support (#139)
### Enhancement: Apache Arrow Flight Middleware in C#
Co-authored-by: @HackPoint
Original change in
[https://github.com/apache/arrow/pull/46316](https://github.com/apache/arrow/pull/46316).
#### Overview
This Pull Request enhances middleware support for Apache Arrow Flight
using C#, focusing on improved metadata header management and
propagation for better observability and extensibility. It also provides
handling for HTTP/HTTPS communication.
#### Rationale for this Change
Effective middleware is critical for managing metadata headers, ensuring
accurate request/response handling, and simplifying debugging in
distributed systems. By improving middleware capabilities, we enhance
reliability and observability, significantly benefiting developers and
operational teams managing complex Flight-based applications.
#### What's Included in this PR?
- Middleware enhancements supporting complete metadata header
propagation.
- Middleware lifecycle hooks for better request/response management.
- Comprehensive integration tests validating middleware functionality.
- Documentation updates reflecting middleware improvements.
#### Key Features
- **Complete Header Propagation:** Ensures accurate propagation of gRPC
metadata headers throughout middleware lifecycle events.
- **HTTP/HTTPS Handling:** Supports middleware integration and metadata
propagation for HTTP and HTTPS communications.
- **Middleware Lifecycle Management:** Supports reliable middleware
hooks (`OnBeforeSendingHeaders`, `OnHeadersReceived`,
`OnCallCompleted`).
- **Enhanced Testing:** Adds comprehensive integration tests to verify
correct middleware behavior.
#### Impact
- Improves middleware reliability and simplifies debugging.
- Enhances transparency in gRPC and HTTP/S communication within
Flight-based applications.
#### Are These Changes Tested?
**Testing Overview**
**Unit Tests:**
- Added tests for middleware lifecycle event execution (e.g.,
`OnBeforeSendingHeaders`, `OnHeadersReceived`, `OnCallCompleted`).
- Verified internal logic for capturing and storing gRPC metadata
headers.
**Integration Tests:**
- Tested end-to-end with a real Flight client and in-memory server
setup.
- Validated propagation of custom headers (e.g., `x-server-header`,
`Set-Cookie`) between client and server.
**End-to-End Tests:**
- Simulated real-world Flight requests to ensure headers are processed
consistently across middleware layers.
- Confirmed correct invocation order and middleware behavior under
different server responses.
**Example Test Cases:**
- Verify that `OnHeadersReceived` correctly captures server-sent
headers.
- Ensure custom client middleware modifies request headers as expected.
- Validate that `OnCallCompleted` is triggered on both success and error
cases.
#### Checklist
- [x] Implementation completed
- [x] Tests added and passing
Closes GitHub Issue: #138
---
src/Apache.Arrow.Flight/Middleware/CallHeaders.cs | 80 +++++++++
src/Apache.Arrow.Flight/Middleware/CallInfo.cs | 35 ++++
.../Middleware/ClientCookieMiddleware.cs | 75 +++++++++
.../Middleware/ClientCookieMiddlewareFactory.cs | 67 ++++++++
.../Middleware/Extensions/CookieExtensions.cs | 104 ++++++++++++
.../Interceptors/ClientInterceptorAdapter.cs | 186 +++++++++++++++++++++
.../Interceptors/MiddlewareResponseStream.cs | 69 ++++++++
.../Middleware/Interfaces/ICallHeaders.cs | 34 ++++
.../Interfaces/IFlightClientMiddleware.cs | 25 +++
.../Interfaces/IFlightClientMiddlewareFactory.cs | 21 +++
.../MiddlewareTests/CallHeadersTests.cs | 135 +++++++++++++++
.../MiddlewareTests/ClientCookieMiddlewareTests.cs | 137 +++++++++++++++
.../ClientInterceptorAdapterTests.cs | 99 +++++++++++
.../MiddlewareTests/CookieExtensionsTests.cs | 127 ++++++++++++++
.../MiddlewareTests/Stubs/CapturingMiddleware.cs | 58 +++++++
.../Stubs/CapturingMiddlewareFactory.cs | 28 ++++
.../Stubs/ClientCookieMiddlewareMock.cs | 64 +++++++
.../MiddlewareTests/Stubs/InMemoryCallHeaders.cs | 90 ++++++++++
.../MiddlewareTests/Stubs/InMemoryFlightStore.cs | 49 ++++++
19 files changed, 1483 insertions(+)
diff --git a/src/Apache.Arrow.Flight/Middleware/CallHeaders.cs
b/src/Apache.Arrow.Flight/Middleware/CallHeaders.cs
new file mode 100644
index 0000000..538c549
--- /dev/null
+++ b/src/Apache.Arrow.Flight/Middleware/CallHeaders.cs
@@ -0,0 +1,80 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+using System;
+using System.Collections;
+using System.Collections.Generic;
+using System.Linq;
+using Apache.Arrow.Flight.Middleware.Interfaces;
+using Grpc.Core;
+
+namespace Apache.Arrow.Flight.Middleware;
+
+public class CallHeaders : ICallHeaders, IEnumerable<KeyValuePair<string,
string>>
+{
+ private readonly Metadata _metadata;
+
+ public CallHeaders(Metadata metadata)
+ {
+ _metadata = metadata;
+ }
+
+ public void Add(string key, string value) => _metadata.Add(key, value);
+
+ public bool ContainsKey(string key) => _metadata.Any(h => KeyEquals(h.Key,
key));
+
+ public IEnumerator<KeyValuePair<string, string>> GetEnumerator()
+ {
+ foreach (var entry in _metadata)
+ yield return new KeyValuePair<string, string>(entry.Key,
entry.Value);
+ }
+
+ IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
+
+ public string this[string key]
+ {
+ get
+ {
+ var entry = _metadata.FirstOrDefault(h => KeyEquals(h.Key, key));
+ return entry?.Value;
+ }
+ set
+ {
+ var entry = _metadata.FirstOrDefault(h => KeyEquals(h.Key, key));
+ if (entry != null) _metadata.Remove(entry);
+ _metadata.Add(key, value);
+ }
+ }
+
+ public string Get(string key) => this[key];
+
+ public byte[] GetBytes(string key) =>
+ _metadata.FirstOrDefault(h => KeyEquals(h.Key, key))?.ValueBytes;
+
+ public IEnumerable<string> GetAll(string key) =>
+ _metadata.Where(h => KeyEquals(h.Key, key)).Select(h => h.Value);
+
+ public IEnumerable<byte[]> GetAllBytes(string key) =>
+ _metadata.Where(h => KeyEquals(h.Key, key)).Select(h => h.ValueBytes);
+
+ public void Insert(string key, string value) => Add(key, value);
+
+ public void Insert(string key, byte[] value) => _metadata.Add(key, value);
+
+ public ISet<string> Keys => new HashSet<string>(_metadata.Select(h =>
h.Key));
+
+ private static bool KeyEquals(string a, string b) =>
+ string.Equals(a, b, StringComparison.OrdinalIgnoreCase);
+}
diff --git a/src/Apache.Arrow.Flight/Middleware/CallInfo.cs
b/src/Apache.Arrow.Flight/Middleware/CallInfo.cs
new file mode 100644
index 0000000..e18f5eb
--- /dev/null
+++ b/src/Apache.Arrow.Flight/Middleware/CallInfo.cs
@@ -0,0 +1,35 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+using Grpc.Core;
+
+namespace Apache.Arrow.Flight.Middleware;
+
+public readonly struct CallInfo
+{
+ public string Method { get; }
+ public MethodType MethodType { get; }
+
+ public CallInfo(string method, MethodType methodType)
+ {
+ Method = method;
+ MethodType = methodType;
+ }
+
+ public override string ToString()
+ {
+ return $"{MethodType}: {Method}";
+ }
+}
diff --git a/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddleware.cs
b/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddleware.cs
new file mode 100644
index 0000000..543b594
--- /dev/null
+++ b/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddleware.cs
@@ -0,0 +1,75 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+using System.Collections.Generic;
+using Apache.Arrow.Flight.Middleware.Interfaces;
+using Grpc.Core;
+using Microsoft.Extensions.Logging;
+
+namespace Apache.Arrow.Flight.Middleware;
+
+public class ClientCookieMiddleware : IFlightClientMiddleware
+{
+ private readonly ClientCookieMiddlewareFactory _factory;
+ private readonly ILogger<ClientCookieMiddleware> _logger;
+ private const string SetCookieHeader = "Set-Cookie";
+ private const string CookieHeader = "Cookie";
+
+ public ClientCookieMiddleware(ClientCookieMiddlewareFactory factory,
+ ILogger<ClientCookieMiddleware> logger)
+ {
+ _factory = factory;
+ _logger = logger;
+ }
+
+ public void OnBeforeSendingHeaders(ICallHeaders outgoingHeaders)
+ {
+ if (_factory.Cookies.IsEmpty)
+ return;
+ var cookieValue = GetValidCookiesAsString();
+ if (!string.IsNullOrEmpty(cookieValue))
+ {
+ outgoingHeaders.Insert(CookieHeader, cookieValue);
+ }
+ }
+
+ public void OnHeadersReceived(ICallHeaders incomingHeaders)
+ {
+ var setCookies = incomingHeaders.GetAll(SetCookieHeader);
+ _factory.UpdateCookies(setCookies);
+ }
+
+ public void OnCallCompleted(Status status, Metadata trailers)
+ {
+ // ingest: status and/or metadata trailers
+ }
+
+ private string GetValidCookiesAsString()
+ {
+ var cookieList = new List<string>();
+ foreach (var entry in _factory.Cookies)
+ {
+ if (entry.Value.Expired)
+ {
+ _factory.Cookies.TryRemove(entry.Key, out _);
+ }
+ else
+ {
+ cookieList.Add(entry.Value.ToString());
+ }
+ }
+ return string.Join("; ", cookieList);
+ }
+}
diff --git
a/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddlewareFactory.cs
b/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddlewareFactory.cs
new file mode 100644
index 0000000..8dd8114
--- /dev/null
+++ b/src/Apache.Arrow.Flight/Middleware/ClientCookieMiddlewareFactory.cs
@@ -0,0 +1,67 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+using System;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+using System.Globalization;
+using System.Net;
+using Apache.Arrow.Flight.Middleware.Extensions;
+using Apache.Arrow.Flight.Middleware.Interfaces;
+using Microsoft.Extensions.Logging;
+
+namespace Apache.Arrow.Flight.Middleware;
+
+public class ClientCookieMiddlewareFactory : IFlightClientMiddlewareFactory
+{
+ public readonly ConcurrentDictionary<string, Cookie> Cookies =
new(StringComparer.OrdinalIgnoreCase);
+ private readonly ILogger<ClientCookieMiddleware> _logger;
+
+ public ClientCookieMiddlewareFactory(ILoggerFactory loggerFactory)
+ {
+ _logger = loggerFactory.CreateLogger<ClientCookieMiddleware>();
+ }
+
+ public IFlightClientMiddleware OnCallStarted(CallInfo callInfo)
+ {
+ return new ClientCookieMiddleware(this, _logger);
+ }
+
+ internal void UpdateCookies(IEnumerable<string> newCookieHeaderValues)
+ {
+ foreach (var headerValue in newCookieHeaderValues)
+ {
+ try
+ {
+ foreach (var parsedCookie in headerValue.ParseHeader())
+ {
+ var nameLc =
parsedCookie.Name.ToLower(CultureInfo.InvariantCulture);
+ if (parsedCookie.IsExpired(headerValue))
+ {
+ Cookies.TryRemove(nameLc, out _);
+ }
+ else
+ {
+ Cookies[nameLc] = parsedCookie;
+ }
+ }
+ }
+ catch (FormatException ex)
+ {
+ _logger.LogWarning(ex, "Skipping malformed Set-Cookie header:
'{HeaderValue}'", headerValue);
+ }
+ }
+ }
+}
diff --git a/src/Apache.Arrow.Flight/Middleware/Extensions/CookieExtensions.cs
b/src/Apache.Arrow.Flight/Middleware/Extensions/CookieExtensions.cs
new file mode 100644
index 0000000..db72e1c
--- /dev/null
+++ b/src/Apache.Arrow.Flight/Middleware/Extensions/CookieExtensions.cs
@@ -0,0 +1,104 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+using System;
+using System.Collections.Generic;
+using System.Globalization;
+using System.Linq;
+using System.Net;
+
+namespace Apache.Arrow.Flight.Middleware.Extensions;
+
+public static class CookieExtensions
+{
+ public static IEnumerable<Cookie> ParseHeader(this string setCookieHeader)
+ {
+ if (string.IsNullOrWhiteSpace(setCookieHeader))
+ return System.Array.Empty<Cookie>();
+
+ var cookies = new List<Cookie>();
+
+ var segments = setCookieHeader.Split([';'],
StringSplitOptions.RemoveEmptyEntries);
+ if (segments.Length == 0)
+ return cookies;
+
+ var nameValue = segments[0].Split(['='], 2);
+ if (nameValue.Length != 2 || string.IsNullOrWhiteSpace(nameValue[0]))
+ return cookies;
+
+ var name = nameValue[0].Trim();
+ var value = nameValue[1].Trim();
+ var cookie = new Cookie(name, value);
+
+ foreach (var segment in segments.Skip(1))
+ {
+ var kv = segment.Split(['='], 2,
StringSplitOptions.RemoveEmptyEntries);
+ var key = kv[0].Trim().ToLowerInvariant();
+ var val = kv.Length > 1 ? kv[1] : null;
+
+ switch (key)
+ {
+ case "expires":
+ if (!string.IsNullOrWhiteSpace(val))
+ {
+ if (DateTimeOffset.TryParseExact(val, "R",
CultureInfo.InvariantCulture, DateTimeStyles.None, out var expiresRfc))
+ cookie.Expires = expiresRfc.UtcDateTime;
+ else if (DateTimeOffset.TryParse(val, out var
expiresFallback))
+ cookie.Expires = expiresFallback.UtcDateTime;
+ }
+ break;
+
+ case "max-age":
+ if (int.TryParse(val, out var seconds))
+ cookie.Expires = DateTime.UtcNow.AddSeconds(seconds);
+ break;
+
+ case "domain":
+ cookie.Domain = val;
+ break;
+
+ case "path":
+ cookie.Path = val;
+ break;
+
+ case "secure":
+ cookie.Secure = true;
+ break;
+
+ case "httponly":
+ cookie.HttpOnly = true;
+ break;
+ }
+ }
+
+ cookies.Add(cookie);
+ return cookies;
+ }
+
+ public static bool IsExpired(this Cookie cookie, string rawHeader)
+ {
+ if (string.IsNullOrWhiteSpace(cookie?.Value))
+ return true;
+
+ // If raw header has Max-Age=0, consider it deleted
+ if (rawHeader?.IndexOf("Max-Age=0",
StringComparison.OrdinalIgnoreCase) >= 0)
+ return true;
+
+ if (cookie.Expires != DateTime.MinValue && cookie.Expires <=
DateTime.UtcNow)
+ return true;
+
+ return false;
+ }
+}
diff --git
a/src/Apache.Arrow.Flight/Middleware/Interceptors/ClientInterceptorAdapter.cs
b/src/Apache.Arrow.Flight/Middleware/Interceptors/ClientInterceptorAdapter.cs
new file mode 100644
index 0000000..2a5c2dd
--- /dev/null
+++
b/src/Apache.Arrow.Flight/Middleware/Interceptors/ClientInterceptorAdapter.cs
@@ -0,0 +1,186 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading.Tasks;
+using Apache.Arrow.Flight.Middleware.Interfaces;
+using Grpc.Core;
+using Grpc.Core.Interceptors;
+
+namespace Apache.Arrow.Flight.Middleware.Interceptors;
+
+public sealed class ClientInterceptorAdapter : Interceptor
+{
+ private readonly IReadOnlyList<IFlightClientMiddlewareFactory> _factories;
+
+ public
ClientInterceptorAdapter(IEnumerable<IFlightClientMiddlewareFactory> factories)
+ {
+ _factories = factories?.ToList() ?? throw new
ArgumentNullException(nameof(factories));
+ }
+
+ public override AsyncUnaryCall<TResponse> AsyncUnaryCall<TRequest,
TResponse>(
+ TRequest request,
+ ClientInterceptorContext<TRequest, TResponse> context,
+ AsyncUnaryCallContinuation<TRequest, TResponse> continuation)
+ where TRequest : class
+ where TResponse : class
+ {
+ var options = InterceptCall(context, out var middlewares);
+
+ var newContext = new ClientInterceptorContext<TRequest, TResponse>(
+ context.Method,
+ context.Host,
+ options);
+
+ var call = continuation(request, newContext);
+
+ return new AsyncUnaryCall<TResponse>(
+ HandleResponse(call.ResponseAsync, call.ResponseHeadersAsync,
call.GetStatus, call.GetTrailers,
+ call.Dispose, middlewares),
+ call.ResponseHeadersAsync,
+ call.GetStatus,
+ call.GetTrailers,
+ call.Dispose
+ );
+ }
+
+ public override AsyncServerStreamingCall<TResponse>
AsyncServerStreamingCall<TRequest, TResponse>(
+ TRequest request,
+ ClientInterceptorContext<TRequest, TResponse> context,
+ AsyncServerStreamingCallContinuation<TRequest, TResponse> continuation)
+ where TRequest : class
+ where TResponse : class
+ {
+ var callOptions = InterceptCall(context, out var middlewares);
+ var newContext = new ClientInterceptorContext<TRequest, TResponse>(
+ context.Method, context.Host, callOptions);
+
+ var call = continuation(request, newContext);
+
+ var responseHeadersTask = call.ResponseHeadersAsync.ContinueWith(task
=>
+ {
+ if (task.IsFaulted)
+ {
+ throw task.Exception!;
+ }
+
+ if (task.IsCanceled)
+ {
+ throw new TaskCanceledException(task);
+ }
+
+ var headers = task.Result;
+ var ch = new CallHeaders(headers);
+ foreach (var m in middlewares)
+ m?.OnHeadersReceived(ch);
+
+ return headers;
+ });
+
+ var wrappedResponseStream = new MiddlewareResponseStream<TResponse>(
+ call.ResponseStream,
+ call,
+ middlewares);
+
+ return new AsyncServerStreamingCall<TResponse>(
+ wrappedResponseStream,
+ responseHeadersTask,
+ call.GetStatus,
+ call.GetTrailers,
+ call.Dispose);
+ }
+
+
+ private CallOptions InterceptCall<TRequest, TResponse>(
+ ClientInterceptorContext<TRequest, TResponse> context,
+ out List<IFlightClientMiddleware> middlewareList)
+ where TRequest : class
+ where TResponse : class
+ {
+ var callInfo = new CallInfo(context.Method.FullName,
context.Method.Type);
+
+ var headers = context.Options.Headers ?? new Metadata();
+ middlewareList = new List<IFlightClientMiddleware>();
+
+ var callHeaders = new CallHeaders(headers);
+
+ foreach (var factory in _factories)
+ {
+ var middleware = factory.OnCallStarted(callInfo);
+ middleware?.OnBeforeSendingHeaders(callHeaders);
+ middlewareList.Add(middleware);
+ }
+
+ return context.Options.WithHeaders(headers);
+ }
+
+ private async Task<TResponse> HandleResponse<TResponse>(
+ Task<TResponse> responseTask,
+ Task<Metadata> headersTask,
+ Func<Status> getStatus,
+ Func<Metadata> getTrailers,
+ Action dispose,
+ List<IFlightClientMiddleware> middlewares)
+ {
+ var nonNullMiddlewares = (middlewares ?? new
List<IFlightClientMiddleware>())
+ .Where(m => m != null)
+ .ToList();
+
+ var hasMiddlewares = nonNullMiddlewares.Count > 0;
+ var completionNotified = false;
+
+ try
+ {
+ // Always await headers to surface faults; only materialize
CallHeaders if needed.
+ var headers = await headersTask.ConfigureAwait(false);
+ if (hasMiddlewares)
+ {
+ var ch = new CallHeaders(headers);
+ foreach (var m in nonNullMiddlewares)
+ m.OnHeadersReceived(ch);
+ }
+
+ var response = await responseTask.ConfigureAwait(false);
+
+ // Single completion notification
+ NotifyCompletionOnce();
+ return response;
+ }
+ catch
+ {
+ // Completion on failure (only once)
+ NotifyCompletionOnce();
+ throw;
+ }
+ finally
+ {
+ dispose?.Invoke();
+ }
+
+ void NotifyCompletionOnce()
+ {
+ if (completionNotified || !hasMiddlewares) return;
+ completionNotified = true;
+
+ var status = getStatus();
+ var trailers = getTrailers();
+
+ foreach (var m in nonNullMiddlewares)
+ m.OnCallCompleted(status, trailers);
+ }
+ }
+}
diff --git
a/src/Apache.Arrow.Flight/Middleware/Interceptors/MiddlewareResponseStream.cs
b/src/Apache.Arrow.Flight/Middleware/Interceptors/MiddlewareResponseStream.cs
new file mode 100644
index 0000000..273e951
--- /dev/null
+++
b/src/Apache.Arrow.Flight/Middleware/Interceptors/MiddlewareResponseStream.cs
@@ -0,0 +1,69 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+using System.Collections.Generic;
+using System.Threading;
+using System.Threading.Tasks;
+using Apache.Arrow.Flight.Middleware.Interfaces;
+using Grpc.Core;
+
+namespace Apache.Arrow.Flight.Middleware.Interceptors;
+
+public class MiddlewareResponseStream<T> : IAsyncStreamReader<T> where T :
class
+{
+ private readonly IAsyncStreamReader<T> _inner;
+ private readonly AsyncServerStreamingCall<T> _call;
+ private readonly List<IFlightClientMiddleware> _middlewareList;
+
+ public MiddlewareResponseStream(
+ IAsyncStreamReader<T> inner,
+ AsyncServerStreamingCall<T> call,
+ List<IFlightClientMiddleware> middlewareList)
+ {
+ _inner = inner;
+ _call = call;
+ _middlewareList = middlewareList;
+ }
+
+ public T Current => _inner.Current;
+
+ public async Task<bool> MoveNext(CancellationToken cancellationToken)
+ {
+ try
+ {
+ bool hasNext = await
_inner.MoveNext(cancellationToken).ConfigureAwait(false);
+ if (!hasNext)
+ {
+ TriggerOnCallCompleted();
+ }
+
+ return hasNext;
+ }
+ catch
+ {
+ TriggerOnCallCompleted();
+ throw;
+ }
+ }
+
+ private void TriggerOnCallCompleted()
+ {
+ var status = _call.GetStatus();
+ var trailers = _call.GetTrailers();
+
+ foreach (var m in _middlewareList)
+ m?.OnCallCompleted(status, trailers);
+ }
+}
diff --git a/src/Apache.Arrow.Flight/Middleware/Interfaces/ICallHeaders.cs
b/src/Apache.Arrow.Flight/Middleware/Interfaces/ICallHeaders.cs
new file mode 100644
index 0000000..aba5ba3
--- /dev/null
+++ b/src/Apache.Arrow.Flight/Middleware/Interfaces/ICallHeaders.cs
@@ -0,0 +1,34 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+using System.Collections.Generic;
+
+namespace Apache.Arrow.Flight.Middleware.Interfaces;
+
+public interface ICallHeaders
+{
+ string this[string key] { get; }
+
+ string Get(string key);
+ byte[] GetBytes(string key);
+ IEnumerable<string> GetAll(string key);
+ IEnumerable<byte[]> GetAllBytes(string key);
+
+ void Insert(string key, string value);
+ void Insert(string key, byte[] value);
+
+ ISet<string> Keys { get; }
+ bool ContainsKey(string key);
+}
diff --git
a/src/Apache.Arrow.Flight/Middleware/Interfaces/IFlightClientMiddleware.cs
b/src/Apache.Arrow.Flight/Middleware/Interfaces/IFlightClientMiddleware.cs
new file mode 100644
index 0000000..d34686d
--- /dev/null
+++ b/src/Apache.Arrow.Flight/Middleware/Interfaces/IFlightClientMiddleware.cs
@@ -0,0 +1,25 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+using Grpc.Core;
+
+namespace Apache.Arrow.Flight.Middleware.Interfaces;
+
+public interface IFlightClientMiddleware
+{
+ void OnBeforeSendingHeaders(ICallHeaders outgoingHeaders);
+ void OnHeadersReceived(ICallHeaders incomingHeaders);
+ void OnCallCompleted(Status status, Metadata trailers);
+}
diff --git
a/src/Apache.Arrow.Flight/Middleware/Interfaces/IFlightClientMiddlewareFactory.cs
b/src/Apache.Arrow.Flight/Middleware/Interfaces/IFlightClientMiddlewareFactory.cs
new file mode 100644
index 0000000..6ae7456
--- /dev/null
+++
b/src/Apache.Arrow.Flight/Middleware/Interfaces/IFlightClientMiddlewareFactory.cs
@@ -0,0 +1,21 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+namespace Apache.Arrow.Flight.Middleware.Interfaces;
+
+public interface IFlightClientMiddlewareFactory
+{
+ IFlightClientMiddleware OnCallStarted(CallInfo callInfo);
+}
diff --git a/test/Apache.Arrow.Flight.Tests/MiddlewareTests/CallHeadersTests.cs
b/test/Apache.Arrow.Flight.Tests/MiddlewareTests/CallHeadersTests.cs
new file mode 100644
index 0000000..2f69a75
--- /dev/null
+++ b/test/Apache.Arrow.Flight.Tests/MiddlewareTests/CallHeadersTests.cs
@@ -0,0 +1,135 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+using System.Linq;
+using Apache.Arrow.Flight.Tests.MiddlewareTests.Stubs;
+using Xunit;
+
+namespace Apache.Arrow.Flight.Tests.MiddlewareTests;
+
+public class CallHeadersTests
+{
+ private readonly InMemoryCallHeaders _headers = new();
+
+ [Fact]
+ public void InsertAndGetStringValue()
+ {
+ _headers.Insert("Auth", "Bearer 123");
+ Assert.Equal("Bearer 123", _headers.Get("Auth"));
+ Assert.Equal("Bearer 123", _headers["Auth"]);
+ }
+
+ [Fact]
+ public void InsertAndGetByteArrayValue()
+ {
+ var bytes = new byte[] { 1, 2, 3, 4, 5 };
+ _headers.Insert("Data", bytes);
+ Assert.Equal(bytes, _headers.GetBytes("Data"));
+ }
+
+ [Fact]
+ public void InsertMultipleValuesAndGetLast()
+ {
+ _headers.Insert("User", "Alice");
+ _headers.Insert("User", "Bob");
+ Assert.Equal("Alice", _headers.Get("User"));
+ }
+
+ [Fact]
+ public void GetAllShouldReturnAllStringValues()
+ {
+ _headers.Insert("Header", "v1");
+ _headers.Insert("Header", "v2");
+ var all = _headers.GetAll("Header").ToList();
+ Assert.Contains("v1", all);
+ Assert.Contains("v2", all);
+ Assert.Equal(2, all.Count);
+ }
+
+ [Fact]
+ public void GetAllBytesShouldReturnAllByteArrayValues()
+ {
+ var a = new byte[] { 1 };
+ var b = new byte[] { 2 };
+ _headers.Insert("Binary", a);
+ _headers.Insert("Binary", b);
+ var all = _headers.GetAllBytes("Binary").ToList();
+ Assert.Contains(a, all);
+ Assert.Contains(b, all);
+ Assert.Equal(2, all.Count);
+ }
+
+ [Fact]
+ public void KeysShouldReturnAllKeys()
+ {
+ _headers.Insert("A", "x");
+ _headers.Insert("B", "y");
+ Assert.Contains("A", _headers.Keys);
+ Assert.Contains("B", _headers.Keys);
+ }
+
+ [Fact]
+ public void ContainsKeyShouldWork()
+ {
+ _headers.Insert("Check", "yes");
+ Assert.True(_headers.ContainsKey("Check"));
+ Assert.False(_headers.ContainsKey("Missing"));
+ }
+
+ [Fact]
+ public void GetNonExistentKeyShouldReturnNull()
+ {
+ Assert.Null(_headers.Get("MissingKey"));
+ Assert.Null(_headers.GetBytes("MissingKey"));
+ Assert.Empty(_headers.GetAll("MissingKey"));
+ Assert.Empty(_headers.GetAllBytes("MissingKey"));
+ }
+
+ [Fact]
+ public void ContainsKeyShouldBeFalseForMissingKey()
+ {
+ Assert.False(_headers.ContainsKey("DefinitelyMissing"));
+ }
+
+ [Fact]
+ public void KeysShouldBeEmptyWhenNoHeaders()
+ {
+ Assert.Empty(_headers.Keys);
+ }
+
+ [Fact]
+ public void IndexerShouldReturnNullForMissingKey()
+ {
+ string value = _headers["nonexistent"];
+ Assert.Null(value);
+ }
+
+ [Fact]
+ public void InsertEmptyStringsShouldStillStore()
+ {
+ _headers.Insert("Empty", "");
+ Assert.Equal("", _headers.Get("Empty"));
+ Assert.Single(_headers.GetAll("Empty"));
+ }
+
+ [Fact]
+ public void InsertEmptyByteArrayShouldStillStore()
+ {
+ var empty = System.Array.Empty<byte>();
+ _headers.Insert("BinaryEmpty", empty);
+ Assert.Equal(empty, _headers.GetBytes("BinaryEmpty"));
+ Assert.Single(_headers.GetAllBytes("BinaryEmpty"));
+ }
+}
diff --git
a/test/Apache.Arrow.Flight.Tests/MiddlewareTests/ClientCookieMiddlewareTests.cs
b/test/Apache.Arrow.Flight.Tests/MiddlewareTests/ClientCookieMiddlewareTests.cs
new file mode 100644
index 0000000..12558f2
--- /dev/null
+++
b/test/Apache.Arrow.Flight.Tests/MiddlewareTests/ClientCookieMiddlewareTests.cs
@@ -0,0 +1,137 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+using System;
+using System.Linq;
+using System.Net;
+using System.Threading.Tasks;
+using Apache.Arrow.Flight.Middleware;
+using Apache.Arrow.Flight.Tests.MiddlewareTests.Stubs;
+using Xunit;
+
+namespace Apache.Arrow.Flight.Tests.MiddlewareTests;
+
+public class ClientCookieMiddlewareTests
+{
+ private readonly ClientCookieMiddlewareMock _middlewareMock = new();
+
+ [Fact]
+ public void NoCookiesReturnsEmptyString()
+ {
+ var factory = _middlewareMock.CreateFactory();
+ var middleware =
+ new ClientCookieMiddleware(factory, new
ClientCookieMiddlewareMock.TestLogger<ClientCookieMiddleware>());
+ var headers = new InMemoryCallHeaders();
+ middleware.OnBeforeSendingHeaders(headers);
+ Assert.Empty(headers.GetAll("Cookie"));
+ }
+
+ [Fact]
+ public void OnlyExpiredCookiesRemovesAll()
+ {
+ var factory = _middlewareMock.CreateFactory();
+ factory.Cookies["expired"] =
+ _middlewareMock.CreateCookie("expired", "value",
DateTimeOffset.UtcNow.AddMinutes(-5));
+ var middleware =
+ new ClientCookieMiddleware(factory, new
ClientCookieMiddlewareMock.TestLogger<ClientCookieMiddleware>());
+ var headers = new InMemoryCallHeaders();
+ middleware.OnBeforeSendingHeaders(headers);
+ Assert.Empty(headers.GetAll("Cookie"));
+ Assert.Empty(factory.Cookies);
+ }
+
+ [Fact]
+ public void OnlyValidCookiesReturnsCookieHeader()
+ {
+ var factory = _middlewareMock.CreateFactory();
+ factory.Cookies["valid"] = _middlewareMock.CreateCookie("valid",
"abc", DateTimeOffset.UtcNow.AddMinutes(10));
+ var middleware =
+ new ClientCookieMiddleware(factory, new
ClientCookieMiddlewareMock.TestLogger<ClientCookieMiddleware>());
+ var headers = new InMemoryCallHeaders();
+ middleware.OnBeforeSendingHeaders(headers);
+ var header = headers.GetAll("Cookie").FirstOrDefault();
+ Assert.NotNull(header);
+ Assert.Contains("valid=abc", header);
+ }
+
+ [Fact]
+ public void MixedCookiesRemovesExpiredOnly()
+ {
+ var factory = _middlewareMock.CreateFactory();
+ factory.Cookies["expired"] =
+ _middlewareMock.CreateCookie("expired", "x",
DateTimeOffset.UtcNow.AddMinutes(-10));
+ factory.Cookies["valid"] = _middlewareMock.CreateCookie("valid", "y",
DateTimeOffset.UtcNow.AddMinutes(10));
+ var middleware =
+ new ClientCookieMiddleware(factory, new
ClientCookieMiddlewareMock.TestLogger<ClientCookieMiddleware>());
+ var headers = new InMemoryCallHeaders();
+ middleware.OnBeforeSendingHeaders(headers);
+ var header = headers.GetAll("Cookie").FirstOrDefault();
+ Assert.NotNull(header);
+ Assert.Contains("valid=y", header);
+ Assert.DoesNotContain("expired=x", header);
+ Assert.Single(factory.Cookies);
+ }
+
+ [Fact]
+ public void DuplicateCookieKeysLastValidRemains()
+ {
+ var factory = _middlewareMock.CreateFactory();
+ factory.Cookies["token"] = _middlewareMock.CreateCookie("token",
"old", DateTimeOffset.UtcNow.AddMinutes(-5));
+ factory.Cookies["token"] = _middlewareMock.CreateCookie("token",
"new", DateTimeOffset.UtcNow.AddMinutes(5));
+ var middleware =
+ new ClientCookieMiddleware(factory, new
ClientCookieMiddlewareMock.TestLogger<ClientCookieMiddleware>());
+ var headers = new InMemoryCallHeaders();
+ middleware.OnBeforeSendingHeaders(headers);
+ var header = headers.GetAll("Cookie").FirstOrDefault();
+ Assert.NotNull(header);
+ Assert.Contains("token=new", header);
+ }
+
+ [Fact]
+ public void FalsePositiveValidDateButMarkedExpired()
+ {
+ var factory = _middlewareMock.CreateFactory();
+ factory.Cookies["wrong"] =
+ _middlewareMock.CreateCookie("wrong", "v",
DateTimeOffset.UtcNow.AddMinutes(10), expiredOverride: true);
+ var middleware =
+ new ClientCookieMiddleware(factory, new
ClientCookieMiddlewareMock.TestLogger<ClientCookieMiddleware>());
+ var headers = new InMemoryCallHeaders();
+ middleware.OnBeforeSendingHeaders(headers);
+ Assert.Empty(headers.GetAll("Cookie"));
+ }
+
+ [Fact]
+ public async Task ConcurrentInsertRemoveDoesNotCorrupt()
+ {
+ var factory = _middlewareMock.CreateFactory();
+ var middleware =
+ new ClientCookieMiddleware(factory, new
ClientCookieMiddlewareMock.TestLogger<ClientCookieMiddleware>());
+
+ for (int i = 0; i < 100; i++)
+ factory.Cookies[$"cookie{i}"] =
+ _middlewareMock.CreateCookie($"cookie{i}", $"{i}",
DateTimeOffset.UtcNow.AddMinutes(5));
+
+ var tasks = Enumerable.Range(0, 20).Select(_ => Task.Run(() =>
+ {
+ var headers = new InMemoryCallHeaders();
+ middleware.OnBeforeSendingHeaders(headers);
+ foreach (var key in factory.Cookies.Keys)
+ factory.Cookies.TryRemove(key, out Cookie _);
+ }));
+
+ await Task.WhenAll(tasks);
+ Assert.True(factory.Cookies.Count >= 0);
+ }
+}
diff --git
a/test/Apache.Arrow.Flight.Tests/MiddlewareTests/ClientInterceptorAdapterTests.cs
b/test/Apache.Arrow.Flight.Tests/MiddlewareTests/ClientInterceptorAdapterTests.cs
new file mode 100644
index 0000000..29777a4
--- /dev/null
+++
b/test/Apache.Arrow.Flight.Tests/MiddlewareTests/ClientInterceptorAdapterTests.cs
@@ -0,0 +1,99 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading.Tasks;
+using Apache.Arrow.Flight.Client;
+using Apache.Arrow.Flight.Middleware.Interceptors;
+using Apache.Arrow.Flight.Sql.Tests.Stubs;
+using Apache.Arrow.Flight.Tests.MiddlewareTests.Stubs;
+using Grpc.Core;
+using Grpc.Core.Interceptors;
+using Xunit;
+
+namespace Apache.Arrow.Flight.Tests.MiddlewareTests;
+
+public class ClientInterceptorAdapterTests
+{
+ private readonly TestWebFactory _testWebFactory;
+ private readonly FlightClient _client;
+ private readonly CapturingMiddlewareFactory _middlewareFactory;
+
+ public ClientInterceptorAdapterTests()
+ {
+ _testWebFactory = new TestWebFactory(new InMemoryFlightStore());
+
+ _middlewareFactory = new CapturingMiddlewareFactory();
+ var interceptor = new ClientInterceptorAdapter([_middlewareFactory]);
+
+ _client = new
FlightClient(_testWebFactory.GetChannel().Intercept(interceptor));
+ }
+
+ [Fact]
+ public async Task MiddlewareFlowIsCalledCorrectly()
+ {
+ // Arrange
+ var descriptor = FlightDescriptor.CreatePathDescriptor("test");
+
+ // Act
+ var info = await _client.GetInfo(descriptor);
+ var middleware = _middlewareFactory.Instance;
+
+ // Assert
+ Assert.NotNull(info);
+ Assert.True(middleware.BeforeHeadersCalled, "BeforeHeaders not
called");
+ Assert.True(middleware.HeadersReceivedCalled, "HeadersReceived not
called");
+ Assert.True(middleware.CallCompletedCalled, "CallCompleted not
called");
+ }
+
+ [Fact]
+ public async Task CookieAndHeaderValuesArePersistedThroughMiddleware()
+ {
+ // Arrange
+ var descriptor = FlightDescriptor.CreatePathDescriptor("test");
+
+ // Act
+ try
+ {
+ await _client.GetInfo(descriptor);
+ }
+ catch (RpcException)
+ {
+ // Expected: Flight not found, but middleware should have run
+ }
+
+ // Assert Middleware captured the headers and cookies correctly
+ var middleware = _middlewareFactory.Instance;
+
+ Assert.True(middleware.BeforeHeadersCalled, "OnBeforeSendingHeaders
not called");
+ Assert.True(middleware.HeadersReceivedCalled, "OnHeadersReceived not
called");
+ Assert.True(middleware.CallCompletedCalled, "OnCallCompleted not
called");
+
+ // Validate Cookies captured correctly
+ Assert.True(middleware.CapturedHeaders.ContainsKey("cookie"));
+ var cookies = ParseCookies(middleware.CapturedHeaders["cookie"]);
+
+ Assert.Equal("abc123", cookies["sessionId"]);
+ Assert.Equal("xyz789", cookies["token"]);
+ }
+
+ private static IDictionary<string, string> ParseCookies(string
cookieHeader)
+ {
+ return cookieHeader.Split(';')
+ .Select(pair => pair.Split('='))
+ .ToDictionary(parts => parts[0].Trim(), parts => parts[1].Trim());
+ }
+}
diff --git
a/test/Apache.Arrow.Flight.Tests/MiddlewareTests/CookieExtensionsTests.cs
b/test/Apache.Arrow.Flight.Tests/MiddlewareTests/CookieExtensionsTests.cs
new file mode 100644
index 0000000..47c7cac
--- /dev/null
+++ b/test/Apache.Arrow.Flight.Tests/MiddlewareTests/CookieExtensionsTests.cs
@@ -0,0 +1,127 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+using System;
+using Apache.Arrow.Flight.Middleware.Extensions;
+using Xunit;
+using static System.Linq.Enumerable;
+
+namespace Apache.Arrow.Flight.Tests.MiddlewareTests;
+
+public class CookieExtensionsTests
+{
+ [Fact]
+ public void ParseHeaderShouldParseSimpleCookie()
+ {
+ // Arrange
+ var header = "sessionId=abc123";
+
+ // Act
+ var cookies = header.ParseHeader().ToList();
+
+ // Assert
+ Assert.Single(cookies);
+ Assert.Equal("sessionId", cookies[0].Name);
+ Assert.Equal("abc123", cookies[0].Value);
+ Assert.False(cookies[0].Expired);
+ }
+
+ [Fact]
+ public void ParseHeaderShouldParseCookieWithExpires()
+ {
+ // Arrange
+ var futureDate = DateTimeOffset.UtcNow.AddDays(7);
+ var header = $"userId=789; Expires={futureDate:R}";
+
+ // Act
+ var cookies = header.ParseHeader().ToList();
+
+ // Assert
+ Assert.Single(cookies);
+ Assert.Equal("userId", cookies[0].Name);
+ Assert.Equal("789", cookies[0].Value);
+ Assert.True(Math.Abs((cookies[0].Expires -
futureDate.UtcDateTime).TotalSeconds) < 5);
+ }
+
+ [Fact]
+ public void ParseHeaderShouldReturnEmptyWhenMalformed()
+ {
+ // Arrange
+ var header = "this_is_wrong";
+
+ // Act
+ var cookies = header.ParseHeader().ToList();
+
+ // Assert
+ Assert.Empty(cookies);
+ }
+
+ [Fact]
+ public void ParseHeaderShouldReturnEmptyWhenEmptyString()
+ {
+ // Arrange
+ var header = string.Empty;
+
+ // Act
+ var cookies = header.ParseHeader().ToList();
+
+ // Assert
+ Assert.Empty(cookies);
+ }
+
+ [Fact]
+ public void ParseHeaderShouldReturnEmptyWhenNullString()
+ {
+ // Arrange
+ string header = null;
+
+ // Act
+ var cookies = header.ParseHeader().ToList();
+
+ // Assert
+ Assert.Empty(cookies);
+ }
+
+ [Fact]
+ public void ParseHeaderShouldParseCookieIgnoringAttributes()
+ {
+ // Arrange
+ var header = "token=xyz; Path=/; HttpOnly";
+
+ // Act
+ var cookies = header.ParseHeader().ToList();
+
+ // Assert
+ Assert.Single(cookies);
+ Assert.Equal("token", cookies[0].Name);
+ Assert.Equal("xyz", cookies[0].Value);
+ }
+
+ [Fact]
+ public void ParseHeaderShouldIgnoreInvalidExpires()
+ {
+ // Arrange
+ var header = "name=value; Expires=invalid-date";
+
+ // Act
+ var cookies = header.ParseHeader().ToList();
+
+ // Assert
+ Assert.Single(cookies);
+ Assert.Equal("name", cookies[0].Name);
+ Assert.Equal("value", cookies[0].Value);
+ Assert.False(cookies[0].Expired);
+ }
+}
diff --git
a/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/CapturingMiddleware.cs
b/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/CapturingMiddleware.cs
new file mode 100644
index 0000000..a58f47a
--- /dev/null
+++
b/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/CapturingMiddleware.cs
@@ -0,0 +1,58 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+using System.Collections.Generic;
+using Apache.Arrow.Flight.Middleware.Interfaces;
+using Grpc.Core;
+
+namespace Apache.Arrow.Flight.Tests.MiddlewareTests.Stubs;
+
+public class CapturingMiddleware : IFlightClientMiddleware
+{
+ public Dictionary<string, string> CapturedHeaders { get; } = new();
+
+ public bool BeforeHeadersCalled { get; private set; }
+ public bool HeadersReceivedCalled { get; private set; }
+ public bool CallCompletedCalled { get; private set; }
+ public void OnBeforeSendingHeaders(ICallHeaders outgoingHeaders)
+ {
+ BeforeHeadersCalled = true;
+ outgoingHeaders.Insert("x-test-header", "test-value");
+ outgoingHeaders.Insert("cookie", "sessionId=abc123; token=xyz789");
+ CaptureHeaders(outgoingHeaders);
+ }
+ public void OnHeadersReceived(ICallHeaders incomingHeaders)
+ {
+ HeadersReceivedCalled = true;
+ CaptureHeaders(incomingHeaders);
+ }
+
+ public void OnCallCompleted(Status status, Metadata trailers)
+ {
+ CallCompletedCalled = true;
+ }
+
+ private void CaptureHeaders(ICallHeaders headers)
+ {
+ foreach (var key in headers.Keys)
+ {
+ var value = headers[key];
+ if (value != null)
+ {
+ CapturedHeaders[key] = value;
+ }
+ }
+ }
+}
diff --git
a/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/CapturingMiddlewareFactory.cs
b/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/CapturingMiddlewareFactory.cs
new file mode 100644
index 0000000..5618cdf
--- /dev/null
+++
b/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/CapturingMiddlewareFactory.cs
@@ -0,0 +1,28 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+using System.Dynamic;
+using Apache.Arrow.Flight.Middleware.Interfaces;
+using CallInfo = Apache.Arrow.Flight.Middleware.CallInfo;
+
+
+namespace Apache.Arrow.Flight.Tests.MiddlewareTests.Stubs;
+
+public class CapturingMiddlewareFactory : IFlightClientMiddlewareFactory
+{
+ public CapturingMiddleware Instance { get; } = new();
+
+ public IFlightClientMiddleware OnCallStarted(CallInfo callInfo) =>
Instance;
+}
diff --git
a/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/ClientCookieMiddlewareMock.cs
b/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/ClientCookieMiddlewareMock.cs
new file mode 100644
index 0000000..540bf9e
--- /dev/null
+++
b/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/ClientCookieMiddlewareMock.cs
@@ -0,0 +1,64 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+using System;
+using System.Net;
+using Apache.Arrow.Flight.Middleware;
+using Microsoft.Extensions.Logging;
+
+namespace Apache.Arrow.Flight.Tests.MiddlewareTests.Stubs;
+
+internal class ClientCookieMiddlewareMock
+{
+ public Cookie CreateCookie(string name, string value, DateTimeOffset?
expires = null, bool? expiredOverride = null)
+ {
+ return new Cookie
+ {
+ Name = name,
+ Value = value,
+ Expires = expires!.Value.UtcDateTime,
+ Expired = expiredOverride ?? (expires.HasValue && expires.Value <
DateTimeOffset.UtcNow)
+ };
+ }
+
+ public ClientCookieMiddlewareFactory CreateFactory()
+ {
+ return new ClientCookieMiddlewareFactory(new TestLoggerFactory());
+ }
+
+ public class TestLogger<T> : ILogger<T>
+ {
+ public IDisposable BeginScope<TState>(TState state) => null;
+ public bool IsEnabled(LogLevel logLevel) => false;
+
+ public void Log<TState>(LogLevel logLevel, EventId eventId, TState
state, Exception exception,
+ Func<TState, Exception, string> formatter)
+ {
+ }
+ }
+
+ internal class TestLoggerFactory : ILoggerFactory
+ {
+ public void AddProvider(ILoggerProvider provider)
+ {
+ }
+
+ public ILogger CreateLogger(string categoryName) => new
TestLogger<object>();
+
+ public void Dispose()
+ {
+ }
+ }
+}
diff --git
a/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/InMemoryCallHeaders.cs
b/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/InMemoryCallHeaders.cs
new file mode 100644
index 0000000..11c04e3
--- /dev/null
+++
b/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/InMemoryCallHeaders.cs
@@ -0,0 +1,90 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using Apache.Arrow.Flight.Middleware;
+using Apache.Arrow.Flight.Middleware.Interfaces;
+using Grpc.Core;
+
+namespace Apache.Arrow.Flight.Tests.MiddlewareTests.Stubs;
+
+public class InMemoryCallHeaders : ICallHeaders
+{
+ private readonly CallHeaders _stringHeaders;
+ private readonly Dictionary<string, List<byte[]>> _byteHeaders;
+
+ public InMemoryCallHeaders()
+ {
+ _stringHeaders = new CallHeaders(new Metadata());
+ _byteHeaders = new Dictionary<string,
List<byte[]>>(StringComparer.OrdinalIgnoreCase);
+ }
+
+ private static string NormalizeKey(string key) => key.ToLowerInvariant();
+
+ public string this[string key] => Get(key);
+
+ public string Get(string key)
+ {
+ key = NormalizeKey(key);
+ return _stringHeaders.ContainsKey(key) ? _stringHeaders[key] : null;
+ }
+
+ public byte[] GetBytes(string key)
+ {
+ key = NormalizeKey(key);
+ return _byteHeaders.TryGetValue(key, out var values) ?
values.LastOrDefault() : null;
+ }
+
+ public IEnumerable<string> GetAll(string key)
+ {
+ key = NormalizeKey(key);
+ return _stringHeaders.Where(h => string.Equals(h.Key, key,
StringComparison.OrdinalIgnoreCase))
+ .Select(h => h.Value);
+ }
+
+ public IEnumerable<byte[]> GetAllBytes(string key)
+ {
+ key = NormalizeKey(key);
+ return _byteHeaders.TryGetValue(key, out var values) ? values :
Enumerable.Empty<byte[]>();
+ }
+
+ public void Insert(string key, string value)
+ {
+ key = NormalizeKey(key);
+ _stringHeaders.Add(key, value);
+ }
+
+ public void Insert(string key, byte[] value)
+ {
+ key = NormalizeKey(key);
+ if (!_byteHeaders.TryGetValue(key, out var list))
+ _byteHeaders[key] = list = new List<byte[]>();
+ list.Add(value);
+ }
+
+ public ISet<string> Keys =>
+ new HashSet<string>(
+ _stringHeaders.Select(h => h.Key.ToLowerInvariant())
+ .Concat(_byteHeaders.Keys),
+ StringComparer.OrdinalIgnoreCase);
+
+ public bool ContainsKey(string key)
+ {
+ key = NormalizeKey(key);
+ return _stringHeaders.ContainsKey(key) ||
_byteHeaders.ContainsKey(key);
+ }
+}
diff --git
a/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/InMemoryFlightStore.cs
b/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/InMemoryFlightStore.cs
new file mode 100644
index 0000000..f66b39c
--- /dev/null
+++
b/test/Apache.Arrow.Flight.Tests/MiddlewareTests/Stubs/InMemoryFlightStore.cs
@@ -0,0 +1,49 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+using Apache.Arrow.Flight.TestWeb;
+using Apache.Arrow.Types;
+
+namespace Apache.Arrow.Flight.Sql.Tests.Stubs;
+
+public class InMemoryFlightStore : FlightStore
+{
+ public InMemoryFlightStore()
+ {
+ // Pre-register a dummy flight so GetFlightInfo can resolve it
+ var descriptor = FlightDescriptor.CreatePathDescriptor("test");
+ var schema = new Schema.Builder()
+ .Field(f => f.Name("id").DataType(Int32Type.Default))
+ .Field(f => f.Name("name").DataType(StringType.Default))
+ .Build();
+
+ var recordBatch = new RecordBatch(schema, new Array[]
+ {
+ new Int32Array.Builder().Append(1).Build(),
+ new StringArray.Builder().Append("John Doe").Build()
+ }, 1);
+
+ var location = new FlightLocation("grpc+tcp://localhost:50051");
+
+ var flightHolder = new FlightHolder(descriptor, schema, location.Uri);
+ flightHolder.AddBatch(new RecordBatchWithMetadata(recordBatch));
+ Flights.Add(descriptor, flightHolder);
+ }
+
+ public override string ToString()
+ {
+ return $"InMemoryFlightStore(Flights={Flights.Count})";
+ }
+}