This is an automated email from the ASF dual-hosted git repository.

gian pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git


The following commit(s) were added to refs/heads/master by this push:
     new f8867341ceb [grpc-query] Respect client cancellations/disconnections 
(#19005)
f8867341ceb is described below

commit f8867341cebc6967773d8ba1066696bed5cba30b
Author: Ben Smithgall <[email protected]>
AuthorDate: Wed Mar 4 20:27:06 2026 -0500

    [grpc-query] Respect client cancellations/disconnections (#19005)
    
    This change allows in-flight queries to be cancelled (for example, by a 
gRPC client) by accepting a disconnection callback.
---
 .../druid/grpc/server/GrpcEndpointInitializer.java |  12 +-
 .../org/apache/druid/grpc/server/QueryDriver.java  |  64 +++++-
 .../org/apache/druid/grpc/server/QueryService.java |  34 ++-
 .../java/org/apache/druid/grpc/BasicAuthTest.java  |   4 +-
 .../java/org/apache/druid/grpc/DriverTest.java     |   4 +-
 .../java/org/apache/druid/grpc/GrpcQueryTest.java  |   4 +-
 .../java/org/apache/druid/grpc/TestServer.java     |   4 +-
 .../apache/druid/grpc/server/QueryServiceTest.java | 227 +++++++++++++++++++++
 8 files changed, 336 insertions(+), 17 deletions(-)

diff --git 
a/extensions-contrib/grpc-query/src/main/java/org/apache/druid/grpc/server/GrpcEndpointInitializer.java
 
b/extensions-contrib/grpc-query/src/main/java/org/apache/druid/grpc/server/GrpcEndpointInitializer.java
index 68a67e3bc78..af80e3d928d 100644
--- 
a/extensions-contrib/grpc-query/src/main/java/org/apache/druid/grpc/server/GrpcEndpointInitializer.java
+++ 
b/extensions-contrib/grpc-query/src/main/java/org/apache/druid/grpc/server/GrpcEndpointInitializer.java
@@ -29,6 +29,7 @@ import 
org.apache.druid.java.util.common.lifecycle.LifecycleStop;
 import org.apache.druid.java.util.common.logger.Logger;
 import org.apache.druid.query.DefaultQueryConfig;
 import org.apache.druid.server.QueryLifecycleFactory;
+import org.apache.druid.server.QueryScheduler;
 import org.apache.druid.server.security.AuthenticatorMapper;
 import org.apache.druid.sql.SqlStatementFactory;
 
@@ -66,12 +67,19 @@ public class GrpcEndpointInitializer
       final @NativeQuery SqlStatementFactory sqlStatementFactory,
       final QueryLifecycleFactory queryLifecycleFactory,
       final DefaultQueryConfig defaultQueryConfig,
-      final AuthenticatorMapper authMapper
+      final AuthenticatorMapper authMapper,
+      final QueryScheduler queryScheduler
   )
   {
     this.config = config;
     this.authMapper = authMapper;
-    this.driver = new QueryDriver(jsonMapper, sqlStatementFactory, 
defaultQueryConfig.getContext(), queryLifecycleFactory);
+    this.driver = new QueryDriver(
+        jsonMapper,
+        sqlStatementFactory,
+        defaultQueryConfig.getContext(),
+        queryLifecycleFactory,
+        queryScheduler
+    );
   }
 
   @LifecycleStart
diff --git 
a/extensions-contrib/grpc-query/src/main/java/org/apache/druid/grpc/server/QueryDriver.java
 
b/extensions-contrib/grpc-query/src/main/java/org/apache/druid/grpc/server/QueryDriver.java
index 468783c40e4..c30b822648b 100644
--- 
a/extensions-contrib/grpc-query/src/main/java/org/apache/druid/grpc/server/QueryDriver.java
+++ 
b/extensions-contrib/grpc-query/src/main/java/org/apache/druid/grpc/server/QueryDriver.java
@@ -41,12 +41,14 @@ import org.apache.druid.java.util.common.guava.Accumulator;
 import org.apache.druid.java.util.common.guava.Sequence;
 import org.apache.druid.java.util.common.logger.Logger;
 import org.apache.druid.query.Query;
+import org.apache.druid.query.QueryInterruptedException;
 import org.apache.druid.query.QueryToolChest;
 import org.apache.druid.segment.column.ColumnHolder;
 import org.apache.druid.segment.column.ColumnType;
 import org.apache.druid.segment.column.RowSignature;
 import org.apache.druid.server.QueryLifecycle;
 import org.apache.druid.server.QueryLifecycleFactory;
+import org.apache.druid.server.QueryScheduler;
 import org.apache.druid.server.security.Access;
 import org.apache.druid.server.security.AuthenticationResult;
 import org.apache.druid.server.security.AuthorizationResult;
@@ -69,6 +71,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Optional;
 import java.util.UUID;
+import java.util.concurrent.atomic.AtomicReference;
 
 
 /**
@@ -98,36 +101,55 @@ public class QueryDriver
   private final SqlStatementFactory sqlStatementFactory;
   private final Map<String, Object> defaultContext;
   private final QueryLifecycleFactory queryLifecycleFactory;
+  private final QueryScheduler queryScheduler;
 
   public QueryDriver(
       final ObjectMapper jsonMapper,
       final SqlStatementFactory sqlStatementFactory,
       final Map<String, Object> defaultContext,
-      final QueryLifecycleFactory queryLifecycleFactory
+      final QueryLifecycleFactory queryLifecycleFactory,
+      final QueryScheduler queryScheduler
   )
   {
     this.jsonMapper = Preconditions.checkNotNull(jsonMapper, "jsonMapper");
     this.sqlStatementFactory = Preconditions.checkNotNull(sqlStatementFactory, 
"sqlStatementFactory");
     this.defaultContext = defaultContext;
     this.queryLifecycleFactory = queryLifecycleFactory;
+    this.queryScheduler = queryScheduler;
   }
 
   /**
-   * First-cut synchronous query handler. Druid prefers to stream results, in
-   * part to avoid overly-short network timeouts. However, for now, we simply 
run
-   * the query within this call and prepare the Protobuf response. Async 
handling
-   * can come later.
+   * Submit a query with cancellation support.
+   *
+   * @param cancelCallback will be populated with a Runnable that cancels the 
query.
+   *                       The caller should invoke this if the query should 
be cancelled.
    */
-  public QueryResponse submitQuery(QueryRequest request, AuthenticationResult 
authResult)
+  public QueryResponse submitQuery(
+      QueryRequest request,
+      AuthenticationResult authResult,
+      AtomicReference<Runnable> cancelCallback
+  )
   {
     if (request.getQueryType() == QueryOuterClass.QueryType.NATIVE) {
-      return runNativeQuery(request, authResult);
+      return runNativeQuery(request, authResult, cancelCallback);
     } else {
-      return runSqlQuery(request, authResult);
+      return runSqlQuery(request, authResult, cancelCallback);
     }
   }
 
-  private QueryResponse runNativeQuery(QueryRequest request, 
AuthenticationResult authResult)
+  /**
+   * Backward-compatible method for existing tests.
+   */
+  public QueryResponse submitQuery(QueryRequest request, AuthenticationResult 
authResult)
+  {
+    return submitQuery(request, authResult, new AtomicReference<>(() -> {}));
+  }
+
+  private QueryResponse runNativeQuery(
+      QueryRequest request,
+      AuthenticationResult authResult,
+      AtomicReference<Runnable> cancelCallback
+  )
   {
     Query<?> query;
     try {
@@ -146,8 +168,14 @@ public class QueryDriver
 
     final QueryLifecycle queryLifecycle = queryLifecycleFactory.factorize();
 
+    if (queryScheduler != null) {
+      final String queryId = query.getId();
+      cancelCallback.set(() -> queryScheduler.cancelQuery(queryId));
+    }
+
     final org.apache.druid.server.QueryResponse queryResponse;
     final String currThreadName = Thread.currentThread().getName();
+    Throwable caught = null;
     try {
       queryLifecycle.initialize(query);
       AuthorizationResult authorizationResult = 
queryLifecycle.authorize(authResult);
@@ -171,7 +199,12 @@ public class QueryDriver
                           .addAllColumns(encodeNativeColumns(rowSignature, 
request.getSkipColumnsList()))
                           .build();
     }
+    catch (QueryInterruptedException e) {
+      caught = e;
+      throw e;
+    }
     catch (IOException | RuntimeException e) {
+      caught = e;
       return QueryResponse.newBuilder()
                           .setQueryId(query.getId())
                           .setStatus(QueryStatus.RUNTIME_ERROR)
@@ -179,11 +212,16 @@ public class QueryDriver
                           .build();
     }
     finally {
+      queryLifecycle.emitLogsAndMetrics(caught, null, -1);
       Thread.currentThread().setName(currThreadName);
     }
   }
 
-  private QueryResponse runSqlQuery(QueryRequest request, AuthenticationResult 
authResult)
+  private QueryResponse runSqlQuery(
+      QueryRequest request,
+      AuthenticationResult authResult,
+      AtomicReference<Runnable> cancelCallback
+  )
   {
     final SqlQueryPlus queryPlus;
     try {
@@ -197,6 +235,7 @@ public class QueryDriver
                           .build();
     }
     final DirectStatement stmt = 
sqlStatementFactory.directStatement(queryPlus);
+    cancelCallback.set(stmt::cancel);
     final String currThreadName = Thread.currentThread().getName();
     try {
       Thread.currentThread().setName(StringUtils.format("grpc-sql[%s]", 
stmt.sqlQueryId()));
@@ -218,6 +257,11 @@ public class QueryDriver
       stmt.close();
       throw e;
     }
+    catch (QueryInterruptedException e) {
+      stmt.reporter().failed(e);
+      stmt.close();
+      throw e;
+    }
     catch (RequestError e) {
       stmt.reporter().failed(e);
       stmt.close();
diff --git 
a/extensions-contrib/grpc-query/src/main/java/org/apache/druid/grpc/server/QueryService.java
 
b/extensions-contrib/grpc-query/src/main/java/org/apache/druid/grpc/server/QueryService.java
index ddd99c456e3..e3b14ad15f1 100644
--- 
a/extensions-contrib/grpc-query/src/main/java/org/apache/druid/grpc/server/QueryService.java
+++ 
b/extensions-contrib/grpc-query/src/main/java/org/apache/druid/grpc/server/QueryService.java
@@ -19,14 +19,20 @@
 
 package org.apache.druid.grpc.server;
 
+import com.google.common.util.concurrent.MoreExecutors;
+import io.grpc.Context;
 import io.grpc.Status;
 import io.grpc.StatusRuntimeException;
 import io.grpc.stub.StreamObserver;
 import org.apache.druid.grpc.proto.QueryGrpc;
 import org.apache.druid.grpc.proto.QueryOuterClass.QueryRequest;
 import org.apache.druid.grpc.proto.QueryOuterClass.QueryResponse;
+import org.apache.druid.query.QueryException;
+import org.apache.druid.query.QueryInterruptedException;
 import org.apache.druid.server.security.ForbiddenException;
 
+import java.util.concurrent.atomic.AtomicReference;
+
 /**
  * Implementation of the gRPC Query service. Provides a single method
  * to run a query using the "driver" that holds the actual Druid SQL
@@ -44,8 +50,15 @@ class QueryService extends QueryGrpc.QueryImplBase
   @Override
   public void submitQuery(QueryRequest request, StreamObserver<QueryResponse> 
responseObserver)
   {
+    final AtomicReference<Runnable> cancelCallback = new AtomicReference<>(() 
-> {});
+
+    // getAndSet ensures the callback runs at most once
+    final Runnable cancelOnce = () -> cancelCallback.getAndSet(() -> {}).run();
+
+    Context.current().addListener(context -> cancelOnce.run(), 
MoreExecutors.directExecutor());
+
     try {
-      QueryResponse reply = driver.submitQuery(request, 
QueryServer.AUTH_KEY.get());
+      QueryResponse reply = driver.submitQuery(request, 
QueryServer.AUTH_KEY.get(), cancelCallback);
       responseObserver.onNext(reply);
       responseObserver.onCompleted();
     }
@@ -55,5 +68,24 @@ class QueryService extends QueryGrpc.QueryImplBase
       // handler.
       responseObserver.onError(new 
StatusRuntimeException(Status.PERMISSION_DENIED));
     }
+    catch (QueryInterruptedException e) {
+      if (QueryException.QUERY_CANCELED_ERROR_CODE.equals(e.getErrorCode())) {
+        responseObserver.onError(new StatusRuntimeException(
+            Status.CANCELLED.withDescription(e.getMessage())
+        ));
+      } else {
+        responseObserver.onError(new StatusRuntimeException(
+            Status.INTERNAL.withDescription(e.getMessage())
+        ));
+      }
+    }
+    finally {
+      // Handle race where context was cancelled before cancelCallback was set.
+      // The listener would have invoked the no-op; now that the real callback
+      // is registered, we check again and invoke if needed.
+      if (Context.current().isCancelled()) {
+        cancelOnce.run();
+      }
+    }
   }
 }
diff --git 
a/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/BasicAuthTest.java
 
b/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/BasicAuthTest.java
index d0779d3bf65..f04af5370a3 100644
--- 
a/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/BasicAuthTest.java
+++ 
b/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/BasicAuthTest.java
@@ -33,6 +33,7 @@ import org.apache.druid.grpc.server.QueryServer;
 import org.apache.druid.metadata.DefaultPasswordProvider;
 import org.apache.druid.security.basic.authentication.BasicHTTPAuthenticator;
 import 
org.apache.druid.security.basic.authentication.validator.CredentialsValidator;
+import org.apache.druid.server.QueryStackTests;
 import org.apache.druid.server.security.AuthConfig;
 import org.apache.druid.server.security.AuthenticationResult;
 import org.apache.druid.server.security.AuthenticatorMapper;
@@ -72,7 +73,8 @@ public class BasicAuthTest extends BaseCalciteQueryTest
         sqlTestFramework.queryJsonMapper(),
         plannerFixture.statementFactory(),
         Map.of("forbiddenKey", "system-default-value"), // systen default 
forbidden key, only superuser can change it
-        sqlTestFramework.queryLifecycleFactory()
+        sqlTestFramework.queryLifecycleFactory(),
+        QueryStackTests.DEFAULT_NOOP_SCHEDULER
     );
 
     CredentialsValidator validator = new CredentialsValidator()
diff --git 
a/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/DriverTest.java
 
b/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/DriverTest.java
index af6bebbabfa..0bf7aa2c0e3 100644
--- 
a/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/DriverTest.java
+++ 
b/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/DriverTest.java
@@ -27,6 +27,7 @@ import 
org.apache.druid.grpc.proto.QueryOuterClass.QueryResponse;
 import org.apache.druid.grpc.proto.QueryOuterClass.QueryResultFormat;
 import org.apache.druid.grpc.proto.QueryOuterClass.QueryStatus;
 import org.apache.druid.grpc.server.QueryDriver;
+import org.apache.druid.server.QueryStackTests;
 import org.apache.druid.server.security.AuthConfig;
 import org.apache.druid.sql.calcite.BaseCalciteQueryTest;
 import org.apache.druid.sql.calcite.util.CalciteTests;
@@ -58,7 +59,8 @@ public class DriverTest extends BaseCalciteQueryTest
         sqlTestFramework.queryJsonMapper(),
         plannerFixture.statementFactory(),
         Map.of(),
-        sqlTestFramework.queryLifecycleFactory()
+        sqlTestFramework.queryLifecycleFactory(),
+        QueryStackTests.DEFAULT_NOOP_SCHEDULER
     );
   }
 
diff --git 
a/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/GrpcQueryTest.java
 
b/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/GrpcQueryTest.java
index abf35a7dda5..06cbbdd6afe 100644
--- 
a/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/GrpcQueryTest.java
+++ 
b/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/GrpcQueryTest.java
@@ -34,6 +34,7 @@ import org.apache.druid.grpc.proto.TestResults.QueryResult;
 import org.apache.druid.grpc.server.GrpcQueryConfig;
 import org.apache.druid.grpc.server.QueryDriver;
 import org.apache.druid.grpc.server.QueryServer;
+import org.apache.druid.server.QueryStackTests;
 import org.apache.druid.server.security.AllowAllAuthenticator;
 import org.apache.druid.server.security.AuthConfig;
 import org.apache.druid.server.security.AuthenticatorMapper;
@@ -77,7 +78,8 @@ public class GrpcQueryTest extends BaseCalciteQueryTest
         sqlTestFramework.queryJsonMapper(),
         plannerFixture.statementFactory(),
         Map.of(),
-        sqlTestFramework.queryLifecycleFactory()
+        sqlTestFramework.queryLifecycleFactory(),
+        QueryStackTests.DEFAULT_NOOP_SCHEDULER
     );
     AuthenticatorMapper authMapper = new AuthenticatorMapper(
         ImmutableMap.of(
diff --git 
a/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/TestServer.java
 
b/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/TestServer.java
index 03e60ae515c..faad4fa8a39 100644
--- 
a/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/TestServer.java
+++ 
b/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/TestServer.java
@@ -23,6 +23,7 @@ import com.google.common.collect.ImmutableMap;
 import org.apache.druid.grpc.server.GrpcEndpointInitializer;
 import org.apache.druid.grpc.server.GrpcQueryConfig;
 import org.apache.druid.query.DefaultQueryConfig;
+import org.apache.druid.server.QueryStackTests;
 import org.apache.druid.server.security.AllowAllAuthenticator;
 import org.apache.druid.server.security.AuthConfig;
 import org.apache.druid.server.security.AuthenticatorMapper;
@@ -60,7 +61,8 @@ public class TestServer extends BaseCalciteQueryTest
         plannerFixture.statementFactory(),
         null,
         DefaultQueryConfig.NIL,
-        authMapper
+        authMapper,
+        QueryStackTests.DEFAULT_NOOP_SCHEDULER
     );
     serverInit.start();
     Runtime.getRuntime().addShutdownHook(new Thread()
diff --git 
a/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/server/QueryServiceTest.java
 
b/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/server/QueryServiceTest.java
new file mode 100644
index 00000000000..7d3eec6cdae
--- /dev/null
+++ 
b/extensions-contrib/grpc-query/src/test/java/org/apache/druid/grpc/server/QueryServiceTest.java
@@ -0,0 +1,227 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.grpc.server;
+
+import io.grpc.Context;
+import io.grpc.Status;
+import io.grpc.StatusRuntimeException;
+import io.grpc.stub.StreamObserver;
+import org.apache.druid.grpc.proto.QueryOuterClass.QueryRequest;
+import org.apache.druid.grpc.proto.QueryOuterClass.QueryResponse;
+import org.apache.druid.query.QueryException;
+import org.apache.druid.query.QueryInterruptedException;
+import org.apache.druid.server.security.AuthenticationResult;
+import org.easymock.Capture;
+import org.easymock.EasyMock;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import java.util.concurrent.CancellationException;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
+
+import static org.easymock.EasyMock.anyObject;
+import static org.easymock.EasyMock.capture;
+import static org.easymock.EasyMock.expect;
+import static org.easymock.EasyMock.expectLastCall;
+import static org.easymock.EasyMock.replay;
+import static org.easymock.EasyMock.verify;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class QueryServiceTest
+{
+  private QueryDriver mockDriver;
+  private StreamObserver<QueryResponse> mockObserver;
+  private QueryService service;
+  private Context.CancellableContext testContext;
+
+  @BeforeEach
+  public void setup()
+  {
+    mockDriver = EasyMock.createMock(QueryDriver.class);
+    mockObserver = EasyMock.createMock(StreamObserver.class);
+    service = new QueryService(mockDriver);
+    testContext = Context.current().withCancellation();
+  }
+
+  @AfterEach
+  public void tearDown()
+  {
+    testContext.cancel(null);
+  }
+
+  @Test
+  public void testCancelledQueryReturnsGrpcCancelled()
+  {
+    QueryInterruptedException cancelException = new QueryInterruptedException(
+        new CancellationException("Query was cancelled")
+    );
+
+    expect(mockDriver.submitQuery(
+        anyObject(QueryRequest.class),
+        anyObject(AuthenticationResult.class),
+        anyObject(AtomicReference.class)
+    )).andThrow(cancelException);
+
+    Capture<Throwable> errorCapture = EasyMock.newCapture();
+    mockObserver.onError(capture(errorCapture));
+    expectLastCall();
+
+    replay(mockDriver, mockObserver);
+
+    testContext.run(() -> 
service.submitQuery(QueryRequest.getDefaultInstance(), mockObserver));
+
+    verify(mockDriver, mockObserver);
+
+    StatusRuntimeException statusException = (StatusRuntimeException) 
errorCapture.getValue();
+    assertEquals(Status.Code.CANCELLED, statusException.getStatus().getCode());
+  }
+
+  @Test
+  public void testInterruptedQueryReturnsGrpcInternal()
+  {
+    QueryInterruptedException interruptException = new 
QueryInterruptedException(
+        QueryException.QUERY_INTERRUPTED_ERROR_CODE,
+        "Query was interrupted",
+        null,
+        null
+    );
+
+    expect(mockDriver.submitQuery(
+        anyObject(QueryRequest.class),
+        anyObject(AuthenticationResult.class),
+        anyObject(AtomicReference.class)
+    )).andThrow(interruptException);
+
+    Capture<Throwable> errorCapture = EasyMock.newCapture();
+    mockObserver.onError(capture(errorCapture));
+    expectLastCall();
+
+    replay(mockDriver, mockObserver);
+
+    testContext.run(() -> 
service.submitQuery(QueryRequest.getDefaultInstance(), mockObserver));
+
+    verify(mockDriver, mockObserver);
+
+    StatusRuntimeException statusException = (StatusRuntimeException) 
errorCapture.getValue();
+    assertEquals(Status.Code.INTERNAL, statusException.getStatus().getCode());
+  }
+
+  @Test
+  public void testSuccessfulQueryCompletes()
+  {
+    expect(mockDriver.submitQuery(
+        anyObject(QueryRequest.class),
+        anyObject(AuthenticationResult.class),
+        anyObject(AtomicReference.class)
+    )).andReturn(QueryResponse.getDefaultInstance());
+
+    mockObserver.onNext(anyObject(QueryResponse.class));
+    expectLastCall();
+    mockObserver.onCompleted();
+    expectLastCall();
+
+    replay(mockDriver, mockObserver);
+
+    testContext.run(() -> 
service.submitQuery(QueryRequest.getDefaultInstance(), mockObserver));
+
+    verify(mockDriver, mockObserver);
+  }
+
+  @Test
+  public void testContextCancellationInvokesCancelCallback() throws Exception
+  {
+    AtomicBoolean callbackInvoked = new AtomicBoolean(false);
+    Capture<AtomicReference<Runnable>> callbackCapture = EasyMock.newCapture();
+
+    expect(mockDriver.submitQuery(
+        anyObject(QueryRequest.class),
+        anyObject(AuthenticationResult.class),
+        capture(callbackCapture)
+    )).andAnswer(() -> {
+      // Register a callback that tracks invocation
+      callbackCapture.getValue().set(() -> callbackInvoked.set(true));
+      return QueryResponse.getDefaultInstance();
+    });
+
+    mockObserver.onNext(anyObject(QueryResponse.class));
+    expectLastCall();
+    mockObserver.onCompleted();
+    expectLastCall();
+
+    replay(mockDriver, mockObserver);
+
+    testContext.run(() -> 
service.submitQuery(QueryRequest.getDefaultInstance(), mockObserver));
+
+    // Now cancel the context - listener should fire and invoke our callback
+    testContext.cancel(null);
+
+    assertTrue(callbackInvoked.get(), "Cancel callback should have been 
invoked");
+    verify(mockDriver, mockObserver);
+  }
+
+  @Test
+  public void testMidExecutionCancellationInvokesCallback() throws Exception
+  {
+    CountDownLatch queryStarted = new CountDownLatch(1);
+    CountDownLatch contextCancelled = new CountDownLatch(1);
+    AtomicBoolean callbackInvoked = new AtomicBoolean(false);
+
+    Capture<AtomicReference<Runnable>> callbackCapture = EasyMock.newCapture();
+
+    expect(mockDriver.submitQuery(
+        anyObject(QueryRequest.class),
+        anyObject(AuthenticationResult.class),
+        capture(callbackCapture)
+    )).andAnswer(() -> {
+      callbackCapture.getValue().set(() -> callbackInvoked.set(true));
+      queryStarted.countDown();
+      contextCancelled.await(1, TimeUnit.SECONDS);
+      throw new QueryInterruptedException(new 
CancellationException("Cancelled"));
+    });
+
+    Capture<Throwable> errorCapture = EasyMock.newCapture();
+    mockObserver.onError(capture(errorCapture));
+    expectLastCall();
+
+    replay(mockDriver, mockObserver);
+
+    Thread queryThread = new Thread(() ->
+        testContext.run(() -> 
service.submitQuery(QueryRequest.getDefaultInstance(), mockObserver))
+    );
+    queryThread.start();
+
+    assertTrue(queryStarted.await(1, TimeUnit.SECONDS), "Query should start");
+    testContext.cancel(null);
+    contextCancelled.countDown();
+
+    queryThread.join(1000);
+
+    assertTrue(callbackInvoked.get(), "Cancel callback should have been 
invoked");
+    verify(mockDriver, mockObserver);
+
+    StatusRuntimeException statusException = (StatusRuntimeException) 
errorCapture.getValue();
+    assertEquals(Status.Code.CANCELLED, statusException.getStatus().getCode());
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to