This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 969dd290ab GH-33953: [Java] Pass custom headers on every request
(#33967)
969dd290ab is described below
commit 969dd290ab93c8629b643134f150e71c1c05c2ff
Author: Diego Fernández Giraldo <[email protected]>
AuthorDate: Mon Feb 27 06:41:23 2023 -0700
GH-33953: [Java] Pass custom headers on every request (#33967)
### Rationale for this change
Some flight requests don't send custom headers. This PR should fix that.
### What changes are included in this PR?
Ensure custom headers are sent across on every request.
### Are these changes tested?
No
### Are there any user-facing changes?
Custom headers should now be attached to every call.
* Closes: #33953
* Closes: #33839
Authored-by: Diego Fernandez <[email protected]>
Signed-off-by: David Li <[email protected]>
---
.../java/org/apache/arrow/flight/FlightClient.java | 25 ++-
.../org/apache/arrow/flight/HeaderCallOption.java | 2 +-
.../arrow/flight/client/CustomHeaderTest.java | 230 +++++++++++++++++++++
3 files changed, 247 insertions(+), 10 deletions(-)
diff --git
a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java
b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java
index 1f50f50a29..8cf2bbaf25 100644
---
a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java
+++
b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java
@@ -79,7 +79,7 @@ public class FlightClient implements AutoCloseable {
private static final int MAX_CHANNEL_TRACE_EVENTS = 0;
private final BufferAllocator allocator;
private final ManagedChannel channel;
- private final Channel interceptedChannel;
+
private final FlightServiceBlockingStub blockingStub;
private final FlightServiceStub asyncStub;
private final ClientAuthInterceptor authInterceptor = new
ClientAuthInterceptor();
@@ -101,7 +101,7 @@ public class FlightClient implements AutoCloseable {
interceptors = new ClientInterceptor[]{authInterceptor, new
ClientInterceptorAdapter(middleware)};
// Create a channel with interceptors pre-applied for DoGet and DoPut
- this.interceptedChannel = ClientInterceptors.intercept(channel,
interceptors);
+ Channel interceptedChannel = ClientInterceptors.intercept(channel,
interceptors);
blockingStub = FlightServiceGrpc.newBlockingStub(interceptedChannel);
asyncStub = FlightServiceGrpc.newStub(interceptedChannel);
@@ -255,13 +255,12 @@ public class FlightClient implements AutoCloseable {
CallOption... options) {
Preconditions.checkNotNull(descriptor, "descriptor must not be null");
Preconditions.checkNotNull(metadataListener, "metadataListener must not be
null");
- final io.grpc.CallOptions callOptions = CallOptions.wrapStub(asyncStub,
options).getCallOptions();
try {
+ final ClientCall<ArrowMessage, Flight.PutResult> call =
asyncStubNewCall(doPutDescriptor, options);
final SetStreamObserver resultObserver = new
SetStreamObserver(allocator, metadataListener);
ClientCallStreamObserver<ArrowMessage> observer =
(ClientCallStreamObserver<ArrowMessage>)
- ClientCalls.asyncBidiStreamingCall(
- interceptedChannel.newCall(doPutDescriptor, callOptions),
resultObserver);
+ ClientCalls.asyncBidiStreamingCall(call, resultObserver);
return new PutObserver(
descriptor, observer, metadataListener::isCancelled,
metadataListener::getResult);
} catch (StatusRuntimeException sre) {
@@ -306,8 +305,7 @@ public class FlightClient implements AutoCloseable {
* @param options RPC-layer hints for this call.
*/
public FlightStream getStream(Ticket ticket, CallOption... options) {
- final io.grpc.CallOptions callOptions = CallOptions.wrapStub(asyncStub,
options).getCallOptions();
- ClientCall<Flight.Ticket, ArrowMessage> call =
interceptedChannel.newCall(doGetDescriptor, callOptions);
+ final ClientCall<Flight.Ticket, ArrowMessage> call =
asyncStubNewCall(doGetDescriptor, options);
FlightStream stream = new FlightStream(
allocator,
PENDING_REQUESTS,
@@ -353,10 +351,9 @@ public class FlightClient implements AutoCloseable {
*/
public ExchangeReaderWriter doExchange(FlightDescriptor descriptor,
CallOption... options) {
Preconditions.checkNotNull(descriptor, "descriptor must not be null");
- final io.grpc.CallOptions callOptions = CallOptions.wrapStub(asyncStub,
options).getCallOptions();
try {
- final ClientCall<ArrowMessage, ArrowMessage> call =
interceptedChannel.newCall(doExchangeDescriptor, callOptions);
+ final ClientCall<ArrowMessage, ArrowMessage> call =
asyncStubNewCall(doExchangeDescriptor, options);
final FlightStream stream = new FlightStream(allocator,
PENDING_REQUESTS, call::cancel, call::request);
final ClientCallStreamObserver<ArrowMessage> observer =
(ClientCallStreamObserver<ArrowMessage>)
ClientCalls.asyncBidiStreamingCall(call, stream.asObserver());
@@ -723,4 +720,14 @@ public class FlightClient implements AutoCloseable {
return new FlightClient(allocator, builder.build(), middleware);
}
}
+
+ /**
+ * Helper method to create a call from the asyncStub, method descriptor, and
list of calling options.
+ */
+ private <RequestT, ResponseT> ClientCall<RequestT, ResponseT>
asyncStubNewCall(
+ MethodDescriptor<RequestT, ResponseT> descriptor,
+ CallOption... options) {
+ FlightServiceStub wrappedStub = CallOptions.wrapStub(asyncStub, options);
+ return wrappedStub.getChannel().newCall(descriptor,
wrappedStub.getCallOptions());
+ }
}
diff --git
a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/HeaderCallOption.java
b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/HeaderCallOption.java
index e2fad1a402..1a04ca3d08 100644
---
a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/HeaderCallOption.java
+++
b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/HeaderCallOption.java
@@ -47,6 +47,6 @@ public class HeaderCallOption implements
CallOptions.GrpcCallOption {
@Override
public <T extends AbstractStub<T>> T wrapStub(T stub) {
- return MetadataUtils.attachHeaders(stub, propertiesMetadata);
+ return
stub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(propertiesMetadata));
}
}
diff --git
a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/client/CustomHeaderTest.java
b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/client/CustomHeaderTest.java
new file mode 100644
index 0000000000..a320d949cd
--- /dev/null
+++
b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/client/CustomHeaderTest.java
@@ -0,0 +1,230 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.client;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import org.apache.arrow.flight.Action;
+import org.apache.arrow.flight.CallHeaders;
+import org.apache.arrow.flight.CallInfo;
+import org.apache.arrow.flight.CallStatus;
+import org.apache.arrow.flight.Criteria;
+import org.apache.arrow.flight.FlightCallHeaders;
+import org.apache.arrow.flight.FlightClient;
+import org.apache.arrow.flight.FlightClient.ClientStreamListener;
+import org.apache.arrow.flight.FlightDescriptor;
+import org.apache.arrow.flight.FlightMethod;
+import org.apache.arrow.flight.FlightServer;
+import org.apache.arrow.flight.FlightServerMiddleware;
+import org.apache.arrow.flight.FlightStream;
+import org.apache.arrow.flight.FlightTestUtil;
+import org.apache.arrow.flight.HeaderCallOption;
+import org.apache.arrow.flight.NoOpFlightProducer;
+import org.apache.arrow.flight.RequestContext;
+import org.apache.arrow.flight.SyncPutListener;
+import org.apache.arrow.flight.Ticket;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.util.AutoCloseables;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import com.google.common.collect.ImmutableMap;
+
+
+/**
+ * Tests to ensure custom headers are passed along to the server for each
command.
+ */
+public class CustomHeaderTest {
+ FlightServer server;
+ FlightClient client;
+ BufferAllocator allocator;
+ TestCustomHeaderMiddleware.Factory headersMiddleware;
+ HeaderCallOption headers;
+ Map<String, String> testHeaders = ImmutableMap.of(
+ "foo", "bar",
+ "bar", "foo",
+ "answer", "42"
+ );
+
+ @Before
+ public void setUp() throws Exception {
+ allocator = new RootAllocator(Integer.MAX_VALUE);
+ headersMiddleware = new TestCustomHeaderMiddleware.Factory();
+ FlightCallHeaders callHeaders = new FlightCallHeaders();
+ for (Map.Entry<String, String> entry : testHeaders.entrySet()) {
+ callHeaders.insert(entry.getKey(), entry.getValue());
+ }
+ headers = new HeaderCallOption(callHeaders);
+ server = FlightTestUtil.getStartedServer(location ->
+ FlightServer.builder(allocator, location, new NoOpFlightProducer())
+ .middleware(FlightServerMiddleware.Key.of("customHeader"),
headersMiddleware)
+ .build());
+ client = FlightClient.builder(allocator, server.getLocation()).build();
+ }
+
+ @After
+ public void tearDown() throws Exception {
+ allocator.getChildAllocators().forEach(BufferAllocator::close);
+ AutoCloseables.close(allocator, server, client);
+ }
+
+ @Test
+ public void testHandshake() {
+ try {
+ client.handshake(headers);
+ } catch (Exception ignored) { }
+
+ assertHeadersMatch(FlightMethod.HANDSHAKE);
+ }
+
+ @Test
+ public void testGetSchema() {
+ try {
+ client.getSchema(FlightDescriptor.command(new byte[0]), headers);
+ } catch (Exception ignored) { }
+
+ assertHeadersMatch(FlightMethod.GET_SCHEMA);
+ }
+
+ @Test
+ public void testGetFlightInfo() {
+ try {
+ client.getInfo(FlightDescriptor.command(new byte[0]), headers);
+ } catch (Exception ignored) { }
+
+ assertHeadersMatch(FlightMethod.GET_FLIGHT_INFO);
+ }
+
+ @Test
+ public void testListActions() {
+ try {
+ client.listActions(headers).iterator().next();
+ } catch (Exception ignored) { }
+
+ assertHeadersMatch(FlightMethod.LIST_ACTIONS);
+ }
+
+ @Test
+ public void testListFlights() {
+ try {
+ client.listFlights(new Criteria(new byte[]{1}),
headers).iterator().next();
+ } catch (Exception ignored) { }
+
+ assertHeadersMatch(FlightMethod.LIST_FLIGHTS);
+ }
+
+ @Test
+ public void testDoAction() {
+ try {
+ client.doAction(new Action("test"), headers).next();
+ } catch (Exception ignored) { }
+
+ assertHeadersMatch(FlightMethod.DO_ACTION);
+ }
+
+ @Test
+ public void testStartPut() {
+ try {
+ final ClientStreamListener listener =
client.startPut(FlightDescriptor.command(new byte[0]),
+ new SyncPutListener(),
+ headers);
+ listener.getResult();
+ } catch (Exception ignored) { }
+
+ assertHeadersMatch(FlightMethod.DO_PUT);
+ }
+
+ @Test
+ public void testGetStream() {
+ try (final FlightStream stream = client.getStream(new Ticket(new byte[0]),
headers)) {
+ stream.next();
+ } catch (Exception ignored) { }
+
+ assertHeadersMatch(FlightMethod.DO_GET);
+ }
+
+ @Test
+ public void testDoExchange() {
+ try (final FlightClient.ExchangeReaderWriter stream = client.doExchange(
+ FlightDescriptor.command(new byte[0]),
+ headers)
+ ) {
+ stream.getReader().next();
+ } catch (Exception ignored) { }
+
+ assertHeadersMatch(FlightMethod.DO_EXCHANGE);
+ }
+
+ private void assertHeadersMatch(FlightMethod method) {
+ for (Map.Entry<String, String> entry : testHeaders.entrySet()) {
+ Assert.assertEquals(entry.getValue(),
headersMiddleware.getCustomHeader(method, entry.getKey()));
+ }
+ }
+
+ /**
+ * A middleware used to test if customHeaders are being sent to the server
properly.
+ */
+ static class TestCustomHeaderMiddleware implements FlightServerMiddleware {
+
+ public TestCustomHeaderMiddleware() {
+ }
+
+ @Override
+ public void onBeforeSendingHeaders(CallHeaders callHeaders) {
+
+ }
+
+ @Override
+ public void onCallCompleted(CallStatus callStatus) {
+
+ }
+
+ @Override
+ public void onCallErrored(Throwable throwable) {
+
+ }
+
+ /**
+ * A factory for the middleware that keeps track of the received headers
and provides a way
+ * to check those values for a given Flight Method.
+ */
+ static class Factory implements
FlightServerMiddleware.Factory<TestCustomHeaderMiddleware> {
+ private final Map<FlightMethod, CallHeaders> receivedCallHeaders = new
HashMap<>();
+
+ @Override
+ public TestCustomHeaderMiddleware onCallStarted(CallInfo callInfo,
CallHeaders callHeaders,
+ RequestContext
requestContext) {
+
+ receivedCallHeaders.put(callInfo.method(), callHeaders);
+ return new TestCustomHeaderMiddleware();
+ }
+
+ public String getCustomHeader(FlightMethod method, String key) {
+ CallHeaders headers = receivedCallHeaders.get(method);
+ if (headers == null) {
+ return null;
+ }
+ return headers.get(key);
+ }
+ }
+ }
+}