This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 969dd290ab GH-33953: [Java] Pass custom headers on every request 
(#33967)
969dd290ab is described below

commit 969dd290ab93c8629b643134f150e71c1c05c2ff
Author: Diego Fernández Giraldo <[email protected]>
AuthorDate: Mon Feb 27 06:41:23 2023 -0700

    GH-33953: [Java] Pass custom headers on every request (#33967)
    
    
    
    ### Rationale for this change
    
    Some flight requests don't send custom headers. This PR should fix that.
    
    ### What changes are included in this PR?
    
    Ensure custom headers are sent across on every request.
    
    ### Are these changes tested?
    
    No
    
    ### Are there any user-facing changes?
    
    Custom headers should now be attached to every call.
    
    * Closes: #33953
    * Closes: #33839
    
    Authored-by: Diego Fernandez <[email protected]>
    Signed-off-by: David Li <[email protected]>
---
 .../java/org/apache/arrow/flight/FlightClient.java |  25 ++-
 .../org/apache/arrow/flight/HeaderCallOption.java  |   2 +-
 .../arrow/flight/client/CustomHeaderTest.java      | 230 +++++++++++++++++++++
 3 files changed, 247 insertions(+), 10 deletions(-)

diff --git 
a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java
 
b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java
index 1f50f50a29..8cf2bbaf25 100644
--- 
a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java
+++ 
b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java
@@ -79,7 +79,7 @@ public class FlightClient implements AutoCloseable {
   private static final int MAX_CHANNEL_TRACE_EVENTS = 0;
   private final BufferAllocator allocator;
   private final ManagedChannel channel;
-  private final Channel interceptedChannel;
+
   private final FlightServiceBlockingStub blockingStub;
   private final FlightServiceStub asyncStub;
   private final ClientAuthInterceptor authInterceptor = new 
ClientAuthInterceptor();
@@ -101,7 +101,7 @@ public class FlightClient implements AutoCloseable {
     interceptors = new ClientInterceptor[]{authInterceptor, new 
ClientInterceptorAdapter(middleware)};
 
     // Create a channel with interceptors pre-applied for DoGet and DoPut
-    this.interceptedChannel = ClientInterceptors.intercept(channel, 
interceptors);
+    Channel interceptedChannel = ClientInterceptors.intercept(channel, 
interceptors);
 
     blockingStub = FlightServiceGrpc.newBlockingStub(interceptedChannel);
     asyncStub = FlightServiceGrpc.newStub(interceptedChannel);
@@ -255,13 +255,12 @@ public class FlightClient implements AutoCloseable {
                                        CallOption... options) {
     Preconditions.checkNotNull(descriptor, "descriptor must not be null");
     Preconditions.checkNotNull(metadataListener, "metadataListener must not be 
null");
-    final io.grpc.CallOptions callOptions = CallOptions.wrapStub(asyncStub, 
options).getCallOptions();
 
     try {
+      final ClientCall<ArrowMessage, Flight.PutResult> call = 
asyncStubNewCall(doPutDescriptor, options);
       final SetStreamObserver resultObserver = new 
SetStreamObserver(allocator, metadataListener);
       ClientCallStreamObserver<ArrowMessage> observer = 
(ClientCallStreamObserver<ArrowMessage>)
-          ClientCalls.asyncBidiStreamingCall(
-              interceptedChannel.newCall(doPutDescriptor, callOptions), 
resultObserver);
+          ClientCalls.asyncBidiStreamingCall(call, resultObserver);
       return new PutObserver(
           descriptor, observer, metadataListener::isCancelled, 
metadataListener::getResult);
     } catch (StatusRuntimeException sre) {
@@ -306,8 +305,7 @@ public class FlightClient implements AutoCloseable {
    * @param options RPC-layer hints for this call.
    */
   public FlightStream getStream(Ticket ticket, CallOption... options) {
-    final io.grpc.CallOptions callOptions = CallOptions.wrapStub(asyncStub, 
options).getCallOptions();
-    ClientCall<Flight.Ticket, ArrowMessage> call = 
interceptedChannel.newCall(doGetDescriptor, callOptions);
+    final ClientCall<Flight.Ticket, ArrowMessage> call = 
asyncStubNewCall(doGetDescriptor, options);
     FlightStream stream = new FlightStream(
         allocator,
         PENDING_REQUESTS,
@@ -353,10 +351,9 @@ public class FlightClient implements AutoCloseable {
    */
   public ExchangeReaderWriter doExchange(FlightDescriptor descriptor, 
CallOption... options) {
     Preconditions.checkNotNull(descriptor, "descriptor must not be null");
-    final io.grpc.CallOptions callOptions = CallOptions.wrapStub(asyncStub, 
options).getCallOptions();
 
     try {
-      final ClientCall<ArrowMessage, ArrowMessage> call = 
interceptedChannel.newCall(doExchangeDescriptor, callOptions);
+      final ClientCall<ArrowMessage, ArrowMessage> call = 
asyncStubNewCall(doExchangeDescriptor, options);
       final FlightStream stream = new FlightStream(allocator, 
PENDING_REQUESTS, call::cancel, call::request);
       final ClientCallStreamObserver<ArrowMessage> observer = 
(ClientCallStreamObserver<ArrowMessage>)
               ClientCalls.asyncBidiStreamingCall(call, stream.asObserver());
@@ -723,4 +720,14 @@ public class FlightClient implements AutoCloseable {
       return new FlightClient(allocator, builder.build(), middleware);
     }
   }
+
+  /**
+   * Helper method to create a call from the asyncStub, method descriptor, and 
list of calling options.
+   */
+  private <RequestT, ResponseT> ClientCall<RequestT, ResponseT> 
asyncStubNewCall(
+          MethodDescriptor<RequestT, ResponseT> descriptor,
+          CallOption... options) {
+    FlightServiceStub wrappedStub = CallOptions.wrapStub(asyncStub, options);
+    return wrappedStub.getChannel().newCall(descriptor, 
wrappedStub.getCallOptions());
+  }
 }
diff --git 
a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/HeaderCallOption.java
 
b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/HeaderCallOption.java
index e2fad1a402..1a04ca3d08 100644
--- 
a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/HeaderCallOption.java
+++ 
b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/HeaderCallOption.java
@@ -47,6 +47,6 @@ public class HeaderCallOption implements 
CallOptions.GrpcCallOption {
 
   @Override
   public <T extends AbstractStub<T>> T wrapStub(T stub) {
-    return MetadataUtils.attachHeaders(stub, propertiesMetadata);
+    return 
stub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(propertiesMetadata));
   }
 }
diff --git 
a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/client/CustomHeaderTest.java
 
b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/client/CustomHeaderTest.java
new file mode 100644
index 0000000000..a320d949cd
--- /dev/null
+++ 
b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/client/CustomHeaderTest.java
@@ -0,0 +1,230 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.client;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import org.apache.arrow.flight.Action;
+import org.apache.arrow.flight.CallHeaders;
+import org.apache.arrow.flight.CallInfo;
+import org.apache.arrow.flight.CallStatus;
+import org.apache.arrow.flight.Criteria;
+import org.apache.arrow.flight.FlightCallHeaders;
+import org.apache.arrow.flight.FlightClient;
+import org.apache.arrow.flight.FlightClient.ClientStreamListener;
+import org.apache.arrow.flight.FlightDescriptor;
+import org.apache.arrow.flight.FlightMethod;
+import org.apache.arrow.flight.FlightServer;
+import org.apache.arrow.flight.FlightServerMiddleware;
+import org.apache.arrow.flight.FlightStream;
+import org.apache.arrow.flight.FlightTestUtil;
+import org.apache.arrow.flight.HeaderCallOption;
+import org.apache.arrow.flight.NoOpFlightProducer;
+import org.apache.arrow.flight.RequestContext;
+import org.apache.arrow.flight.SyncPutListener;
+import org.apache.arrow.flight.Ticket;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.util.AutoCloseables;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import com.google.common.collect.ImmutableMap;
+
+
+/**
+ * Tests to ensure custom headers are passed along to the server for each 
command.
+ */
+public class CustomHeaderTest {
+  FlightServer server;
+  FlightClient client;
+  BufferAllocator allocator;
+  TestCustomHeaderMiddleware.Factory headersMiddleware;
+  HeaderCallOption headers;
+  Map<String, String> testHeaders = ImmutableMap.of(
+          "foo", "bar",
+          "bar", "foo",
+          "answer", "42"
+  );
+
+  @Before
+  public void setUp() throws Exception {
+    allocator = new RootAllocator(Integer.MAX_VALUE);
+    headersMiddleware = new TestCustomHeaderMiddleware.Factory();
+    FlightCallHeaders callHeaders = new FlightCallHeaders();
+    for (Map.Entry<String, String> entry : testHeaders.entrySet()) {
+      callHeaders.insert(entry.getKey(), entry.getValue());
+    }
+    headers = new HeaderCallOption(callHeaders);
+    server = FlightTestUtil.getStartedServer(location ->
+            FlightServer.builder(allocator, location, new NoOpFlightProducer())
+                    .middleware(FlightServerMiddleware.Key.of("customHeader"), 
headersMiddleware)
+                    .build());
+    client = FlightClient.builder(allocator, server.getLocation()).build();
+  }
+
+  @After
+  public void tearDown() throws Exception {
+    allocator.getChildAllocators().forEach(BufferAllocator::close);
+    AutoCloseables.close(allocator, server, client);
+  }
+
+  @Test
+  public void testHandshake() {
+    try {
+      client.handshake(headers);
+    } catch (Exception ignored) { }
+
+    assertHeadersMatch(FlightMethod.HANDSHAKE);
+  }
+
+  @Test
+  public void testGetSchema() {
+    try {
+      client.getSchema(FlightDescriptor.command(new byte[0]), headers);
+    } catch (Exception ignored) { }
+
+    assertHeadersMatch(FlightMethod.GET_SCHEMA);
+  }
+
+  @Test
+  public void testGetFlightInfo() {
+    try {
+      client.getInfo(FlightDescriptor.command(new byte[0]), headers);
+    } catch (Exception ignored) { }
+
+    assertHeadersMatch(FlightMethod.GET_FLIGHT_INFO);
+  }
+
+  @Test
+  public void testListActions() {
+    try {
+      client.listActions(headers).iterator().next();
+    } catch (Exception ignored) { }
+
+    assertHeadersMatch(FlightMethod.LIST_ACTIONS);
+  }
+
+  @Test
+  public void testListFlights() {
+    try {
+      client.listFlights(new Criteria(new byte[]{1}), 
headers).iterator().next();
+    } catch (Exception ignored) { }
+
+    assertHeadersMatch(FlightMethod.LIST_FLIGHTS);
+  }
+
+  @Test
+  public void testDoAction() {
+    try {
+      client.doAction(new Action("test"), headers).next();
+    } catch (Exception ignored) { }
+
+    assertHeadersMatch(FlightMethod.DO_ACTION);
+  }
+
+  @Test
+  public void testStartPut() {
+    try {
+      final ClientStreamListener listener = 
client.startPut(FlightDescriptor.command(new byte[0]),
+              new SyncPutListener(),
+              headers);
+      listener.getResult();
+    } catch (Exception ignored) { }
+
+    assertHeadersMatch(FlightMethod.DO_PUT);
+  }
+
+  @Test
+  public void testGetStream() {
+    try (final FlightStream stream = client.getStream(new Ticket(new byte[0]), 
headers)) {
+      stream.next();
+    } catch (Exception ignored) { }
+
+    assertHeadersMatch(FlightMethod.DO_GET);
+  }
+
+  @Test
+  public void testDoExchange() {
+    try (final FlightClient.ExchangeReaderWriter stream = client.doExchange(
+            FlightDescriptor.command(new byte[0]),
+            headers)
+    ) {
+      stream.getReader().next();
+    } catch (Exception ignored) { }
+
+    assertHeadersMatch(FlightMethod.DO_EXCHANGE);
+  }
+
+  private void assertHeadersMatch(FlightMethod method) {
+    for (Map.Entry<String, String> entry : testHeaders.entrySet()) {
+      Assert.assertEquals(entry.getValue(), 
headersMiddleware.getCustomHeader(method, entry.getKey()));
+    }
+  }
+
+  /**
+   * A middleware used to test if customHeaders are being sent to the server 
properly.
+   */
+  static class TestCustomHeaderMiddleware implements FlightServerMiddleware {
+
+    public TestCustomHeaderMiddleware() {
+    }
+
+    @Override
+    public void onBeforeSendingHeaders(CallHeaders callHeaders) {
+
+    }
+
+    @Override
+    public void onCallCompleted(CallStatus callStatus) {
+
+    }
+
+    @Override
+    public void onCallErrored(Throwable throwable) {
+
+    }
+
+    /**
+     * A factory for the middleware that keeps track of the received headers 
and provides a way
+     * to check those values for a given Flight Method.
+     */
+    static class Factory implements 
FlightServerMiddleware.Factory<TestCustomHeaderMiddleware> {
+      private final Map<FlightMethod, CallHeaders> receivedCallHeaders = new 
HashMap<>();
+
+      @Override
+      public TestCustomHeaderMiddleware onCallStarted(CallInfo callInfo, 
CallHeaders callHeaders,
+                                                      RequestContext 
requestContext) {
+
+        receivedCallHeaders.put(callInfo.method(), callHeaders);
+        return new TestCustomHeaderMiddleware();
+      }
+
+      public String getCustomHeader(FlightMethod method, String key) {
+        CallHeaders headers = receivedCallHeaders.get(method);
+        if (headers == null) {
+          return null;
+        }
+        return headers.get(key);
+      }
+    }
+  }   
+}

Reply via email to