This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new c1507ad20a generic channel support for FlightClient (#9933)
c1507ad20a is described below

commit c1507ad20a3dad44353bd9fb4c489785298d10d8
Author: Rostislav Rumenov <[email protected]>
AuthorDate: Thu May 7 20:10:50 2026 +0200

    generic channel support for FlightClient (#9933)
    
    Allow FlightServiceClient to be parameterized over the underlying
    channel type, so
      users can wrap a tonic channel with custom interceptors or services.
    Motivation: Annotating outbound Flight requests with metadata (e.g.
    injecting
    OpenTelemetry trace context into headers) currently requires forking or
    wrapping at
    a higher level. Making the channel generic lets callers compose tower
    layers/interceptors idiomatically and propagate distributed tracing
    context without
      bespoke plumbing.
    
    ---------
    
    Co-authored-by: Rostislav Rumenov <[email protected]>
---
 arrow-flight/src/client.rs | 291 +++++++++++++++++++++++++++++++++++++++++++--
 1 file changed, 280 insertions(+), 11 deletions(-)

diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs
index dac086271c..b2059a81d0 100644
--- a/arrow-flight/src/client.rs
+++ b/arrow-flight/src/client.rs
@@ -31,6 +31,7 @@ use futures::{
     stream::{self, BoxStream},
 };
 use prost::Message;
+use tonic::codegen::{Body, StdError};
 use tonic::{metadata::MetadataMap, transport::Channel};
 
 use crate::error::{FlightError, Result};
@@ -67,22 +68,28 @@ use crate::streams::{FallibleRequestStream, 
FallibleTonicResponseStream};
 /// # }
 /// ```
 #[derive(Debug)]
-pub struct FlightClient {
+pub struct FlightClient<T = Channel> {
     /// Optional grpc header metadata to include with each request
     metadata: MetadataMap,
 
     /// The inner client
-    inner: FlightServiceClient<Channel>,
+    inner: FlightServiceClient<T>,
 }
 
-impl FlightClient {
-    /// Creates a client client with the provided [`Channel`]
-    pub fn new(channel: Channel) -> Self {
-        Self::new_from_inner(FlightServiceClient::new(channel))
+impl<T> FlightClient<T>
+where
+    T: tonic::client::GrpcService<tonic::body::Body>,
+    T::Error: Into<StdError>,
+    T::ResponseBody: Body<Data = Bytes> + std::marker::Send + 'static,
+    <T::ResponseBody as Body>::Error: Into<StdError> + std::marker::Send,
+{
+    /// Creates a client with the provided transport
+    pub fn new(inner: T) -> Self {
+        Self::new_from_inner(FlightServiceClient::new(inner))
     }
 
     /// Creates a new higher level client with the provided lower level client
-    pub fn new_from_inner(inner: FlightServiceClient<Channel>) -> Self {
+    pub fn new_from_inner(inner: FlightServiceClient<T>) -> Self {
         Self {
             metadata: MetadataMap::new(),
             inner,
@@ -120,19 +127,19 @@ impl FlightClient {
 
     /// Return a reference to the underlying tonic
     /// [`FlightServiceClient`]
-    pub fn inner(&self) -> &FlightServiceClient<Channel> {
+    pub fn inner(&self) -> &FlightServiceClient<T> {
         &self.inner
     }
 
     /// Return a mutable reference to the underlying tonic
     /// [`FlightServiceClient`]
-    pub fn inner_mut(&mut self) -> &mut FlightServiceClient<Channel> {
+    pub fn inner_mut(&mut self) -> &mut FlightServiceClient<T> {
         &mut self.inner
     }
 
     /// Consume this client and return the underlying tonic
     /// [`FlightServiceClient`]
-    pub fn into_inner(self) -> FlightServiceClient<Channel> {
+    pub fn into_inner(self) -> FlightServiceClient<T> {
         self.inner
     }
 
@@ -664,10 +671,272 @@ impl FlightClient {
     }
 
     /// return a Request, adding any configured metadata
-    fn make_request<T>(&self, t: T) -> tonic::Request<T> {
+    fn make_request<R>(&self, t: R) -> tonic::Request<R> {
         // Pass along metadata
         let mut request = tonic::Request::new(t);
         *request.metadata_mut() = self.metadata.clone();
         request
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use super::FlightClient;
+    use crate::encode::FlightDataEncoderBuilder;
+    use crate::flight_service_server::{FlightService, FlightServiceServer};
+    use crate::{
+        Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, 
FlightInfo,
+        HandshakeRequest, HandshakeResponse, PollInfo, PutResult, 
SchemaResult, Ticket,
+    };
+    use arrow_array::{RecordBatch, UInt64Array};
+    use bytes::Bytes;
+    use futures::{StreamExt, TryStreamExt, stream::BoxStream};
+    use std::net::SocketAddr;
+    use std::sync::{Arc, Mutex};
+    use std::time::Duration;
+    use tokio::net::TcpListener;
+    use tokio::task::JoinHandle;
+    use tonic::metadata::MetadataMap;
+    use tonic::service::interceptor::InterceptedService;
+    use tonic::transport::Channel;
+    use tonic::{Request, Response, Status, Streaming};
+    use uuid::Uuid;
+
+    /// Minimal `FlightService` that records request metadata and serves a
+    /// configured `do_get` response. Other RPCs return `Unimplemented`.
+    #[derive(Debug, Clone, Default)]
+    struct InterceptorTestServer {
+        state: Arc<Mutex<InterceptorTestState>>,
+    }
+
+    #[derive(Debug, Default)]
+    struct InterceptorTestState {
+        do_get_request: Option<Ticket>,
+        do_get_response: Option<Vec<Result<RecordBatch, Status>>>,
+        last_request_metadata: Option<MetadataMap>,
+    }
+
+    impl InterceptorTestServer {
+        fn save_metadata<T>(&self, request: &Request<T>) {
+            self.state.lock().unwrap().last_request_metadata = 
Some(request.metadata().clone());
+        }
+
+        fn set_do_get_response(&self, response: Vec<Result<RecordBatch, 
Status>>) {
+            self.state.lock().unwrap().do_get_response = Some(response);
+        }
+
+        fn take_do_get_request(&self) -> Option<Ticket> {
+            self.state.lock().unwrap().do_get_request.take()
+        }
+
+        fn take_last_request_metadata(&self) -> Option<MetadataMap> {
+            self.state.lock().unwrap().last_request_metadata.take()
+        }
+    }
+
+    #[tonic::async_trait]
+    impl FlightService for InterceptorTestServer {
+        type HandshakeStream = BoxStream<'static, Result<HandshakeResponse, 
Status>>;
+        type ListFlightsStream = BoxStream<'static, Result<FlightInfo, 
Status>>;
+        type DoGetStream = BoxStream<'static, Result<FlightData, Status>>;
+        type DoPutStream = BoxStream<'static, Result<PutResult, Status>>;
+        type DoActionStream = BoxStream<'static, Result<crate::Result, 
Status>>;
+        type ListActionsStream = BoxStream<'static, Result<ActionType, 
Status>>;
+        type DoExchangeStream = BoxStream<'static, Result<FlightData, Status>>;
+
+        async fn do_get(
+            &self,
+            request: Request<Ticket>,
+        ) -> Result<Response<Self::DoGetStream>, Status> {
+            self.save_metadata(&request);
+            let mut state = self.state.lock().unwrap();
+            state.do_get_request = Some(request.into_inner());
+
+            let batches = state
+                .do_get_response
+                .take()
+                .ok_or_else(|| Status::internal("no do_get response 
configured"))?;
+            let batch_stream = 
futures::stream::iter(batches).map_err(Into::into);
+            let stream = FlightDataEncoderBuilder::new()
+                .build(batch_stream)
+                .map_err(Into::into);
+            Ok(Response::new(stream.boxed()))
+        }
+
+        async fn handshake(
+            &self,
+            _: Request<Streaming<HandshakeRequest>>,
+        ) -> Result<Response<Self::HandshakeStream>, Status> {
+            Err(Status::unimplemented(""))
+        }
+        async fn list_flights(
+            &self,
+            _: Request<Criteria>,
+        ) -> Result<Response<Self::ListFlightsStream>, Status> {
+            Err(Status::unimplemented(""))
+        }
+        async fn get_flight_info(
+            &self,
+            _: Request<FlightDescriptor>,
+        ) -> Result<Response<FlightInfo>, Status> {
+            Err(Status::unimplemented(""))
+        }
+        async fn poll_flight_info(
+            &self,
+            _: Request<FlightDescriptor>,
+        ) -> Result<Response<PollInfo>, Status> {
+            Err(Status::unimplemented(""))
+        }
+        async fn get_schema(
+            &self,
+            _: Request<FlightDescriptor>,
+        ) -> Result<Response<SchemaResult>, Status> {
+            Err(Status::unimplemented(""))
+        }
+        async fn do_put(
+            &self,
+            _: Request<Streaming<FlightData>>,
+        ) -> Result<Response<Self::DoPutStream>, Status> {
+            Err(Status::unimplemented(""))
+        }
+        async fn do_action(
+            &self,
+            _: Request<Action>,
+        ) -> Result<Response<Self::DoActionStream>, Status> {
+            Err(Status::unimplemented(""))
+        }
+        async fn list_actions(
+            &self,
+            _: Request<Empty>,
+        ) -> Result<Response<Self::ListActionsStream>, Status> {
+            Err(Status::unimplemented(""))
+        }
+        async fn do_exchange(
+            &self,
+            _: Request<Streaming<FlightData>>,
+        ) -> Result<Response<Self::DoExchangeStream>, Status> {
+            Err(Status::unimplemented(""))
+        }
+    }
+
+    /// Spawns the test server on a background task and exposes a connected 
channel.
+    struct InterceptorTestFixture {
+        shutdown: Option<tokio::sync::oneshot::Sender<()>>,
+        addr: SocketAddr,
+        handle: Option<JoinHandle<Result<(), tonic::transport::Error>>>,
+    }
+
+    impl InterceptorTestFixture {
+        async fn new(server: FlightServiceServer<InterceptorTestServer>) -> 
Self {
+            let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
+            let addr = listener.local_addr().unwrap();
+            let (tx, rx) = tokio::sync::oneshot::channel();
+            let shutdown_future = async move {
+                rx.await.ok();
+            };
+            let serve = tonic::transport::Server::builder()
+                .timeout(Duration::from_secs(30))
+                .add_service(server)
+                .serve_with_incoming_shutdown(
+                    tokio_stream::wrappers::TcpListenerStream::new(listener),
+                    shutdown_future,
+                );
+            let handle = tokio::task::spawn(serve);
+            Self {
+                shutdown: Some(tx),
+                addr,
+                handle: Some(handle),
+            }
+        }
+
+        async fn channel(&self) -> Channel {
+            let url = format!("http://{}";, self.addr);
+            tonic::transport::Endpoint::from_shared(url)
+                .expect("valid endpoint")
+                .timeout(Duration::from_secs(30))
+                .connect()
+                .await
+                .expect("error connecting to server")
+        }
+
+        async fn shutdown_and_wait(mut self) {
+            if let Some(tx) = self.shutdown.take() {
+                tx.send(()).expect("server quit early");
+            }
+            if let Some(handle) = self.handle.take() {
+                handle
+                    .await
+                    .expect("task join error (panic?)")
+                    .expect("server error at shutdown");
+            }
+        }
+    }
+
+    /// Integration test: a tonic [`Channel`] wrapped in an 
[`InterceptedService`]
+    /// that injects a custom header is passed to [`FlightClient`], and the 
server
+    /// observes the header on the request.
+    #[tokio::test]
+    async fn 
test_flight_client_with_intercepted_channel_passes_custom_header() {
+        let test_server = InterceptorTestServer::default();
+        let fixture =
+            
InterceptorTestFixture::new(FlightServiceServer::new(test_server.clone())).await;
+
+        let channel = fixture.channel().await;
+
+        let header_name = "x-random-header";
+        let header_value = format!("random-{}", Uuid::new_v4());
+        let header_value_for_interceptor = header_value.clone();
+
+        let interceptor = move |mut req: Request<()>| -> Result<Request<()>, 
Status> {
+            req.metadata_mut().insert(
+                header_name,
+                header_value_for_interceptor
+                    .parse()
+                    .expect("valid metadata value"),
+            );
+            Ok(req)
+        };
+
+        let intercepted = InterceptedService::new(channel, interceptor);
+        let mut client = FlightClient::new(intercepted);
+
+        let ticket = Ticket {
+            ticket: Bytes::from("dummy-ticket"),
+        };
+
+        let batch = RecordBatch::try_from_iter(vec![(
+            "col",
+            Arc::new(UInt64Array::from_iter([1, 2, 3, 4])) as _,
+        )])
+        .unwrap();
+
+        test_server.set_do_get_response(vec![Ok(batch.clone())]);
+
+        let response_stream = client
+            .do_get(ticket.clone())
+            .await
+            .expect("error making do_get request");
+
+        let response: Vec<RecordBatch> = response_stream
+            .try_collect()
+            .await
+            .expect("error streaming data");
+
+        assert_eq!(response, vec![batch]);
+        assert_eq!(test_server.take_do_get_request(), Some(ticket));
+
+        let metadata = test_server
+            .take_last_request_metadata()
+            .expect("server received headers")
+            .into_headers();
+
+        let received = metadata
+            .get(header_name)
+            .expect("interceptor header missing on server")
+            .to_str()
+            .expect("ascii header value");
+        assert_eq!(received, header_value);
+
+        fixture.shutdown_and_wait().await;
+    }
+}

Reply via email to