This is an automated email from the ASF dual-hosted git repository.

jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 24aef1df4e4 Fix query cancellation on op-chain not scheduled (#16625)
24aef1df4e4 is described below

commit 24aef1df4e4d8caa3146516cebc43d00fb8ae7a3
Author: Xiaotian (Jackie) Jiang <[email protected]>
AuthorDate: Wed Aug 20 17:35:34 2025 -0600

    Fix query cancellation on op-chain not scheduled (#16625)
---
 .../runtime/executor/OpChainSchedulerService.java  | 86 +++++++++++++++----
 .../plan/pipeline/PipelineBreakerExecutor.java     | 13 ++-
 .../executor/OpChainSchedulerServiceTest.java      | 98 +++++++++++++++++++---
 .../apache/pinot/spi/utils/CommonConstants.java    | 10 +++
 4 files changed, 176 insertions(+), 31 deletions(-)

diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerService.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerService.java
index 18d34f35c67..807e0d038a5 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerService.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerService.java
@@ -27,6 +27,9 @@ import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.locks.Lock;
+import java.util.concurrent.locks.ReadWriteLock;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
 import java.util.stream.Collectors;
 import org.apache.pinot.core.util.trace.TraceRunnable;
 import org.apache.pinot.query.runtime.blocks.ErrorMseBlock;
@@ -37,44 +40,82 @@ import org.apache.pinot.query.runtime.operator.OpChainId;
 import org.apache.pinot.query.runtime.plan.MultiStageQueryStats;
 import org.apache.pinot.spi.accounting.ThreadExecutionContext;
 import org.apache.pinot.spi.env.PinotConfiguration;
+import org.apache.pinot.spi.exception.QueryCancelledException;
 import org.apache.pinot.spi.trace.Tracing;
 import org.apache.pinot.spi.utils.CommonConstants.MultiStageQueryRunner;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+
+/// TODO: Use CID to manage the queries
 public class OpChainSchedulerService {
   private static final Logger LOGGER = 
LoggerFactory.getLogger(OpChainSchedulerService.class);
+  private static final int NUM_QUERY_LOCKS = 1 << 10; // 1024 locks
+  private static final int QUERY_LOCK_MASK = NUM_QUERY_LOCKS - 1;
 
   private final ExecutorService _executorService;
   private final ConcurrentHashMap<OpChainId, Future<?>> _submittedOpChainMap = 
new ConcurrentHashMap<>();
   private final Cache<OpChainId, MultiStageOperator> _opChainCache;
-
+  private final ReadWriteLock[] _queryLocks;
+  private final Cache<Long, Boolean> _cancelledQueryCache;
 
   public OpChainSchedulerService(ExecutorService executorService, 
PinotConfiguration config) {
-    this(
-        executorService,
-        config.getProperty(MultiStageQueryRunner.KEY_OF_OP_STATS_CACHE_SIZE,
+    this(executorService, 
config.getProperty(MultiStageQueryRunner.KEY_OF_OP_STATS_CACHE_SIZE,
             MultiStageQueryRunner.DEFAULT_OF_OP_STATS_CACHE_SIZE),
         
config.getProperty(MultiStageQueryRunner.KEY_OF_OP_STATS_CACHE_EXPIRE_MS,
-            MultiStageQueryRunner.DEFAULT_OF_OP_STATS_CACHE_EXPIRE_MS)
-    );
+            MultiStageQueryRunner.DEFAULT_OF_OP_STATS_CACHE_EXPIRE_MS),
+        
config.getProperty(MultiStageQueryRunner.KEY_OF_CANCELLED_QUERY_CACHE_SIZE,
+            MultiStageQueryRunner.DEFAULT_OF_CANCELLED_QUERY_CACHE_SIZE),
+        
config.getProperty(MultiStageQueryRunner.KEY_OF_CANCELLED_QUERY_CACHE_EXPIRE_MS,
+            MultiStageQueryRunner.DEFAULT_OF_CANCELLED_QUERY_CACHE_EXPIRE_MS));
   }
 
   public OpChainSchedulerService(ExecutorService executorService) {
     this(executorService, MultiStageQueryRunner.DEFAULT_OF_OP_STATS_CACHE_SIZE,
-        MultiStageQueryRunner.DEFAULT_OF_OP_STATS_CACHE_EXPIRE_MS);
+        MultiStageQueryRunner.DEFAULT_OF_OP_STATS_CACHE_EXPIRE_MS,
+        MultiStageQueryRunner.DEFAULT_OF_CANCELLED_QUERY_CACHE_SIZE,
+        MultiStageQueryRunner.DEFAULT_OF_CANCELLED_QUERY_CACHE_EXPIRE_MS);
   }
 
-  public OpChainSchedulerService(ExecutorService executorService, int 
maxWeight, long expireAfterWriteMs) {
+  public OpChainSchedulerService(ExecutorService executorService, int 
opStatsCacheSize, long opStatsCacheExpireMs,
+      int cancelledQueryCacheSize, long cancelledQueryCacheExpireMs) {
     _executorService = executorService;
     _opChainCache = CacheBuilder.newBuilder()
         .weigher((OpChainId key, MultiStageOperator value) -> 
countOperators(value))
-        .maximumWeight(maxWeight)
-        .expireAfterWrite(expireAfterWriteMs, TimeUnit.MILLISECONDS)
+        .maximumWeight(opStatsCacheSize)
+        .expireAfterWrite(opStatsCacheExpireMs, TimeUnit.MILLISECONDS)
+        .build();
+    _queryLocks = new ReadWriteLock[NUM_QUERY_LOCKS];
+    for (int i = 0; i < NUM_QUERY_LOCKS; i++) {
+      _queryLocks[i] = new ReentrantReadWriteLock();
+    }
+    _cancelledQueryCache = CacheBuilder.newBuilder()
+        .maximumSize(cancelledQueryCacheSize)
+        .expireAfterWrite(cancelledQueryCacheExpireMs, TimeUnit.MILLISECONDS)
         .build();
   }
 
   public void register(OpChain operatorChain) {
+    // Acquire read lock for the query to ensure that the query is not 
cancelled while scheduling the operator chain.
+    long requestId = operatorChain.getId().getRequestId();
+    Lock readLock = getQueryLock(requestId).readLock();
+    readLock.lock();
+    try {
+      // Do not schedule the operator chain if the query has been cancelled.
+      if (_cancelledQueryCache.getIfPresent(requestId) != null) {
+        LOGGER.debug("({}): Query has been cancelled", operatorChain);
+        throw new QueryCancelledException(
+            "Query has been cancelled before op-chain: " + 
operatorChain.getId() + " being scheduled");
+      } else {
+        registerInternal(operatorChain);
+      }
+    } finally {
+      readLock.unlock();
+    }
+  }
+
+  private void registerInternal(OpChain operatorChain) {
+    OpChainId opChainId = operatorChain.getId();
     Future<?> scheduledFuture = _executorService.submit(new TraceRunnable() {
       @Override
       public void runJob() {
@@ -83,8 +124,8 @@ public class OpChainSchedulerService {
         // try-with-resources to ensure that the operator chain is closed
         // TODO: Change the code so we ownership is expressed in the code in a 
better way
         try (OpChain closeMe = operatorChain) {
-          
Tracing.ThreadAccountantOps.setupWorker(operatorChain.getId().getStageId(),
-              ThreadExecutionContext.TaskType.MSE, 
operatorChain.getParentContext());
+          Tracing.ThreadAccountantOps.setupWorker(opChainId.getStageId(), 
ThreadExecutionContext.TaskType.MSE,
+              operatorChain.getParentContext());
           LOGGER.trace("({}): Executing", operatorChain);
           MseBlock result = operatorChain.getRoot().nextBlock();
           while (result.isData()) {
@@ -96,13 +137,13 @@ public class OpChainSchedulerService {
             LOGGER.error("({}): Completed erroneously {} {}", operatorChain, 
stats, errorBlock.getErrorMessages());
           } else {
             LOGGER.debug("({}): Completed {}", operatorChain, stats);
-            _opChainCache.invalidate(operatorChain.getId());
+            _opChainCache.invalidate(opChainId);
           }
         } catch (Exception e) {
           LOGGER.error("({}): Failed to execute operator chain!", 
operatorChain, e);
           thrown = e;
         } finally {
-          _submittedOpChainMap.remove(operatorChain.getId());
+          _submittedOpChainMap.remove(opChainId);
           if (errorBlock != null || thrown != null) {
             if (thrown == null) {
               thrown = new RuntimeException("Error block " + 
errorBlock.getErrorMessages());
@@ -113,11 +154,20 @@ public class OpChainSchedulerService {
         }
       }
     });
-    _opChainCache.put(operatorChain.getId(), operatorChain.getRoot());
-    _submittedOpChainMap.put(operatorChain.getId(), scheduledFuture);
+    _opChainCache.put(opChainId, operatorChain.getRoot());
+    _submittedOpChainMap.put(opChainId, scheduledFuture);
   }
 
   public Map<Integer, MultiStageQueryStats.StageStats.Closed> cancel(long 
requestId) {
+    // Acquire write lock for the query to ensure that the query is not 
cancelled while scheduling the operator chain.
+    Lock writeLock = getQueryLock(requestId).writeLock();
+    writeLock.lock();
+    try {
+      _cancelledQueryCache.put(requestId, Boolean.TRUE);
+    } finally {
+      writeLock.unlock();
+    }
+
     // simple cancellation. for leaf stage this cannot be a dangling opchain 
b/c they will eventually be cleared up
     // via query timeout.
     Iterator<Map.Entry<OpChainId, Future<?>>> iterator = 
_submittedOpChainMap.entrySet().iterator();
@@ -166,4 +216,8 @@ public class OpChainSchedulerService {
     }
     return result;
   }
+
+  private ReadWriteLock getQueryLock(long requestId) {
+    return _queryLocks[(int) (requestId & QUERY_LOCK_MASK)];
+  }
 }
diff --git 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/pipeline/PipelineBreakerExecutor.java
 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/pipeline/PipelineBreakerExecutor.java
index 27f375e4dbc..2f997918550 100644
--- 
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/pipeline/PipelineBreakerExecutor.java
+++ 
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/pipeline/PipelineBreakerExecutor.java
@@ -18,7 +18,6 @@
  */
 package org.apache.pinot.query.runtime.plan.pipeline;
 
-import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.concurrent.CountDownLatch;
@@ -38,6 +37,7 @@ import org.apache.pinot.query.runtime.operator.OpChain;
 import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
 import org.apache.pinot.query.runtime.plan.PlanNodeToOpChain;
 import org.apache.pinot.spi.accounting.ThreadExecutionContext;
+import org.apache.pinot.spi.exception.QueryCancelledException;
 import org.apache.pinot.spi.query.QueryThreadContext;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -107,9 +107,14 @@ public class PipelineBreakerExecutor {
                 stagePlan.getStageMetadata(), workerMetadata, null, 
parentContext, sendStats);
         return execute(scheduler, pipelineBreakerContext, 
opChainExecutionContext);
       } catch (Exception e) {
-        LOGGER.error("Caught exception executing pipeline breaker for request: 
{}, stage: {}", requestId,
-            stagePlan.getStageMetadata().getStageId(), e);
-        return new 
PipelineBreakerResult(pipelineBreakerContext.getNodeIdMap(), 
Collections.emptyMap(),
+        if (e instanceof QueryCancelledException) {
+          LOGGER.debug("Pipeline breaker execution cancelled for request: {}, 
stage: {}", requestId,
+              stagePlan.getStageMetadata().getStageId(), e);
+        } else {
+          LOGGER.error("Caught exception executing pipeline breaker for 
request: {}, stage: {}", requestId,
+              stagePlan.getStageMetadata().getStageId(), e);
+        }
+        return new 
PipelineBreakerResult(pipelineBreakerContext.getNodeIdMap(), Map.of(),
             ErrorMseBlock.fromException(e), null);
       }
     } else {
diff --git 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerServiceTest.java
 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerServiceTest.java
index 40c676313ae..ac156ec2d83 100644
--- 
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerServiceTest.java
+++ 
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerServiceTest.java
@@ -34,10 +34,10 @@ import 
org.apache.pinot.query.runtime.operator.MultiStageOperator;
 import org.apache.pinot.query.runtime.operator.OpChain;
 import org.apache.pinot.query.runtime.plan.MultiStageQueryStats;
 import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
+import org.apache.pinot.spi.exception.QueryCancelledException;
 import org.apache.pinot.spi.executor.ExecutorServiceUtils;
 import org.mockito.Mockito;
 import org.mockito.MockitoAnnotations;
-import org.testng.Assert;
 import org.testng.annotations.AfterClass;
 import org.testng.annotations.BeforeClass;
 import org.testng.annotations.BeforeMethod;
@@ -46,6 +46,8 @@ import org.testng.annotations.Test;
 import static org.mockito.Mockito.clearInvocations;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
+import static org.testng.Assert.assertTrue;
+import static org.testng.Assert.fail;
 
 
 public class OpChainSchedulerServiceTest {
@@ -81,8 +83,7 @@ public class OpChainSchedulerServiceTest {
     WorkerMetadata workerMetadata = new WorkerMetadata(0, Map.of(), Map.of());
     OpChainExecutionContext context =
         new OpChainExecutionContext(mailboxService, 123L, Long.MAX_VALUE, 
Long.MAX_VALUE, Map.of(),
-            new StageMetadata(0, ImmutableList.of(workerMetadata), Map.of()), 
workerMetadata, null, null,
-            true);
+            new StageMetadata(0, ImmutableList.of(workerMetadata), Map.of()), 
workerMetadata, null, null, true);
     return new OpChain(context, operator);
   }
 
@@ -100,7 +101,7 @@ public class OpChainSchedulerServiceTest {
 
     schedulerService.register(opChain);
 
-    Assert.assertTrue(latch.await(10, TimeUnit.SECONDS), "expected await to be 
called in less than 10 seconds");
+    assertTrue(latch.await(10, TimeUnit.SECONDS), "expected await to be called 
in less than 10 seconds");
   }
 
   @Test
@@ -117,7 +118,7 @@ public class OpChainSchedulerServiceTest {
 
     schedulerService.register(opChain);
 
-    Assert.assertTrue(latch.await(10, TimeUnit.SECONDS), "expected await to be 
called in less than 10 seconds");
+    assertTrue(latch.await(10, TimeUnit.SECONDS), "expected await to be called 
in less than 10 seconds");
   }
 
   @Test
@@ -135,7 +136,7 @@ public class OpChainSchedulerServiceTest {
 
     schedulerService.register(opChain);
 
-    Assert.assertTrue(latch.await(10, TimeUnit.SECONDS), "expected await to be 
called in less than 10 seconds");
+    assertTrue(latch.await(10, TimeUnit.SECONDS), "expected await to be called 
in less than 10 seconds");
   }
 
   @Test
@@ -153,7 +154,7 @@ public class OpChainSchedulerServiceTest {
 
     schedulerService.register(opChain);
 
-    Assert.assertTrue(latch.await(10, TimeUnit.SECONDS), "expected await to be 
called in less than 10 seconds");
+    assertTrue(latch.await(10, TimeUnit.SECONDS), "expected await to be called 
in less than 10 seconds");
   }
 
   @Test
@@ -179,12 +180,12 @@ public class OpChainSchedulerServiceTest {
 
     schedulerService.register(opChain);
 
-    Assert.assertTrue(opChainStarted.await(10, TimeUnit.SECONDS), "op chain 
doesn't seem to be started");
+    assertTrue(opChainStarted.await(10, TimeUnit.SECONDS), "op chain doesn't 
seem to be started");
 
     // now cancel the request.
-    schedulerService.cancel(123);
+    schedulerService.cancel(123L);
 
-    Assert.assertTrue(cancelLatch.await(10, TimeUnit.SECONDS), "expected 
OpChain to be cancelled");
+    assertTrue(cancelLatch.await(10, TimeUnit.SECONDS), "expected OpChain to 
be cancelled");
     Mockito.verify(_operatorA, Mockito.times(1)).cancel(Mockito.any());
   }
 
@@ -203,7 +204,82 @@ public class OpChainSchedulerServiceTest {
 
     schedulerService.register(opChain);
 
-    Assert.assertTrue(cancelLatch.await(10, TimeUnit.SECONDS), "expected 
OpChain to be cancelled");
+    assertTrue(cancelLatch.await(10, TimeUnit.SECONDS), "expected OpChain to 
be cancelled");
     Mockito.verify(_operatorA, Mockito.times(1)).cancel(Mockito.any());
   }
+
+  @Test
+  public void 
shouldThrowQueryCancelledExceptionWhenRegisteringOpChainAfterQueryCancellation()
 {
+    OpChain opChain = getChain(_operatorA);
+    OpChainSchedulerService schedulerService = new 
OpChainSchedulerService(_executor);
+
+    // First cancel the query with the same requestId (123L as defined in 
getChain method)
+    schedulerService.cancel(123L);
+
+    // Now try to register an OpChain for the same query - should throw 
QueryCancelledException
+    try {
+      schedulerService.register(opChain);
+      fail("Expected QueryCancelledException to be thrown when registering 
OpChain after query cancellation");
+    } catch (QueryCancelledException e) {
+      assertTrue(e.getMessage().contains("Query has been cancelled before 
op-chain"));
+      assertTrue(e.getMessage().contains("being scheduled"));
+    }
+  }
+
+  @Test
+  public void shouldHandleConcurrentCancellationAndRegistration()
+      throws InterruptedException {
+    OpChain opChain = getChain(_operatorA);
+    OpChainSchedulerService schedulerService = new 
OpChainSchedulerService(_executor);
+
+    // Setup mock to slow down the execution so we can test concurrent 
cancellation
+    CountDownLatch registrationStarted = new CountDownLatch(1);
+    CountDownLatch cancellationCanProceed = new CountDownLatch(1);
+    Mockito.when(_operatorA.nextBlock()).thenAnswer(inv -> {
+      registrationStarted.countDown();
+      cancellationCanProceed.await();
+      return SuccessMseBlock.INSTANCE;
+    });
+    Mockito.doAnswer(inv -> 
MultiStageQueryStats.emptyStats(1)).when(_operatorA).calculateStats();
+
+    // Start registration in a separate thread
+    CountDownLatch registrationCompleted = new CountDownLatch(1);
+    boolean[] registrationSucceeded = {false};
+    Thread registrationThread = new Thread(() -> {
+      try {
+        schedulerService.register(opChain);
+        registrationSucceeded[0] = true;
+      } catch (QueryCancelledException e) {
+        // Expected if cancellation happens before registration
+        registrationSucceeded[0] = false;
+      } finally {
+        registrationCompleted.countDown();
+      }
+    });
+
+    registrationThread.start();
+
+    // Wait for registration to start
+    assertTrue(registrationStarted.await(10, TimeUnit.SECONDS), "Registration 
should have started");
+
+    // Now cancel the query while it's running
+    schedulerService.cancel(123L);
+
+    // Allow the opchain execution to complete
+    cancellationCanProceed.countDown();
+
+    // Wait for registration thread to complete
+    assertTrue(registrationCompleted.await(10, TimeUnit.SECONDS), 
"Registration thread should complete");
+
+    // The registration should have succeeded since it started before 
cancellation
+    assertTrue(registrationSucceeded[0], "Registration should have succeeded");
+
+    // Now try to register another OpChain with the same requestId - should 
fail
+    try {
+      schedulerService.register(getChain(_operatorA));
+      fail("Expected QueryCancelledException for subsequent registration after 
cancellation");
+    } catch (QueryCancelledException e) {
+      assertTrue(e.getMessage().contains("Query has been cancelled before 
op-chain"));
+    }
+  }
 }
diff --git 
a/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java 
b/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java
index f725dfde334..cdf5f378f7e 100644
--- a/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java
+++ b/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java
@@ -1953,7 +1953,17 @@ public class CommonConstants {
     /// Max time to keep the op stats in the cache.
     public static final String KEY_OF_OP_STATS_CACHE_EXPIRE_MS = 
"pinot.server.query.op.stats.cache.ms";
     public static final int DEFAULT_OF_OP_STATS_CACHE_EXPIRE_MS = 60 * 1000;
+
+    /// Max number of cancelled queries to keep in the cache.
+    public static final String KEY_OF_CANCELLED_QUERY_CACHE_SIZE = 
"pinot.server.query.cancelled.cache.size";
+    public static final int DEFAULT_OF_CANCELLED_QUERY_CACHE_SIZE = 1000;
+
+    /// Max time to keep the cancelled queries in the cache.
+    public static final String KEY_OF_CANCELLED_QUERY_CACHE_EXPIRE_MS = 
"pinot.server.query.cancelled.cache.ms";
+    public static final int DEFAULT_OF_CANCELLED_QUERY_CACHE_EXPIRE_MS = 60 * 
1000;
+
     /// Timeout of the cancel request, in milliseconds.
+    /// TODO: This is used by the broker. Consider renaming it.
     public static final String KEY_OF_CANCEL_TIMEOUT_MS = 
"pinot.server.query.cancel.timeout.ms";
     public static final long DEFAULT_OF_CANCEL_TIMEOUT_MS = 1000;
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to