This is an automated email from the ASF dual-hosted git repository. emkornfield pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push: new d8b8cc7 ARROW-8555: [FlightRPC][Java] implement DoExchange d8b8cc7 is described below commit d8b8cc7d714a6e7eae4eb4ebc4edc3649de4bef2 Author: David Li <li.david...@gmail.com> AuthorDate: Thu May 14 20:48:23 2020 -0700 ARROW-8555: [FlightRPC][Java] implement DoExchange This is a complete implementation of DoExchange for Java. It is not tested against the C++ implementation yet, however, it still passes integration tests, so the internal refactoring should not have broken compatibility with existing clients/servers. In this PR, I've refactored DoGet/DoPut/DoExchange on the client and server to share their implementation as much as possible. DoGet/DoPut retain their behavior of "eagerly" reading/writing schemas, but DoExchange allows the client/server to delay writing the schema until ready. This is checked in the unit tests. I also ran into some test flakes and tried to address them, by making sure we clean up things in the right order, and adding missing `close()` calls in some existing tests. Closes #7012 from lidavidm/doexchange-java Authored-by: David Li <li.david...@gmail.com> Signed-off-by: Micah Kornfield <emkornfi...@gmail.com> --- java/flight/flight-core/pom.xml | 8 + .../java/org/apache/arrow/flight/ArrowMessage.java | 35 +- .../java/org/apache/arrow/flight/CallStatus.java | 2 +- .../apache/arrow/flight/FlightBindingService.java | 35 +- .../java/org/apache/arrow/flight/FlightClient.java | 197 ++++++---- .../java/org/apache/arrow/flight/FlightMethod.java | 3 + .../org/apache/arrow/flight/FlightProducer.java | 50 +-- .../java/org/apache/arrow/flight/FlightServer.java | 6 +- .../org/apache/arrow/flight/FlightService.java | 175 +++++---- .../java/org/apache/arrow/flight/FlightStream.java | 105 ++++-- .../arrow/flight/OutboundStreamListener.java | 82 +++++ .../arrow/flight/OutboundStreamListenerImpl.java | 119 +++++++ .../apache/arrow/flight/TestBasicOperation.java | 19 +- .../org/apache/arrow/flight/TestDoExchange.java | 395 +++++++++++++++++++++ .../org/apache/arrow/flight/TestErrorMetadata.java | 10 +- .../org/apache/arrow/flight/TestServerOptions.java | 17 +- 16 files changed, 1018 insertions(+), 240 deletions(-) diff --git a/java/flight/flight-core/pom.xml b/java/flight/flight-core/pom.xml index 8301c71..43ac6cc 100644 --- a/java/flight/flight-core/pom.xml +++ b/java/flight/flight-core/pom.xml @@ -132,6 +132,14 @@ <version>1.12.0</version> <scope>test</scope> </dependency> + <dependency> + <groupId>org.apache.arrow</groupId> + <artifactId>arrow-vector</artifactId> + <version>${project.version}</version> + <classifier>tests</classifier> + <type>test-jar</type> + <scope>test</scope> + </dependency> </dependencies> <build> <extensions> diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java index fd59dd5..1758215 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java @@ -154,6 +154,24 @@ class ArrowMessage implements AutoCloseable { this.appMetadata = null; } + /** + * Create an ArrowMessage containing only application metadata. + * @param appMetadata The application-provided metadata buffer. + */ + public ArrowMessage(ArrowBuf appMetadata) { + this.message = null; + this.bufs = ImmutableList.of(); + this.descriptor = null; + this.appMetadata = appMetadata; + } + + public ArrowMessage(FlightDescriptor descriptor) { + this.message = null; + this.bufs = ImmutableList.of(); + this.descriptor = descriptor; + this.appMetadata = null; + } + private ArrowMessage(FlightDescriptor descriptor, MessageMetadataResult message, ArrowBuf appMetadata, ArrowBuf buf) { this.message = message; @@ -171,6 +189,10 @@ class ArrowMessage implements AutoCloseable { } public HeaderType getMessageType() { + if (message == null) { + // Null message occurs for metadata-only messages (in DoExchange) + return HeaderType.NONE; + } return HeaderType.getHeader(message.headerType()); } @@ -271,8 +293,19 @@ class ArrowMessage implements AutoCloseable { * @return InputStream */ private InputStream asInputStream(BufferAllocator allocator) { - try { + if (message == null) { + // If we have no IPC message, it's a pure-metadata message + final FlightData.Builder builder = FlightData.newBuilder(); + if (descriptor != null) { + builder.setFlightDescriptor(descriptor); + } + if (appMetadata != null) { + builder.setAppMetadata(ByteString.copyFrom(appMetadata.nioBuffer())); + } + return NO_BODY_MARSHALLER.stream(builder.build()); + } + try { final ByteString bytes = ByteString.copyFrom(message.getMessageBuffer(), message.bytesAfterMessage()); diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallStatus.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallStatus.java index a43b824..991d0ed 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallStatus.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallStatus.java @@ -137,7 +137,7 @@ public class CallStatus { "code=" + code + ", cause=" + cause + ", description='" + description + - ", metadata='" + metadata + '\'' + + "', metadata='" + metadata + '\'' + '}'; } } diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightBindingService.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightBindingService.java index 13051e7..ba5249b 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightBindingService.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightBindingService.java @@ -45,7 +45,9 @@ class FlightBindingService implements BindableService { private static final String DO_GET = MethodDescriptor.generateFullMethodName(FlightConstants.SERVICE, "DoGet"); private static final String DO_PUT = MethodDescriptor.generateFullMethodName(FlightConstants.SERVICE, "DoPut"); - private static final Set<String> OVERRIDE_METHODS = ImmutableSet.of(DO_GET, DO_PUT); + private static final String DO_EXCHANGE = MethodDescriptor.generateFullMethodName( + FlightConstants.SERVICE, "DoExchange"); + private static final Set<String> OVERRIDE_METHODS = ImmutableSet.of(DO_GET, DO_PUT, DO_EXCHANGE); private final FlightService delegate; private final BufferAllocator allocator; @@ -78,19 +80,31 @@ class FlightBindingService implements BindableService { .build(); } + public static MethodDescriptor<ArrowMessage, ArrowMessage> getDoExchangeDescriptor(BufferAllocator allocator) { + return MethodDescriptor.<ArrowMessage, ArrowMessage>newBuilder() + .setType(MethodType.BIDI_STREAMING) + .setFullMethodName(DO_EXCHANGE) + .setSampledToLocalTracing(false) + .setRequestMarshaller(ArrowMessage.createMarshaller(allocator)) + .setResponseMarshaller(ArrowMessage.createMarshaller(allocator)) + .setSchemaDescriptor(FlightServiceGrpc.getDoExchangeMethod().getSchemaDescriptor()) + .build(); + } + @Override public ServerServiceDefinition bindService() { final ServerServiceDefinition baseDefinition = delegate.bindService(); final MethodDescriptor<Flight.Ticket, ArrowMessage> doGetDescriptor = getDoGetDescriptor(allocator); - final MethodDescriptor<ArrowMessage, Flight.PutResult> doPutDescriptor = getDoPutDescriptor(allocator); + final MethodDescriptor<ArrowMessage, ArrowMessage> doExchangeDescriptor = getDoExchangeDescriptor(allocator); // Make sure we preserve SchemaDescriptor fields on methods so that gRPC reflection still works. final ServiceDescriptor.Builder serviceDescriptorBuilder = ServiceDescriptor.newBuilder(FlightConstants.SERVICE) .setSchemaDescriptor(baseDefinition.getServiceDescriptor().getSchemaDescriptor()); serviceDescriptorBuilder.addMethod(doGetDescriptor); serviceDescriptorBuilder.addMethod(doPutDescriptor); + serviceDescriptorBuilder.addMethod(doExchangeDescriptor); for (MethodDescriptor<?, ?> definition : baseDefinition.getServiceDescriptor().getMethods()) { if (OVERRIDE_METHODS.contains(definition.getFullMethodName())) { continue; @@ -103,6 +117,7 @@ class FlightBindingService implements BindableService { ServerServiceDefinition.Builder serviceBuilder = ServerServiceDefinition.builder(serviceDescriptor); serviceBuilder.addMethod(doGetDescriptor, ServerCalls.asyncServerStreamingCall(new DoGetMethod(delegate))); serviceBuilder.addMethod(doPutDescriptor, ServerCalls.asyncBidiStreamingCall(new DoPutMethod(delegate))); + serviceBuilder.addMethod(doExchangeDescriptor, ServerCalls.asyncBidiStreamingCall(new DoExchangeMethod(delegate))); // copy over not-overridden methods. for (ServerMethodDefinition<?, ?> definition : baseDefinition.getMethods()) { @@ -116,7 +131,7 @@ class FlightBindingService implements BindableService { return serviceBuilder.build(); } - private class DoGetMethod implements ServerCalls.ServerStreamingMethod<Flight.Ticket, ArrowMessage> { + private static class DoGetMethod implements ServerCalls.ServerStreamingMethod<Flight.Ticket, ArrowMessage> { private final FlightService delegate; @@ -130,7 +145,7 @@ class FlightBindingService implements BindableService { } } - private class DoPutMethod implements ServerCalls.BidiStreamingMethod<ArrowMessage, PutResult> { + private static class DoPutMethod implements ServerCalls.BidiStreamingMethod<ArrowMessage, PutResult> { private final FlightService delegate; public DoPutMethod(FlightService delegate) { @@ -141,7 +156,19 @@ class FlightBindingService implements BindableService { public StreamObserver<ArrowMessage> invoke(StreamObserver<PutResult> responseObserver) { return delegate.doPutCustom(responseObserver); } + } + + private static class DoExchangeMethod implements ServerCalls.BidiStreamingMethod<ArrowMessage, ArrowMessage> { + private final FlightService delegate; + public DoExchangeMethod(FlightService delegate) { + this.delegate = delegate; + } + + @Override + public StreamObserver<ArrowMessage> invoke(StreamObserver<ArrowMessage> responseObserver) { + return delegate.doExchangeCustom(responseObserver); + } } } diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java index 93f89f9..fe9cfe2 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java @@ -22,7 +22,9 @@ import java.net.URISyntaxException; import java.util.ArrayList; import java.util.Iterator; import java.util.List; +import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import java.util.function.BooleanSupplier; import javax.net.ssl.SSLException; @@ -38,14 +40,11 @@ import org.apache.arrow.flight.impl.Flight.Empty; import org.apache.arrow.flight.impl.FlightServiceGrpc; import org.apache.arrow.flight.impl.FlightServiceGrpc.FlightServiceBlockingStub; import org.apache.arrow.flight.impl.FlightServiceGrpc.FlightServiceStub; -import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.VectorUnloader; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider; -import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; import io.grpc.Channel; import io.grpc.ClientCall; @@ -79,6 +78,7 @@ public class FlightClient implements AutoCloseable { private final ClientAuthInterceptor authInterceptor = new ClientAuthInterceptor(); private final MethodDescriptor<Flight.Ticket, ArrowMessage> doGetDescriptor; private final MethodDescriptor<ArrowMessage, Flight.PutResult> doPutDescriptor; + private final MethodDescriptor<ArrowMessage, ArrowMessage> doExchangeDescriptor; /** * Create a Flight client from an allocator and a gRPC channel. @@ -98,6 +98,7 @@ public class FlightClient implements AutoCloseable { asyncStub = FlightServiceGrpc.newStub(interceptedChannel); doGetDescriptor = FlightBindingService.getDoGetDescriptor(allocator); doPutDescriptor = FlightBindingService.getDoPutDescriptor(allocator); + doExchangeDescriptor = FlightBindingService.getDoExchangeDescriptor(allocator); } /** @@ -195,31 +196,29 @@ public class FlightClient implements AutoCloseable { * @param root VectorSchemaRoot the root containing data * @param metadataListener A handler for metadata messages from the server. * @param options RPC-layer hints for this call. - * @return ClientStreamListener an interface to control uploading data + * @return ClientStreamListener an interface to control uploading data. + * {@link ClientStreamListener#start(VectorSchemaRoot, DictionaryProvider)} will already have been called. */ public ClientStreamListener startPut(FlightDescriptor descriptor, VectorSchemaRoot root, DictionaryProvider provider, PutListener metadataListener, CallOption... options) { - Preconditions.checkNotNull(descriptor); - Preconditions.checkNotNull(root); + Preconditions.checkNotNull(descriptor, "descriptor must not be null"); + Preconditions.checkNotNull(root, "root must not be null"); + Preconditions.checkNotNull(provider, "provider must not be null"); + Preconditions.checkNotNull(metadataListener, "metadataListener must not be null"); + final io.grpc.CallOptions callOptions = CallOptions.wrapStub(asyncStub, options).getCallOptions(); try { - SetStreamObserver resultObserver = new SetStreamObserver(allocator, metadataListener); - final io.grpc.CallOptions callOptions = CallOptions.wrapStub(asyncStub, options).getCallOptions(); + final SetStreamObserver resultObserver = new SetStreamObserver(allocator, metadataListener); ClientCallStreamObserver<ArrowMessage> observer = (ClientCallStreamObserver<ArrowMessage>) ClientCalls.asyncBidiStreamingCall( interceptedChannel.newCall(doPutDescriptor, callOptions), resultObserver); - // send the schema to start. - DictionaryUtils.generateSchemaMessages(root.getSchema(), descriptor, provider, observer::onNext); - return new PutObserver(new VectorUnloader( - root, true /* include # of nulls in vectors */, true /* must align buffers to be C++-compatible */), - observer, metadataListener); + final ClientStreamListener writer = new PutObserver( + descriptor, observer, metadataListener::isCancelled, metadataListener::getResult); + // Send the schema to start. + writer.start(root, provider); + return writer; } catch (StatusRuntimeException sre) { throw StatusUtils.fromGrpcRuntimeException(sre); - } catch (Exception e) { - // Only happens if DictionaryUtils#generateSchemaMessages fails. This should only happen if closing buffers fails, - // which means the application is in an unknown state, so propagate the exception. - throw CallStatus.INTERNAL.withDescription("Could not send all schema messages: " + e.toString()).withCause(e) - .toRuntimeException(); } } @@ -293,6 +292,82 @@ public class FlightClient implements AutoCloseable { return stream; } + /** + * Initiate a bidirectional data exchange with the server. + * + * @param descriptor A descriptor for the data stream. + * @param options RPC call options. + * @return A pair of a readable stream and a writable stream. + */ + public ExchangeReaderWriter doExchange(FlightDescriptor descriptor, CallOption... options) { + Preconditions.checkNotNull(descriptor, "descriptor must not be null"); + final io.grpc.CallOptions callOptions = CallOptions.wrapStub(asyncStub, options).getCallOptions(); + + try { + final ClientCall<ArrowMessage, ArrowMessage> call = interceptedChannel.newCall(doExchangeDescriptor, callOptions); + final FlightStream stream = new FlightStream(allocator, PENDING_REQUESTS, call::cancel, call::request); + final ClientCallStreamObserver<ArrowMessage> observer = (ClientCallStreamObserver<ArrowMessage>) + ClientCalls.asyncBidiStreamingCall(call, stream.asObserver()); + final ClientStreamListener writer = new PutObserver( + descriptor, observer, stream.completed::isDone, + () -> { + try { + stream.completed.get(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw CallStatus.INTERNAL + .withDescription("Client error: interrupted while completing call") + .withCause(e) + .toRuntimeException(); + } catch (ExecutionException e) { + throw CallStatus.INTERNAL + .withDescription("Client error: internal while completing call") + .withCause(e) + .toRuntimeException(); + } + }); + // Send the descriptor to start. + try (final ArrowMessage message = new ArrowMessage(descriptor.toProtocol())) { + observer.onNext(message); + } catch (Exception e) { + throw CallStatus.INTERNAL + .withCause(e) + .withDescription("Could not write descriptor " + descriptor) + .toRuntimeException(); + } + return new ExchangeReaderWriter(stream, writer); + } catch (StatusRuntimeException sre) { + throw StatusUtils.fromGrpcRuntimeException(sre); + } + } + + /** A pair of a reader and a writer for a DoExchange call. */ + public static class ExchangeReaderWriter implements AutoCloseable { + private final FlightStream reader; + private final ClientStreamListener writer; + + ExchangeReaderWriter(FlightStream reader, ClientStreamListener writer) { + this.reader = reader; + this.writer = writer; + } + + /** Get the reader for the call. */ + public FlightStream getReader() { + return reader; + } + + /** Get the writer for the call. */ + public ClientStreamListener getWriter() { + return writer; + } + + /** Shut down the streams in this call. */ + @Override + public void close() throws Exception { + reader.close(); + } + } + private static class SetStreamObserver implements StreamObserver<Flight.PutResult> { private final BufferAllocator allocator; private final StreamListener<PutResult> listener; @@ -321,81 +396,51 @@ public class FlightClient implements AutoCloseable { } } - private static class PutObserver implements ClientStreamListener { - - private final ClientCallStreamObserver<ArrowMessage> observer; - private final VectorUnloader unloader; - private final PutListener listener; - - public PutObserver(VectorUnloader unloader, ClientCallStreamObserver<ArrowMessage> observer, - PutListener listener) { - this.observer = observer; - this.unloader = unloader; - this.listener = listener; - } + /** + * The implementation of a {@link ClientStreamListener} for writing data to a Flight server. + */ + static class PutObserver extends OutboundStreamListenerImpl implements ClientStreamListener { + private final BooleanSupplier isCancelled; + private final Runnable getResult; - @Override - public void putNext() { - putNext(null); + /** + * Create a new client stream listener. + * + * @param descriptor The descriptor for the stream. + * @param observer The write-side gRPC StreamObserver. + * @param isCancelled A flag to check if the call has been cancelled. + * @param getResult A flag that blocks until the overall call completes. + */ + PutObserver(FlightDescriptor descriptor, ClientCallStreamObserver<ArrowMessage> observer, + BooleanSupplier isCancelled, Runnable getResult) { + super(descriptor, observer); + Preconditions.checkNotNull(descriptor, "descriptor must be provided"); + Preconditions.checkNotNull(isCancelled, "isCancelled must be provided"); + Preconditions.checkNotNull(getResult, "getResult must be provided"); + this.isCancelled = isCancelled; + this.getResult = getResult; + this.unloader = null; } @Override - public void putNext(ArrowBuf appMetadata) { - ArrowRecordBatch batch = unloader.getRecordBatch(); + protected void waitUntilStreamReady() { // Check isCancelled as well to avoid inadvertently blocking forever // (so long as PutListener properly implements it) - while (!observer.isReady() && !listener.isCancelled()) { + while (!responseObserver.isReady() && !isCancelled.getAsBoolean()) { /* busy wait */ } - // ArrowMessage takes ownership of appMetadata and batch - // gRPC should take ownership of ArrowMessage, but in some cases it doesn't, so guard against it - // ArrowMessage#close is a no-op if gRPC did its job - try (final ArrowMessage message = new ArrowMessage(batch, appMetadata)) { - observer.onNext(message); - } catch (Exception e) { - throw StatusUtils.fromThrowable(e); - } - } - - @Override - public void error(Throwable ex) { - observer.onError(StatusUtils.toGrpcException(ex)); - } - - @Override - public void completed() { - observer.onCompleted(); } @Override public void getResult() { - listener.getResult(); + getResult.run(); } } /** - * Interface for subscribers to a stream returned by the server. + * Interface for writers to an Arrow data stream. */ - public interface ClientStreamListener { - - /** - * Send the current data in the corresponding {@link VectorSchemaRoot} to the server. - */ - void putNext(); - - /** - * Send the current data in the corresponding {@link VectorSchemaRoot} to the server, along with - * application-specific metadata. This takes ownership of the buffer. - */ - void putNext(ArrowBuf appMetadata); - - /** - * Indicate an error to the server. Terminates the stream; do not call {@link #completed()}. - */ - void error(Throwable ex); - - /** Indicate the stream is finished on the client side. */ - void completed(); + public interface ClientStreamListener extends OutboundStreamListener { /** * Wait for the stream to finish on the server side. You must call this to be notified of any errors that may have diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightMethod.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightMethod.java index 13d72db..5d2915b 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightMethod.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightMethod.java @@ -31,6 +31,7 @@ public enum FlightMethod { DO_PUT, DO_ACTION, LIST_ACTIONS, + DO_EXCHANGE, ; /** @@ -55,6 +56,8 @@ public enum FlightMethod { return DO_ACTION; } else if (FlightServiceGrpc.getListActionsMethod().getFullMethodName().equals(methodName)) { return LIST_ACTIONS; + } else if (FlightServiceGrpc.getDoExchangeMethod().getFullMethodName().equals(methodName)) { + return DO_EXCHANGE; } throw new IllegalArgumentException("Not a Flight method name in gRPC: " + methodName); } diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightProducer.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightProducer.java index ee064ad..5e5b265 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightProducer.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightProducer.java @@ -19,10 +19,6 @@ package org.apache.arrow.flight; import java.util.Map; -import org.apache.arrow.memory.ArrowBuf; -import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.dictionary.DictionaryProvider; - /** * API to Implement an Arrow Flight producer. */ @@ -78,6 +74,10 @@ public interface FlightProducer { Runnable acceptPut(CallContext context, FlightStream flightStream, StreamListener<PutResult> ackStream); + default void doExchange(CallContext context, FlightStream reader, ServerStreamListener writer) { + throw CallStatus.UNIMPLEMENTED.withDescription("DoExchange is unimplemented").toRuntimeException(); + } + /** * Generic handler for application-defined RPCs. * @@ -98,7 +98,7 @@ public interface FlightProducer { /** * An interface for sending Arrow data back to a client. */ - interface ServerStreamListener { + interface ServerStreamListener extends OutboundStreamListener { /** * Check whether the call has been cancelled. If so, stop sending data. @@ -106,46 +106,6 @@ public interface FlightProducer { boolean isCancelled(); /** - * A hint indicating whether the client is ready to receive data without excessive buffering. - */ - boolean isReady(); - - /** - * Start sending data, using the schema of the given {@link VectorSchemaRoot}. - * - * <p>This method must be called before all others. - */ - void start(VectorSchemaRoot root); - - /** - * Start sending data, using the schema of the given {@link VectorSchemaRoot}. - * - * <p>This method must be called before all others. - */ - void start(VectorSchemaRoot root, DictionaryProvider dictionaries); - - /** - * Send the current contents of the associated {@link VectorSchemaRoot}. - */ - void putNext(); - - /** - * Send the current contents of the associated {@link VectorSchemaRoot} alongside application-defined metadata. - * @param metadata The metadata to send. Ownership of the buffer is transferred to the Flight implementation. - */ - void putNext(ArrowBuf metadata); - - /** - * Indicate an error to the client. Terminates the stream; do not call {@link #completed()} afterwards. - */ - void error(Throwable ex); - - /** - * Indicate that transmission is finished. - */ - void completed(); - - /** * Set a callback for when the client cancels a call, i.e. {@link #isCancelled()} has become true. * * <p>Note that this callback may only be called some time after {@link #isCancelled()} becomes true, and may never diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServer.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServer.java index 8523416..3c8b7ae 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServer.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServer.java @@ -42,6 +42,8 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.util.Preconditions; import org.apache.arrow.util.VisibleForTesting; +import com.google.common.util.concurrent.ThreadFactoryBuilder; + import io.grpc.Server; import io.grpc.ServerInterceptors; import io.grpc.netty.NettyServerBuilder; @@ -243,7 +245,9 @@ public class FlightServer implements AutoCloseable { exec = executor; grpcExecutor = null; } else { - exec = Executors.newCachedThreadPool(); + exec = Executors.newCachedThreadPool( + // Name threads for better debuggability + new ThreadFactoryBuilder().setNameFormat("flight-server-default-executor-%d").build()); grpcExecutor = exec; } final FlightBindingService flightService = new FlightBindingService(allocator, producer, authHandler, exec); diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightService.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightService.java index 955d51f..30c7d30 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightService.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightService.java @@ -33,13 +33,8 @@ import org.apache.arrow.flight.grpc.ServerInterceptorAdapter; import org.apache.arrow.flight.grpc.StatusUtils; import org.apache.arrow.flight.impl.Flight; import org.apache.arrow.flight.impl.FlightServiceGrpc.FlightServiceImplBase; -import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.util.Preconditions; -import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.VectorUnloader; -import org.apache.arrow.vector.dictionary.DictionaryProvider; -import org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -90,10 +85,12 @@ class FlightService extends FlightServiceImplBase { // Do NOT call StreamPipe#onCompleted, as the FlightProducer implementation may be asynchronous } - public void doGetCustom(Flight.Ticket ticket, StreamObserver<ArrowMessage> responseObserver) { + public void doGetCustom(Flight.Ticket ticket, StreamObserver<ArrowMessage> responseObserverSimple) { + final ServerCallStreamObserver<ArrowMessage> responseObserver = + (ServerCallStreamObserver<ArrowMessage>) responseObserverSimple; final GetListener listener = new GetListener(responseObserver, this::handleExceptionWithMiddleware); try { - producer.getStream(makeContext((ServerCallStreamObserver<?>) responseObserver), new Ticket(ticket), listener); + producer.getStream(makeContext(responseObserver), new Ticket(ticket), listener); } catch (Exception ex) { listener.error(ex); } @@ -126,7 +123,7 @@ class FlightService extends FlightServiceImplBase { // Do NOT call StreamPipe#onCompleted, as the FlightProducer implementation may be asynchronous } - private static class GetListener implements ServerStreamListener { + private static class GetListener extends OutboundStreamListenerImpl implements ServerStreamListener { private ServerCallStreamObserver<ArrowMessage> responseObserver; private final Consumer<Throwable> errorHandler; private Runnable onCancelHandler = null; @@ -134,11 +131,11 @@ class FlightService extends FlightServiceImplBase { private volatile VectorUnloader unloader; private boolean completed; - public GetListener(StreamObserver<ArrowMessage> responseObserver, Consumer<Throwable> errorHandler) { - super(); + public GetListener(ServerCallStreamObserver<ArrowMessage> responseObserver, Consumer<Throwable> errorHandler) { + super(null, responseObserver); this.errorHandler = errorHandler; this.completed = false; - this.responseObserver = (ServerCallStreamObserver<ArrowMessage>) responseObserver; + this.responseObserver = responseObserver; this.responseObserver.setOnCancelHandler(this::onCancel); this.responseObserver.disableAutoInboundFlowControl(); } @@ -156,60 +153,20 @@ class FlightService extends FlightServiceImplBase { } @Override - public boolean isReady() { - return responseObserver.isReady(); - } - - @Override public boolean isCancelled() { return responseObserver.isCancelled(); } @Override - public void start(VectorSchemaRoot root) { - start(root, new MapDictionaryProvider()); - } - - @Override - public void start(VectorSchemaRoot root, DictionaryProvider provider) { - unloader = new VectorUnloader(root, true, true); - - try { - DictionaryUtils.generateSchemaMessages(root.getSchema(), null, provider, responseObserver::onNext); - } catch (Exception e) { - // Only happens if closing buffers somehow fails - indicates application is an unknown state so propagate - // the exception - throw new RuntimeException("Could not generate and send all schema messages", e); - } - } - - @Override - public void putNext() { - putNext(null); - } - - @Override - public void putNext(ArrowBuf metadata) { - Preconditions.checkNotNull(unloader); - // close is a no-op if the message has been written to gRPC, otherwise frees the associated buffers - // in some code paths (e.g. if the call is cancelled), gRPC does not write the message, so we need to clean up - // ourselves. Normally, writing the ArrowMessage will transfer ownership of the data to gRPC/Netty. - try (final ArrowMessage message = new ArrowMessage(unloader.getRecordBatch(), metadata)) { - responseObserver.onNext(message); - } catch (Exception e) { - // This exception comes from ArrowMessage#close, not responseObserver#onNext. - // Generally this should not happen - ArrowMessage's implementation only closes non-throwing things. - // The user can't reasonably do anything about this, but if something does throw, we shouldn't let - // execution continue since other state (e.g. allocators) may be in an odd state. - throw new RuntimeException("Could not free ArrowMessage", e); - } + protected void waitUntilStreamReady() { + // Don't do anything - service implementations are expected to manage backpressure themselves } @Override public void error(Throwable ex) { if (!completed) { completed = true; - responseObserver.onError(StatusUtils.toGrpcException(ex)); + super.error(ex); } else { errorHandler.accept(ex); } @@ -217,17 +174,13 @@ class FlightService extends FlightServiceImplBase { @Override public void completed() { - if (unloader == null) { - throw new IllegalStateException("Can't complete stream before starting it"); - } if (!completed) { completed = true; - responseObserver.onCompleted(); + super.completed(); } else { errorHandler.accept(new IllegalStateException("Tried to complete already-completed call")); } } - } public StreamObserver<ArrowMessage> doPutCustom(final StreamObserver<Flight.PutResult> responseObserverSimple) { @@ -248,14 +201,15 @@ class FlightService extends FlightServiceImplBase { } catch (Exception ex) { ackStream.onError(ex); } finally { - // ARROW-6136: Close the stream if and only if acceptPut hasn't closed it itself - // We don't do this for other streams since the implementation may be asynchronous - ackStream.ensureCompleted(); + // Close this stream before telling gRPC that the call is complete. That way we don't race with server shutdown. try { fs.close(); } catch (Exception e) { handleExceptionWithMiddleware(e); } + // ARROW-6136: Close the stream if and only if acceptPut hasn't closed it itself + // We don't do this for other streams since the implementation may be asynchronous + ackStream.ensureCompleted(); } }); @@ -302,6 +256,103 @@ class FlightService extends FlightServiceImplBase { } } + /** Ensures that other resources are cleaned up when the service finishes its call. */ + private static class ExchangeListener extends GetListener { + private final AutoCloseable resource; + private boolean closed = false; + private Runnable onCancelHandler = null; + + public ExchangeListener(ServerCallStreamObserver<ArrowMessage> responseObserver, Consumer<Throwable> errorHandler, + AutoCloseable resource) { + super(responseObserver, errorHandler); + this.resource = resource; + super.setOnCancelHandler(() -> { + try { + if (onCancelHandler != null) { + onCancelHandler.run(); + } + } finally { + cleanup(); + } + }); + } + + private void cleanup() { + if (closed) { + // Prevent double-free. gRPC will call the OnCancelHandler even on a normal call end, which means that + // we'll double-free without this guard. + return; + } + closed = true; + try { + this.resource.close(); + } catch (Exception e) { + throw CallStatus.INTERNAL + .withCause(e) + .withDescription("Server internal error cleaning up resources") + .toRuntimeException(); + } + } + + @Override + public void error(Throwable ex) { + try { + this.cleanup(); + } finally { + super.error(ex); + } + } + + @Override + public void completed() { + try { + this.cleanup(); + } finally { + super.completed(); + } + } + + @Override + public void setOnCancelHandler(Runnable handler) { + onCancelHandler = handler; + } + } + + public StreamObserver<ArrowMessage> doExchangeCustom(StreamObserver<ArrowMessage> responseObserverSimple) { + final ServerCallStreamObserver<ArrowMessage> responseObserver = + (ServerCallStreamObserver<ArrowMessage>) responseObserverSimple; + final FlightStream fs = new FlightStream(allocator, PENDING_REQUESTS, (String message, Throwable cause) -> { + responseObserver.onError(Status.CANCELLED.withCause(cause).withDescription(message).asException()); + }, responseObserver::request); + // When service completes the call, this cleans up the FlightStream + final ExchangeListener listener = new ExchangeListener( + responseObserver, + this::handleExceptionWithMiddleware, + () -> { + // Force the stream to "complete" so it will close without incident. At this point, we don't care since + // we are about to end the call. (Normally it will raise an error.) + fs.completed.complete(null); + fs.close(); + }); + responseObserver.disableAutoInboundFlowControl(); + responseObserver.request(1); + final StreamObserver<ArrowMessage> observer = fs.asObserver(); + try { + executors.submit(() -> { + try { + producer.doExchange(makeContext(responseObserver), fs, listener); + } catch (Exception ex) { + listener.error(ex); + } + // We do not clean up or close anything here, to allow long-running asynchronous implementations. + // It is the service's responsibility to call completed() or error(), which will then clean up the FlightStream. + }); + } catch (Exception ex) { + listener.error(ex); + } + return observer; + } + /** * Call context for the service. */ diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java index 2302230..cbdbf05 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java @@ -22,6 +22,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.LinkedBlockingQueue; @@ -60,16 +61,16 @@ public class FlightStream implements AutoCloseable { private final Cancellable cancellable; private final LinkedBlockingQueue<AutoCloseable> queue = new LinkedBlockingQueue<>(); private final SettableFuture<VectorSchemaRoot> root = SettableFuture.create(); + private final SettableFuture<FlightDescriptor> descriptor = SettableFuture.create(); private final int pendingTarget; private final Requestor requestor; + final CompletableFuture<Void> completed; private volatile int pending = 1; - private boolean completed = false; private volatile VectorSchemaRoot fulfilledRoot; private DictionaryProvider.MapDictionaryProvider dictionaries; private volatile VectorLoader loader; private volatile Throwable ex; - private volatile FlightDescriptor descriptor; private volatile ArrowBuf applicationMetadata = null; /** @@ -86,6 +87,7 @@ public class FlightStream implements AutoCloseable { this.cancellable = cancellable; this.requestor = requestor; this.dictionaries = new DictionaryProvider.MapDictionaryProvider(); + this.completed = new CompletableFuture<>(); } /** @@ -136,9 +138,15 @@ public class FlightStream implements AutoCloseable { * client sends the descriptor. */ public FlightDescriptor getDescriptor() { - // This blocks until the schema message (with the descriptor) is sent. - getRoot(); - return descriptor; + // This blocks until the first message from the client is received. + try { + return descriptor.get(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw CallStatus.INTERNAL.withCause(e).withDescription("Interrupted").toRuntimeException(); + } catch (ExecutionException e) { + throw CallStatus.INTERNAL.withCause(e).withDescription("Error getting descriptor").toRuntimeException(); + } } /** @@ -150,11 +158,13 @@ public class FlightStream implements AutoCloseable { final List<AutoCloseable> closeables = new ArrayList<>(); // cancellation can throw, but we still want to clean up resources, so make it an AutoCloseable too closeables.add(() -> { - if (!completed && cancellable != null) { + if (!completed.isDone() && cancellable != null) { cancel("Stream closed before end.", /* no exception to report */ null); } }); - closeables.add(root.get()); + if (fulfilledRoot != null) { + closeables.add(fulfilledRoot); + } closeables.add(applicationMetadata); closeables.addAll(queue); if (dictionaries != null) { @@ -162,6 +172,9 @@ public class FlightStream implements AutoCloseable { } AutoCloseables.close(closeables); + // Other code ignores the value of this CompletableFuture, only whether it's completed (or has an exception) + // No-op if already complete; do this after the check in the AutoCloseable lambda above + completed.complete(null); } /** @@ -170,21 +183,18 @@ public class FlightStream implements AutoCloseable { */ public boolean next() { try { - // make sure we have the root - root.get().clear(); - - if (completed && queue.isEmpty()) { + if (completed.isDone() && queue.isEmpty()) { return false; } - pending--; requestOutstanding(); Object data = queue.take(); if (DONE == data) { queue.put(DONE); - completed = true; + // Other code ignores the value of this CompletableFuture, only whether it's completed (or has an exception) + completed.complete(null); return false; } else if (DONE_EX == data) { queue.put(DONE_EX); @@ -195,18 +205,22 @@ public class FlightStream implements AutoCloseable { } } else { try (ArrowMessage msg = ((ArrowMessage) data)) { - if (msg.getMessageType() == HeaderType.RECORD_BATCH) { + if (msg.getMessageType() == HeaderType.NONE) { + updateMetadata(msg); + // We received a message without data, so erase any leftover data + if (fulfilledRoot != null) { + fulfilledRoot.clear(); + } + } else if (msg.getMessageType() == HeaderType.RECORD_BATCH) { + // Ensure we have the root + root.get().clear(); try (ArrowRecordBatch arb = msg.asRecordBatch()) { loader.load(arb); } - if (this.applicationMetadata != null) { - this.applicationMetadata.close(); - } - this.applicationMetadata = msg.getApplicationMetadata(); - if (this.applicationMetadata != null) { - this.applicationMetadata.getReferenceManager().retain(); - } + updateMetadata(msg); } else if (msg.getMessageType() == HeaderType.DICTIONARY_BATCH) { + // Ensure we have the root + root.get().clear(); try (ArrowDictionaryBatch arb = msg.asDictionaryBatch()) { final long id = arb.getDictionaryId(); if (dictionaries == null) { @@ -239,6 +253,17 @@ public class FlightStream implements AutoCloseable { } } + /** Update our metdata reference with a new one from this message. */ + private void updateMetadata(ArrowMessage msg) { + if (this.applicationMetadata != null) { + this.applicationMetadata.close(); + } + this.applicationMetadata = msg.getApplicationMetadata(); + if (this.applicationMetadata != null) { + this.applicationMetadata.getReferenceManager().retain(); + } + } + /** * Get the current vector data from the stream. * @@ -258,6 +283,17 @@ public class FlightStream implements AutoCloseable { } /** + * Check if there is a root (i.e. whether the other end has started sending data). + * + * Updated by calls to {@link #next()}. + * + * @return true if and only if the other end has started sending data. + */ + public boolean hasRoot() { + return root.isDone(); + } + + /** * Get the most recent metadata sent from the server. This may be cleared by calls to {@link #next()} if the server * sends a message without metadata. This does NOT take ownership of the buffer - call retain() to create a reference * if you need the buffer after a call to {@link #next()}. @@ -285,6 +321,16 @@ public class FlightStream implements AutoCloseable { public void onNext(ArrowMessage msg) { requestOutstanding(); switch (msg.getMessageType()) { + case NONE: { + // No IPC message - pure metadata or descriptor + if (msg.getDescriptor() != null) { + descriptor.set(new FlightDescriptor(msg.getDescriptor())); + } + if (msg.getApplicationMetadata() != null) { + queue.add(msg); + } + break; + } case SCHEMA: { Schema schema = msg.asSchema(); final List<Field> fields = new ArrayList<>(); @@ -299,7 +345,9 @@ public class FlightStream implements AutoCloseable { schema = new Schema(fields, schema.getCustomMetadata()); fulfilledRoot = VectorSchemaRoot.create(schema, allocator); loader = new VectorLoader(fulfilledRoot); - descriptor = msg.getDescriptor() != null ? new FlightDescriptor(msg.getDescriptor()) : null; + if (msg.getDescriptor() != null) { + descriptor.set(new FlightDescriptor(msg.getDescriptor())); + } root.set(fulfilledRoot); break; @@ -310,7 +358,6 @@ public class FlightStream implements AutoCloseable { case DICTIONARY_BATCH: queue.add(msg); break; - case NONE: case TENSOR: default: queue.add(DONE_EX); @@ -320,18 +367,14 @@ public class FlightStream implements AutoCloseable { @Override public void onError(Throwable t) { - ex = t; + ex = StatusUtils.fromThrowable(t); queue.add(DONE_EX); - root.setException(t); + root.setException(ex); } @Override public void onCompleted() { // Depends on gRPC calling onNext and onCompleted non-concurrently - if (!root.isDone()) { - root.setException( - CallStatus.INTERNAL.withDescription("Stream completed without receiving schema.").toRuntimeException()); - } queue.add(DONE); } } @@ -342,6 +385,8 @@ public class FlightStream implements AutoCloseable { * @throws UnsupportedOperationException on a stream being uploaded from the client. */ public void cancel(String message, Throwable exception) { + completed.completeExceptionally( + CallStatus.CANCELLED.withDescription(message).withCause(exception).toRuntimeException()); if (cancellable != null) { cancellable.cancel(message, exception); } else { @@ -357,6 +402,7 @@ public class FlightStream implements AutoCloseable { /** * Provides a callback to cancel a process that is in progress. */ + @FunctionalInterface public interface Cancellable { void cancel(String message, Throwable exception); } @@ -364,6 +410,7 @@ public class FlightStream implements AutoCloseable { /** * Provides a interface to request more items from a stream producer. */ + @FunctionalInterface public interface Requestor { /** * Requests <code>count</code> more messages from the instance of this object. diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/OutboundStreamListener.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/OutboundStreamListener.java new file mode 100644 index 0000000..194003c --- /dev/null +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/OutboundStreamListener.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.DictionaryProvider; + +/** + * An interface for writing data to a peer, client or server. + */ +public interface OutboundStreamListener { + + /** + * A hint indicating whether the client is ready to receive data without excessive buffering. + * + * <p>Writers should poll this flag before sending data to respect backpressure from the client and + * avoid sending data faster than the client can handle. Ignoring this flag may mean that the server + * will start consuming excessive amounts of memory, as it may buffer messages in memory. + */ + boolean isReady(); + + /** + * Start sending data, using the schema of the given {@link VectorSchemaRoot}. + * + * <p>This method must be called before all others, except {@link #putMetadata(ArrowBuf)}. + */ + void start(VectorSchemaRoot root); + + /** + * Start sending data, using the schema of the given {@link VectorSchemaRoot}. + * + * <p>This method must be called before all others. + */ + void start(VectorSchemaRoot root, DictionaryProvider dictionaries); + + /** + * Send the current contents of the associated {@link VectorSchemaRoot}. + * + * <p>This will not necessarily block until the message is actually sent; it may buffer messages + * in memory. Use {@link #isReady()} to check if there is backpressure and avoid excessive buffering. + */ + void putNext(); + + /** + * Send the current contents of the associated {@link VectorSchemaRoot} alongside application-defined metadata. + * @param metadata The metadata to send. Ownership of the buffer is transferred to the Flight implementation. + */ + void putNext(ArrowBuf metadata); + + /** + * Send a pure metadata message without any associated data. + * + * <p>This may be called without starting the stream. + */ + void putMetadata(ArrowBuf metadata); + + /** + * Indicate an error to the client. Terminates the stream; do not call {@link #completed()} afterwards. + */ + void error(Throwable ex); + + /** + * Indicate that transmission is finished. + */ + void completed(); +} diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/OutboundStreamListenerImpl.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/OutboundStreamListenerImpl.java new file mode 100644 index 0000000..c826c85 --- /dev/null +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/OutboundStreamListenerImpl.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import org.apache.arrow.flight.grpc.StatusUtils; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.dictionary.DictionaryProvider; + +import io.grpc.stub.CallStreamObserver; + +/** + * A base class for writing Arrow data to a Flight stream. + */ +abstract class OutboundStreamListenerImpl implements OutboundStreamListener { + private final FlightDescriptor descriptor; // nullable + protected final CallStreamObserver<ArrowMessage> responseObserver; + protected volatile VectorUnloader unloader; // null until stream started + + OutboundStreamListenerImpl(FlightDescriptor descriptor, CallStreamObserver<ArrowMessage> responseObserver) { + Preconditions.checkNotNull(responseObserver, "responseObserver must be provided"); + this.descriptor = descriptor; + this.responseObserver = responseObserver; + this.unloader = null; + } + + @Override + public boolean isReady() { + return responseObserver.isReady(); + } + + @Override + public void start(VectorSchemaRoot root) { + start(root, new DictionaryProvider.MapDictionaryProvider()); + } + + @Override + public void start(VectorSchemaRoot root, DictionaryProvider dictionaries) { + try { + DictionaryUtils.generateSchemaMessages(root.getSchema(), descriptor, dictionaries, responseObserver::onNext); + } catch (Exception e) { + // Only happens if closing buffers somehow fails - indicates application is an unknown state so propagate + // the exception + throw new RuntimeException("Could not generate and send all schema messages", e); + } + // We include the null count and align buffers to be compatible with Flight/C++ + unloader = new VectorUnloader(root, /* includeNullCount */ true, /* alignBuffers */ true); + } + + @Override + public void putNext() { + putNext(null); + } + + /** + * Busy-wait until the stream is ready. + * + * <p>This is overridable as client/server have different behavior. + */ + protected abstract void waitUntilStreamReady(); + + @Override + public void putNext(ArrowBuf metadata) { + if (unloader == null) { + throw CallStatus.INTERNAL.withDescription("Stream was not started, call start()").toRuntimeException(); + } + + waitUntilStreamReady(); + // close is a no-op if the message has been written to gRPC, otherwise frees the associated buffers + // in some code paths (e.g. if the call is cancelled), gRPC does not write the message, so we need to clean up + // ourselves. Normally, writing the ArrowMessage will transfer ownership of the data to gRPC/Netty. + try (final ArrowMessage message = new ArrowMessage(unloader.getRecordBatch(), metadata)) { + responseObserver.onNext(message); + } catch (Exception e) { + // This exception comes from ArrowMessage#close, not responseObserver#onNext. + // Generally this should not happen - ArrowMessage's implementation only closes non-throwing things. + // The user can't reasonably do anything about this, but if something does throw, we shouldn't let + // execution continue since other state (e.g. allocators) may be in an odd state. + throw new RuntimeException("Could not free ArrowMessage", e); + } + } + + @Override + public void putMetadata(ArrowBuf metadata) { + waitUntilStreamReady(); + try (final ArrowMessage message = new ArrowMessage(metadata)) { + responseObserver.onNext(message); + } catch (Exception e) { + throw StatusUtils.fromThrowable(e); + } + } + + @Override + public void error(Throwable ex) { + responseObserver.onError(StatusUtils.toGrpcException(ex)); + } + + @Override + public void completed() { + responseObserver.onCompleted(); + } +} diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java index 3a6a676..8242bc0 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java @@ -228,15 +228,18 @@ public class TestBasicOperation { @Test public void getStream() throws Exception { test(c -> { - FlightStream stream = c.getStream(new Ticket(new byte[0])); - VectorSchemaRoot root = stream.getRoot(); - IntVector iv = (IntVector) root.getVector("c1"); - int value = 0; - while (stream.next()) { - for (int i = 0; i < root.getRowCount(); i++) { - Assert.assertEquals(value, iv.get(i)); - value++; + try (final FlightStream stream = c.getStream(new Ticket(new byte[0]))) { + VectorSchemaRoot root = stream.getRoot(); + IntVector iv = (IntVector) root.getVector("c1"); + int value = 0; + while (stream.next()) { + for (int i = 0; i < root.getRowCount(); i++) { + Assert.assertEquals(value, iv.get(i)); + value++; + } } + } catch (Exception e) { + throw new RuntimeException(e); } }); } diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDoExchange.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDoExchange.java new file mode 100644 index 0000000..7aa95f7 --- /dev/null +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDoExchange.java @@ -0,0 +1,395 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Collections; +import java.util.stream.IntStream; + +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.testing.ValueVectorDataPopulator; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class TestDoExchange { + static byte[] EXCHANGE_DO_GET = "do-get".getBytes(StandardCharsets.UTF_8); + static byte[] EXCHANGE_DO_PUT = "do-put".getBytes(StandardCharsets.UTF_8); + static byte[] EXCHANGE_ECHO = "echo".getBytes(StandardCharsets.UTF_8); + static byte[] EXCHANGE_METADATA_ONLY = "only-metadata".getBytes(StandardCharsets.UTF_8); + static byte[] EXCHANGE_TRANSFORM = "transform".getBytes(StandardCharsets.UTF_8); + + private BufferAllocator allocator; + private FlightServer server; + private FlightClient client; + + @Before + public void setUp() throws Exception { + allocator = new RootAllocator(Integer.MAX_VALUE); + final Location serverLocation = Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, 0); + server = FlightServer.builder(allocator, serverLocation, new Producer(allocator)).build(); + server.start(); + final Location clientLocation = Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, server.getPort()); + client = FlightClient.builder(allocator, clientLocation).build(); + } + + @After + public void tearDown() throws Exception { + AutoCloseables.close(client, server, allocator); + } + + /** Test a pure-metadata flow. */ + @Test + public void testDoExchangeOnlyMetadata() throws Exception { + // Send a particular descriptor to the server and check for a particular response pattern. + try (final FlightClient.ExchangeReaderWriter stream = + client.doExchange(FlightDescriptor.command(EXCHANGE_METADATA_ONLY))) { + final FlightStream reader = stream.getReader(); + + // Server starts by sending a message without data (hence no VectorSchemaRoot should be present) + assertTrue(reader.next()); + assertFalse(reader.hasRoot()); + assertEquals(42, reader.getLatestMetadata().getInt(0)); + + // Write a metadata message to the server (without sending any data) + ArrowBuf buf = allocator.buffer(4); + buf.writeInt(84); + stream.getWriter().putMetadata(buf); + + // Check that the server echoed the metadata back to us + assertTrue(reader.next()); + assertFalse(reader.hasRoot()); + assertEquals(84, reader.getLatestMetadata().getInt(0)); + + // Close our write channel and ensure the server also closes theirs + stream.getWriter().completed(); + assertFalse(reader.next()); + } + } + + /** Emulate a DoGet with a DoExchange. */ + @Test + public void testDoExchangeDoGet() throws Exception { + try (final FlightClient.ExchangeReaderWriter stream = + client.doExchange(FlightDescriptor.command(EXCHANGE_DO_GET))) { + final FlightStream reader = stream.getReader(); + VectorSchemaRoot root = reader.getRoot(); + IntVector iv = (IntVector) root.getVector("a"); + int value = 0; + while (reader.next()) { + for (int i = 0; i < root.getRowCount(); i++) { + assertFalse(String.format("Row %d should not be null", value), iv.isNull(i)); + assertEquals(value, iv.get(i)); + value++; + } + } + assertEquals(10, value); + } + } + + /** Emulate a DoPut with a DoExchange. */ + @Test + public void testDoExchangeDoPut() throws Exception { + final Schema schema = new Schema(Collections.singletonList(Field.nullable("a", new ArrowType.Int(32, true)))); + try (final FlightClient.ExchangeReaderWriter stream = + client.doExchange(FlightDescriptor.command(EXCHANGE_DO_PUT)); + final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + IntVector iv = (IntVector) root.getVector("a"); + iv.allocateNew(); + + stream.getWriter().start(root); + int counter = 0; + for (int i = 0; i < 10; i++) { + ValueVectorDataPopulator.setVector(iv, IntStream.range(0, i).boxed().toArray(Integer[]::new)); + root.setRowCount(i); + counter += i; + stream.getWriter().putNext(); + + assertTrue(stream.getReader().next()); + assertFalse(stream.getReader().hasRoot()); + // For each write, the server sends back a metadata message containing the index of the last written batch + final ArrowBuf metadata = stream.getReader().getLatestMetadata(); + assertEquals(counter, metadata.getInt(0)); + } + stream.getWriter().completed(); + + while (stream.getReader().next()) { + // Drain the stream. Otherwise closing the stream sends a CANCEL which seriously screws with the server. + // CANCEL -> runs onCancel handler -> closes the FlightStream early + } + } + } + + /** Test a DoExchange that echoes the client message. */ + @Test + public void testDoExchangeEcho() throws Exception { + final Schema schema = new Schema(Collections.singletonList(Field.nullable("a", new ArrowType.Int(32, true)))); + try (final FlightClient.ExchangeReaderWriter stream = client.doExchange(FlightDescriptor.command(EXCHANGE_ECHO)); + final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + final FlightStream reader = stream.getReader(); + + // First try writing metadata without starting the Arrow data stream + ArrowBuf buf = allocator.buffer(4); + buf.writeInt(42); + stream.getWriter().putMetadata(buf); + buf = allocator.buffer(4); + buf.writeInt(84); + stream.getWriter().putMetadata(buf); + + // Ensure that the server echoes the metadata back, also without starting its data stream + assertTrue(reader.next()); + assertFalse(reader.hasRoot()); + assertEquals(42, reader.getLatestMetadata().getInt(0)); + assertTrue(reader.next()); + assertFalse(reader.hasRoot()); + assertEquals(84, reader.getLatestMetadata().getInt(0)); + + // Write data and check that it gets echoed back. + IntVector iv = (IntVector) root.getVector("a"); + iv.allocateNew(); + stream.getWriter().start(root); + for (int i = 0; i < 10; i++) { + iv.setSafe(0, i); + root.setRowCount(1); + stream.getWriter().putNext(); + + assertTrue(reader.next()); + assertNull(reader.getLatestMetadata()); + assertEquals(root.getSchema(), reader.getSchema()); + assertEquals(i, ((IntVector) reader.getRoot().getVector("a")).get(0)); + } + + // Complete the stream so that the server knows not to expect any more messages from us. + stream.getWriter().completed(); + // The server will end its side of the call, so this shouldn't block or indicate that + // there is more data. + assertFalse("We should not be waiting for any messages", reader.next()); + } + } + + /** Write some data, have it transformed, then read it back. */ + @Test + public void testTransform() throws Exception { + final Schema schema = new Schema(Arrays.asList( + Field.nullable("a", new ArrowType.Int(32, true)), + Field.nullable("b", new ArrowType.Int(32, true)))); + try (final FlightClient.ExchangeReaderWriter stream = + client.doExchange(FlightDescriptor.command(EXCHANGE_TRANSFORM))) { + // Write ten batches of data to the stream, where batch N contains N rows of data (N in [0, 10)) + final FlightStream reader = stream.getReader(); + final FlightClient.ClientStreamListener writer = stream.getWriter(); + try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + writer.start(root); + for (int batchIndex = 0; batchIndex < 10; batchIndex++) { + for (final FieldVector rawVec : root.getFieldVectors()) { + final IntVector vec = (IntVector) rawVec; + ValueVectorDataPopulator.setVector(vec, IntStream.range(0, batchIndex).boxed().toArray(Integer[]::new)); + } + root.setRowCount(batchIndex); + writer.putNext(); + } + } + // Indicate that we're done writing so that the server does not expect more data. + writer.completed(); + + // Read back data. We expect the server to double each value in each row of each batch. + assertEquals(schema, reader.getSchema()); + final VectorSchemaRoot root = reader.getRoot(); + for (int batchIndex = 0; batchIndex < 10; batchIndex++) { + assertTrue("Didn't receive batch #" + batchIndex, reader.next()); + assertEquals(batchIndex, root.getRowCount()); + for (final FieldVector rawVec : root.getFieldVectors()) { + final IntVector vec = (IntVector) rawVec; + for (int row = 0; row < batchIndex; row++) { + assertEquals(2 * row, vec.get(row)); + } + } + } + + // The server also sends back a metadata-only message containing the message count + assertTrue("There should be one extra message", reader.next()); + assertEquals(10, reader.getLatestMetadata().getInt(0)); + assertFalse("There should be no more data", reader.next()); + } + } + + static class Producer extends NoOpFlightProducer { + private final BufferAllocator allocator; + + Producer(BufferAllocator allocator) { + this.allocator = allocator; + } + + @Override + public void doExchange(CallContext context, FlightStream reader, ServerStreamListener writer) { + if (Arrays.equals(reader.getDescriptor().getCommand(), EXCHANGE_METADATA_ONLY)) { + metadataOnly(context, reader, writer); + } else if (Arrays.equals(reader.getDescriptor().getCommand(), EXCHANGE_DO_GET)) { + doGet(context, reader, writer); + } else if (Arrays.equals(reader.getDescriptor().getCommand(), EXCHANGE_DO_PUT)) { + doPut(context, reader, writer); + } else if (Arrays.equals(reader.getDescriptor().getCommand(), EXCHANGE_ECHO)) { + echo(context, reader, writer); + } else if (Arrays.equals(reader.getDescriptor().getCommand(), EXCHANGE_TRANSFORM)) { + transform(context, reader, writer); + } else { + writer.error(CallStatus.UNIMPLEMENTED.withDescription("Command not implemented").toRuntimeException()); + } + } + + /** Emulate DoGet. */ + private void doGet(CallContext context, FlightStream reader, ServerStreamListener writer) { + final Schema schema = new Schema(Collections.singletonList(Field.nullable("a", new ArrowType.Int(32, true)))); + try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + writer.start(root); + root.allocateNew(); + IntVector iv = (IntVector) root.getVector("a"); + + for (int i = 0; i < 10; i += 2) { + iv.set(0, i); + iv.set(1, i + 1); + root.setRowCount(2); + writer.putNext(); + } + } + writer.completed(); + } + + /** Emulate DoPut. */ + private void doPut(CallContext context, FlightStream reader, ServerStreamListener writer) { + int counter = 0; + while (reader.next()) { + if (!reader.hasRoot()) { + writer.error(CallStatus.INVALID_ARGUMENT.withDescription("Message has no data").toRuntimeException()); + return; + } + counter += reader.getRoot().getRowCount(); + + final ArrowBuf pong = allocator.buffer(4); + pong.writeInt(counter); + writer.putMetadata(pong); + } + writer.completed(); + } + + /** Exchange metadata without ever exchanging data. */ + private void metadataOnly(CallContext context, FlightStream reader, ServerStreamListener writer) { + final ArrowBuf buf = allocator.buffer(4); + buf.writeInt(42); + writer.putMetadata(buf); + assertTrue(reader.next()); + assertNotNull(reader.getLatestMetadata()); + reader.getLatestMetadata().getReferenceManager().retain(); + writer.putMetadata(reader.getLatestMetadata()); + writer.completed(); + } + + /** Echo the client's response back to it. */ + private void echo(CallContext context, FlightStream reader, ServerStreamListener writer) { + VectorSchemaRoot root = null; + VectorLoader loader = null; + while (reader.next()) { + if (reader.hasRoot()) { + if (root == null) { + root = VectorSchemaRoot.create(reader.getSchema(), allocator); + loader = new VectorLoader(root); + writer.start(root); + } + VectorUnloader unloader = new VectorUnloader(reader.getRoot()); + try (final ArrowRecordBatch arb = unloader.getRecordBatch()) { + loader.load(arb); + } + if (reader.getLatestMetadata() != null) { + reader.getLatestMetadata().getReferenceManager().retain(); + writer.putNext(reader.getLatestMetadata()); + } else { + writer.putNext(); + } + } else { + // Pure metadata + reader.getLatestMetadata().getReferenceManager().retain(); + writer.putMetadata(reader.getLatestMetadata()); + } + } + if (root != null) { + root.close(); + } + writer.completed(); + } + + /** Accept a set of messages, then return some result. */ + private void transform(CallContext context, FlightStream reader, ServerStreamListener writer) { + final Schema schema = reader.getSchema(); + for (final Field field : schema.getFields()) { + if (!(field.getType() instanceof ArrowType.Int)) { + writer.error(CallStatus.INVALID_ARGUMENT.withDescription("Invalid type: " + field).toRuntimeException()); + return; + } + final ArrowType.Int intType = (ArrowType.Int) field.getType(); + if (!intType.getIsSigned() || intType.getBitWidth() != 32) { + writer.error(CallStatus.INVALID_ARGUMENT.withDescription("Must be i32: " + field).toRuntimeException()); + return; + } + } + int batches = 0; + try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + writer.start(root); + final VectorLoader loader = new VectorLoader(root); + final VectorUnloader unloader = new VectorUnloader(reader.getRoot()); + while (reader.next()) { + try (final ArrowRecordBatch batch = unloader.getRecordBatch()) { + loader.load(batch); + } + batches++; + for (final FieldVector rawVec : root.getFieldVectors()) { + final IntVector vec = (IntVector) rawVec; + for (int i = 0; i < root.getRowCount(); i++) { + if (!vec.isNull(i)) { + vec.set(i, vec.get(i) * 2); + } + } + } + writer.putNext(); + } + } + final ArrowBuf count = allocator.buffer(4); + count.writeInt(batches); + writer.putMetadata(count); + writer.completed(); + } + } +} diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestErrorMetadata.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestErrorMetadata.java index b6d344f..02a21f2 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestErrorMetadata.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestErrorMetadata.java @@ -48,12 +48,12 @@ public class TestErrorMetadata { FlightTestUtil.getStartedServer( (location) -> FlightServer.builder(allocator, location, new TestFlightProducer(perf)).build()); final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) { - FlightStream stream = client.getStream(new Ticket("abs".getBytes())); - stream.next(); - Assert.fail(); - } catch (FlightRuntimeException fre) { + final CallStatus flightStatus = FlightTestUtil.assertCode(FlightStatusCode.CANCELLED, () -> { + FlightStream stream = client.getStream(new Ticket("abs".getBytes())); + stream.next(); + }); PerfOuterClass.Perf newPerf = null; - ErrorFlightMetadata metadata = fre.status().metadata(); + ErrorFlightMetadata metadata = flightStatus.metadata(); Assert.assertNotNull(metadata); Assert.assertEquals(2, metadata.keys().size()); Assert.assertTrue(metadata.containsKey("grpc-status-details-bin")); diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestServerOptions.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestServerOptions.java index 791e0b1..363ad44 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestServerOptions.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestServerOptions.java @@ -124,14 +124,15 @@ public class TestServerOptions { (port) -> FlightServer.builder(a, location, producer).build() )) { try (FlightClient c = FlightClient.builder(a, location).build()) { - FlightStream stream = c.getStream(new Ticket(new byte[0])); - VectorSchemaRoot root = stream.getRoot(); - IntVector iv = (IntVector) root.getVector("c1"); - int value = 0; - while (stream.next()) { - for (int i = 0; i < root.getRowCount(); i++) { - Assert.assertEquals(value, iv.get(i)); - value++; + try (FlightStream stream = c.getStream(new Ticket(new byte[0]))) { + VectorSchemaRoot root = stream.getRoot(); + IntVector iv = (IntVector) root.getVector("c1"); + int value = 0; + while (stream.next()) { + for (int i = 0; i < root.getRowCount(); i++) { + Assert.assertEquals(value, iv.get(i)); + value++; + } } } }