This is an automated email from the ASF dual-hosted git repository.
gian pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git
The following commit(s) were added to refs/heads/master by this push:
new f8867341ceb [grpc-query] Respect client cancellations/disconnections
(#19005)
f8867341ceb is described below
commit f8867341cebc6967773d8ba1066696bed5cba30b
Author: Ben Smithgall <[email protected]>
AuthorDate: Wed Mar 4 20:27:06 2026 -0500
[grpc-query] Respect client cancellations/disconnections (#19005)
This change allows in-flight queries to be cancelled (for example, by a
gRPC client) by accepting a disconnection callback.
---
.../druid/grpc/server/GrpcEndpointInitializer.java | 12 +-
.../org/apache/druid/grpc/server/QueryDriver.java | 64 +++++-
.../org/apache/druid/grpc/server/QueryService.java | 34 ++-
.../java/org/apache/druid/grpc/BasicAuthTest.java | 4 +-
.../java/org/apache/druid/grpc/DriverTest.java | 4 +-
.../java/org/apache/druid/grpc/GrpcQueryTest.java | 4 +-
.../java/org/apache/druid/grpc/TestServer.java | 4 +-
.../apache/druid/grpc/server/QueryServiceTest.java | 227 +++++++++++++++++++++
8 files changed, 336 insertions(+), 17 deletions(-)
diff --git
a/extensions-contrib/grpc-query/src/main/java/org/apache/druid/grpc/server/GrpcEndpointInitializer.java
b/extensions-contrib/grpc-query/src/main/java/org/apache/druid/grpc/server/GrpcEndpointInitializer.java
index 68a67e3bc78..af80e3d928d 100644
---
a/extensions-contrib/grpc-query/src/main/java/org/apache/druid/grpc/server/GrpcEndpointInitializer.java
+++
b/extensions-contrib/grpc-query/src/main/java/org/apache/druid/grpc/server/GrpcEndpointInitializer.java
@@ -29,6 +29,7 @@ import
org.apache.druid.java.util.common.lifecycle.LifecycleStop;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.query.DefaultQueryConfig;
import org.apache.druid.server.QueryLifecycleFactory;
+import org.apache.druid.server.QueryScheduler;
import org.apache.druid.server.security.AuthenticatorMapper;
import org.apache.druid.sql.SqlStatementFactory;
@@ -66,12 +67,19 @@ public class GrpcEndpointInitializer
final @NativeQuery SqlStatementFactory sqlStatementFactory,
final QueryLifecycleFactory queryLifecycleFactory,
final DefaultQueryConfig defaultQueryConfig,
- final AuthenticatorMapper authMapper
+ final AuthenticatorMapper authMapper,
+ final QueryScheduler queryScheduler
)
{
this.config = config;
this.authMapper = authMapper;
- this.driver = new QueryDriver(jsonMapper, sqlStatementFactory,
defaultQueryConfig.getContext(), queryLifecycleFactory);
+ this.driver = new QueryDriver(
+ jsonMapper,
+ sqlStatementFactory,
+ defaultQueryConfig.getContext(),
+ queryLifecycleFactory,
+ queryScheduler
+ );
}
@LifecycleStart
diff --git
a/extensions-contrib/grpc-query/src/main/java/org/apache/druid/grpc/server/QueryDriver.java
b/extensions-contrib/grpc-query/src/main/java/org/apache/druid/grpc/server/QueryDriver.java
index 468783c40e4..c30b822648b 100644
---
a/extensions-contrib/grpc-query/src/main/java/org/apache/druid/grpc/server/QueryDriver.java
+++
b/extensions-contrib/grpc-query/src/main/java/org/apache/druid/grpc/server/QueryDriver.java
@@ -41,12 +41,14 @@ import org.apache.druid.java.util.common.guava.Accumulator;
import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.query.Query;
+import org.apache.druid.query.QueryInterruptedException;
import org.apache.druid.query.QueryToolChest;
import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.server.QueryLifecycle;
import org.apache.druid.server.QueryLifecycleFactory;
+import org.apache.druid.server.QueryScheduler;
import org.apache.druid.server.security.Access;
import org.apache.druid.server.security.AuthenticationResult;
import org.apache.druid.server.security.AuthorizationResult;
@@ -69,6 +71,7 @@ import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
+import java.util.concurrent.atomic.AtomicReference;
/**
@@ -98,36 +101,55 @@ public class QueryDriver
private final SqlStatementFactory sqlStatementFactory;
private final Map<String, Object> defaultContext;
private final QueryLifecycleFactory queryLifecycleFactory;
+ private final QueryScheduler queryScheduler;
public QueryDriver(
final ObjectMapper jsonMapper,
final SqlStatementFactory sqlStatementFactory,
final Map<String, Object> defaultContext,
- final QueryLifecycleFactory queryLifecycleFactory
+ final QueryLifecycleFactory queryLifecycleFactory,
+ final QueryScheduler queryScheduler
)
{
this.jsonMapper = Preconditions.checkNotNull(jsonMapper, "jsonMapper");
this.sqlStatementFactory = Preconditions.checkNotNull(sqlStatementFactory,
"sqlStatementFactory");
this.defaultContext = defaultContext;
this.queryLifecycleFactory = queryLifecycleFactory;
+ this.queryScheduler = queryScheduler;
}
/**
- * First-cut synchronous query handler. Druid prefers to stream results, in
- * part to avoid overly-short network timeouts. However, for now, we simply
run
- * the query within this call and prepare the Protobuf response. Async
handling
- * can come later.
+ * Submit a query with cancellation support.
+ *
+ * @param cancelCallback will be populated with a Runnable that cancels the
query.
+ * The caller should invoke this if the query should
be cancelled.
*/
- public QueryResponse submitQuery(QueryRequest request, AuthenticationResult
authResult)
+ public QueryResponse submitQuery(
+ QueryRequest request,
+ AuthenticationResult authResult,
+ AtomicReference<Runnable> cancelCallback
+ )
{
if (request.getQueryType() == QueryOuterClass.QueryType.NATIVE) {
- return runNativeQuery(request, authResult);
+ return runNativeQuery(request, authResult, cancelCallback);
} else {
- return runSqlQuery(request, authResult);
+ return runSqlQuery(request, authResult, cancelCallback);
}
}
- private QueryResponse runNativeQuery(QueryRequest request,
AuthenticationResult authResult)
+ /**
+ * Backward-compatible method for existing tests.
+ */
+ public QueryResponse submitQuery(QueryRequest request, AuthenticationResult
authResult)
+ {
+ return submitQuery(request, authResult, new AtomicReference<>(() -> {}));
+ }
+
+ private QueryResponse runNativeQuery(
+ QueryRequest request,
+ AuthenticationResult authResult,
+ AtomicReference<Runnable> cancelCallback
+ )
{
Query<?> query;
try {
@@ -146,8 +168,14 @@ public class QueryDriver
final QueryLifecycle queryLifecycle = queryLifecycleFactory.factorize();
+ if (queryScheduler != null) {
+ final String queryId = query.getId();
+ cancelCallback.set(() -> queryScheduler.cancelQuery(queryId));
+ }
+
final org.apache.druid.server.QueryResponse queryResponse;
final String currThreadName = Thread.currentThread().getName();
+ Throwable caught = null;
try {
queryLifecycle.initialize(query);
AuthorizationResult authorizationResult =
queryLifecycle.authorize(authResult);
@@ -171,7 +199,12 @@ public class QueryDriver
.addAllColumns(encodeNativeColumns(rowSignature,
request.getSkipColumnsList()))
.build();
}
+ catch (QueryInterruptedException e) {
+ caught = e;
+ throw e;
+ }
catch (IOException | RuntimeException e) {
+ caught = e;
return QueryResponse.newBuilder()
.setQueryId(query.getId())
.setStatus(QueryStatus.RUNTIME_ERROR)
@@ -179,11 +212,16 @@ public class QueryDriver
.build();
}
finally {
+ queryLifecycle.emitLogsAndMetrics(caught, null, -1);
Thread.currentThread().setName(currThreadName);
}
}
- private QueryResponse runSqlQuery(QueryRequest request, AuthenticationResult
authResult)
+ private QueryResponse runSqlQuery(
+ QueryRequest request,
+ AuthenticationResult authResult,
+ AtomicReference<Runnable> cancelCallback
+ )
{
final SqlQueryPlus queryPlus;
try {
@@ -197,6 +235,7 @@ public class QueryDriver
.build();
}
final DirectStatement stmt =
sqlStatementFactory.directStatement(queryPlus);
+ cancelCallback.set(stmt::cancel);
final String currThreadName = Thread.currentThread().getName();
try {
Thread.currentThread().setName(StringUtils.format("grpc-sql[%s]",
stmt.sqlQueryId()));
@@ -218,6 +257,11 @@ public class QueryDriver
stmt.close();
throw e;
}
+ catch (QueryInterruptedException e) {
+ stmt.reporter().failed(e);
+ stmt.close();
+ throw e;
+ }
catch (RequestError e) {
stmt.reporter().failed(e);
stmt.close();
diff --git
a/extensions-contrib/grpc-query/src/main/java/org/apache/druid/grpc/server/QueryService.java
b/extensions-contrib/grpc-query/src/main/java/org/apache/druid/grpc/server/QueryService.java
index ddd99c456e3..e3b14ad15f1 100644
---
a/extensions-contrib/grpc-query/src/main/java/org/apache/druid/grpc/server/QueryService.java
+++
b/extensions-contrib/grpc-query/src/main/java/org/apache/druid/grpc/server/QueryService.java
@@ -19,14 +19,20 @@
package org.apache.druid.grpc.server;
+import com.google.common.util.concurrent.MoreExecutors;
+import io.grpc.Context;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.stub.StreamObserver;
import org.apache.druid.grpc.proto.QueryGrpc;
import org.apache.druid.grpc.proto.QueryOuterClass.QueryRequest;
import org.apache.druid.grpc.proto.QueryOuterClass.QueryResponse;
+import org.apache.druid.query.QueryException;
+import org.apache.druid.query.QueryInterruptedException;
import org.apache.druid.server.security.ForbiddenException;
+import java.util.concurrent.atomic.AtomicReference;
+
/**
* Implementation of the gRPC Query service. Provides a single method
* to run a query using the "driver" that holds the actual Druid SQL
@@ -44,8 +50,15 @@ class QueryService extends QueryGrpc.QueryImplBase
@Override
public void submitQuery(QueryRequest request, StreamObserver<QueryResponse>
responseObserver)
{
+ final AtomicReference<Runnable> cancelCallback = new AtomicReference<>(()
-> {});
+
+ // getAndSet ensures the callback runs at most once
+ final Runnable cancelOnce = () -> cancelCallback.getAndSet(() -> {}).run();
+
+ Context.current().addListener(context -> cancelOnce.run(),
MoreExecutors.directExecutor());
+
try {
- QueryResponse reply = driver.submitQuery(request,
QueryServer.AUTH_KEY.get());
+ QueryResponse reply = driver.submitQuery(request,
QueryServer.AUTH_KEY.get(), cancelCallback);
responseObserver.onNext(reply);
responseObserver.onCompleted();
}
@@ -55,5 +68,24 @@ class QueryService extends QueryGrpc.QueryImplBase
// handler.
responseObserver.onError(new
StatusRuntimeException(Status.PERMISSION_DENIED));
}
+ catch (QueryInterruptedException e) {
+ if (QueryException.QUERY_CANCELED_ERROR_CODE.equals(e.getErrorCode())) {
+ responseObserver.onError(new StatusRuntimeException(
+ Status.CANCELLED.withDescription(e.getMessage())
+ ));
+ } else {
+ responseObserver.onError(new StatusRuntimeException(
+ Status.INTERNAL.withDescription(e.getMessage())
+ ));
+ }
+ }
+ finally {
+ // Handle race where context was cancelled before cancelCallback was set.
+ // The listener would have invoked the no-op; now that the real callback
+ // is registered, we check again and invoke if needed.
+ if (Context.current().isCancelled()) {
+ cancelOnce.run();
+ }
+ }
}
}
diff --git
a/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/BasicAuthTest.java
b/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/BasicAuthTest.java
index d0779d3bf65..f04af5370a3 100644
---
a/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/BasicAuthTest.java
+++
b/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/BasicAuthTest.java
@@ -33,6 +33,7 @@ import org.apache.druid.grpc.server.QueryServer;
import org.apache.druid.metadata.DefaultPasswordProvider;
import org.apache.druid.security.basic.authentication.BasicHTTPAuthenticator;
import
org.apache.druid.security.basic.authentication.validator.CredentialsValidator;
+import org.apache.druid.server.QueryStackTests;
import org.apache.druid.server.security.AuthConfig;
import org.apache.druid.server.security.AuthenticationResult;
import org.apache.druid.server.security.AuthenticatorMapper;
@@ -72,7 +73,8 @@ public class BasicAuthTest extends BaseCalciteQueryTest
sqlTestFramework.queryJsonMapper(),
plannerFixture.statementFactory(),
Map.of("forbiddenKey", "system-default-value"), // systen default
forbidden key, only superuser can change it
- sqlTestFramework.queryLifecycleFactory()
+ sqlTestFramework.queryLifecycleFactory(),
+ QueryStackTests.DEFAULT_NOOP_SCHEDULER
);
CredentialsValidator validator = new CredentialsValidator()
diff --git
a/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/DriverTest.java
b/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/DriverTest.java
index af6bebbabfa..0bf7aa2c0e3 100644
---
a/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/DriverTest.java
+++
b/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/DriverTest.java
@@ -27,6 +27,7 @@ import
org.apache.druid.grpc.proto.QueryOuterClass.QueryResponse;
import org.apache.druid.grpc.proto.QueryOuterClass.QueryResultFormat;
import org.apache.druid.grpc.proto.QueryOuterClass.QueryStatus;
import org.apache.druid.grpc.server.QueryDriver;
+import org.apache.druid.server.QueryStackTests;
import org.apache.druid.server.security.AuthConfig;
import org.apache.druid.sql.calcite.BaseCalciteQueryTest;
import org.apache.druid.sql.calcite.util.CalciteTests;
@@ -58,7 +59,8 @@ public class DriverTest extends BaseCalciteQueryTest
sqlTestFramework.queryJsonMapper(),
plannerFixture.statementFactory(),
Map.of(),
- sqlTestFramework.queryLifecycleFactory()
+ sqlTestFramework.queryLifecycleFactory(),
+ QueryStackTests.DEFAULT_NOOP_SCHEDULER
);
}
diff --git
a/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/GrpcQueryTest.java
b/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/GrpcQueryTest.java
index abf35a7dda5..06cbbdd6afe 100644
---
a/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/GrpcQueryTest.java
+++
b/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/GrpcQueryTest.java
@@ -34,6 +34,7 @@ import org.apache.druid.grpc.proto.TestResults.QueryResult;
import org.apache.druid.grpc.server.GrpcQueryConfig;
import org.apache.druid.grpc.server.QueryDriver;
import org.apache.druid.grpc.server.QueryServer;
+import org.apache.druid.server.QueryStackTests;
import org.apache.druid.server.security.AllowAllAuthenticator;
import org.apache.druid.server.security.AuthConfig;
import org.apache.druid.server.security.AuthenticatorMapper;
@@ -77,7 +78,8 @@ public class GrpcQueryTest extends BaseCalciteQueryTest
sqlTestFramework.queryJsonMapper(),
plannerFixture.statementFactory(),
Map.of(),
- sqlTestFramework.queryLifecycleFactory()
+ sqlTestFramework.queryLifecycleFactory(),
+ QueryStackTests.DEFAULT_NOOP_SCHEDULER
);
AuthenticatorMapper authMapper = new AuthenticatorMapper(
ImmutableMap.of(
diff --git
a/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/TestServer.java
b/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/TestServer.java
index 03e60ae515c..faad4fa8a39 100644
---
a/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/TestServer.java
+++
b/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/TestServer.java
@@ -23,6 +23,7 @@ import com.google.common.collect.ImmutableMap;
import org.apache.druid.grpc.server.GrpcEndpointInitializer;
import org.apache.druid.grpc.server.GrpcQueryConfig;
import org.apache.druid.query.DefaultQueryConfig;
+import org.apache.druid.server.QueryStackTests;
import org.apache.druid.server.security.AllowAllAuthenticator;
import org.apache.druid.server.security.AuthConfig;
import org.apache.druid.server.security.AuthenticatorMapper;
@@ -60,7 +61,8 @@ public class TestServer extends BaseCalciteQueryTest
plannerFixture.statementFactory(),
null,
DefaultQueryConfig.NIL,
- authMapper
+ authMapper,
+ QueryStackTests.DEFAULT_NOOP_SCHEDULER
);
serverInit.start();
Runtime.getRuntime().addShutdownHook(new Thread()
diff --git
a/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/server/QueryServiceTest.java
b/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/server/QueryServiceTest.java
new file mode 100644
index 00000000000..7d3eec6cdae
--- /dev/null
+++
b/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/server/QueryServiceTest.java
@@ -0,0 +1,227 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.grpc.server;
+
+import io.grpc.Context;
+import io.grpc.Status;
+import io.grpc.StatusRuntimeException;
+import io.grpc.stub.StreamObserver;
+import org.apache.druid.grpc.proto.QueryOuterClass.QueryRequest;
+import org.apache.druid.grpc.proto.QueryOuterClass.QueryResponse;
+import org.apache.druid.query.QueryException;
+import org.apache.druid.query.QueryInterruptedException;
+import org.apache.druid.server.security.AuthenticationResult;
+import org.easymock.Capture;
+import org.easymock.EasyMock;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import java.util.concurrent.CancellationException;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+
+import static org.easymock.EasyMock.anyObject;
+import static org.easymock.EasyMock.capture;
+import static org.easymock.EasyMock.expect;
+import static org.easymock.EasyMock.expectLastCall;
+import static org.easymock.EasyMock.replay;
+import static org.easymock.EasyMock.verify;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class QueryServiceTest
+{
+ private QueryDriver mockDriver;
+ private StreamObserver<QueryResponse> mockObserver;
+ private QueryService service;
+ private Context.CancellableContext testContext;
+
+ @BeforeEach
+ public void setup()
+ {
+ mockDriver = EasyMock.createMock(QueryDriver.class);
+ mockObserver = EasyMock.createMock(StreamObserver.class);
+ service = new QueryService(mockDriver);
+ testContext = Context.current().withCancellation();
+ }
+
+ @AfterEach
+ public void tearDown()
+ {
+ testContext.cancel(null);
+ }
+
+ @Test
+ public void testCancelledQueryReturnsGrpcCancelled()
+ {
+ QueryInterruptedException cancelException = new QueryInterruptedException(
+ new CancellationException("Query was cancelled")
+ );
+
+ expect(mockDriver.submitQuery(
+ anyObject(QueryRequest.class),
+ anyObject(AuthenticationResult.class),
+ anyObject(AtomicReference.class)
+ )).andThrow(cancelException);
+
+ Capture<Throwable> errorCapture = EasyMock.newCapture();
+ mockObserver.onError(capture(errorCapture));
+ expectLastCall();
+
+ replay(mockDriver, mockObserver);
+
+ testContext.run(() ->
service.submitQuery(QueryRequest.getDefaultInstance(), mockObserver));
+
+ verify(mockDriver, mockObserver);
+
+ StatusRuntimeException statusException = (StatusRuntimeException)
errorCapture.getValue();
+ assertEquals(Status.Code.CANCELLED, statusException.getStatus().getCode());
+ }
+
+ @Test
+ public void testInterruptedQueryReturnsGrpcInternal()
+ {
+ QueryInterruptedException interruptException = new
QueryInterruptedException(
+ QueryException.QUERY_INTERRUPTED_ERROR_CODE,
+ "Query was interrupted",
+ null,
+ null
+ );
+
+ expect(mockDriver.submitQuery(
+ anyObject(QueryRequest.class),
+ anyObject(AuthenticationResult.class),
+ anyObject(AtomicReference.class)
+ )).andThrow(interruptException);
+
+ Capture<Throwable> errorCapture = EasyMock.newCapture();
+ mockObserver.onError(capture(errorCapture));
+ expectLastCall();
+
+ replay(mockDriver, mockObserver);
+
+ testContext.run(() ->
service.submitQuery(QueryRequest.getDefaultInstance(), mockObserver));
+
+ verify(mockDriver, mockObserver);
+
+ StatusRuntimeException statusException = (StatusRuntimeException)
errorCapture.getValue();
+ assertEquals(Status.Code.INTERNAL, statusException.getStatus().getCode());
+ }
+
+ @Test
+ public void testSuccessfulQueryCompletes()
+ {
+ expect(mockDriver.submitQuery(
+ anyObject(QueryRequest.class),
+ anyObject(AuthenticationResult.class),
+ anyObject(AtomicReference.class)
+ )).andReturn(QueryResponse.getDefaultInstance());
+
+ mockObserver.onNext(anyObject(QueryResponse.class));
+ expectLastCall();
+ mockObserver.onCompleted();
+ expectLastCall();
+
+ replay(mockDriver, mockObserver);
+
+ testContext.run(() ->
service.submitQuery(QueryRequest.getDefaultInstance(), mockObserver));
+
+ verify(mockDriver, mockObserver);
+ }
+
+ @Test
+ public void testContextCancellationInvokesCancelCallback() throws Exception
+ {
+ AtomicBoolean callbackInvoked = new AtomicBoolean(false);
+ Capture<AtomicReference<Runnable>> callbackCapture = EasyMock.newCapture();
+
+ expect(mockDriver.submitQuery(
+ anyObject(QueryRequest.class),
+ anyObject(AuthenticationResult.class),
+ capture(callbackCapture)
+ )).andAnswer(() -> {
+ // Register a callback that tracks invocation
+ callbackCapture.getValue().set(() -> callbackInvoked.set(true));
+ return QueryResponse.getDefaultInstance();
+ });
+
+ mockObserver.onNext(anyObject(QueryResponse.class));
+ expectLastCall();
+ mockObserver.onCompleted();
+ expectLastCall();
+
+ replay(mockDriver, mockObserver);
+
+ testContext.run(() ->
service.submitQuery(QueryRequest.getDefaultInstance(), mockObserver));
+
+ // Now cancel the context - listener should fire and invoke our callback
+ testContext.cancel(null);
+
+ assertTrue(callbackInvoked.get(), "Cancel callback should have been
invoked");
+ verify(mockDriver, mockObserver);
+ }
+
+ @Test
+ public void testMidExecutionCancellationInvokesCallback() throws Exception
+ {
+ CountDownLatch queryStarted = new CountDownLatch(1);
+ CountDownLatch contextCancelled = new CountDownLatch(1);
+ AtomicBoolean callbackInvoked = new AtomicBoolean(false);
+
+ Capture<AtomicReference<Runnable>> callbackCapture = EasyMock.newCapture();
+
+ expect(mockDriver.submitQuery(
+ anyObject(QueryRequest.class),
+ anyObject(AuthenticationResult.class),
+ capture(callbackCapture)
+ )).andAnswer(() -> {
+ callbackCapture.getValue().set(() -> callbackInvoked.set(true));
+ queryStarted.countDown();
+ contextCancelled.await(1, TimeUnit.SECONDS);
+ throw new QueryInterruptedException(new
CancellationException("Cancelled"));
+ });
+
+ Capture<Throwable> errorCapture = EasyMock.newCapture();
+ mockObserver.onError(capture(errorCapture));
+ expectLastCall();
+
+ replay(mockDriver, mockObserver);
+
+ Thread queryThread = new Thread(() ->
+ testContext.run(() ->
service.submitQuery(QueryRequest.getDefaultInstance(), mockObserver))
+ );
+ queryThread.start();
+
+ assertTrue(queryStarted.await(1, TimeUnit.SECONDS), "Query should start");
+ testContext.cancel(null);
+ contextCancelled.countDown();
+
+ queryThread.join(1000);
+
+ assertTrue(callbackInvoked.get(), "Cancel callback should have been
invoked");
+ verify(mockDriver, mockObserver);
+
+ StatusRuntimeException statusException = (StatusRuntimeException)
errorCapture.getValue();
+ assertEquals(Status.Code.CANCELLED, statusException.getStatus().getCode());
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]