This is an automated email from the ASF dual-hosted git repository.
jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new 24aef1df4e4 Fix query cancellation on op-chain not scheduled (#16625)
24aef1df4e4 is described below
commit 24aef1df4e4d8caa3146516cebc43d00fb8ae7a3
Author: Xiaotian (Jackie) Jiang <[email protected]>
AuthorDate: Wed Aug 20 17:35:34 2025 -0600
Fix query cancellation on op-chain not scheduled (#16625)
---
.../runtime/executor/OpChainSchedulerService.java | 86 +++++++++++++++----
.../plan/pipeline/PipelineBreakerExecutor.java | 13 ++-
.../executor/OpChainSchedulerServiceTest.java | 98 +++++++++++++++++++---
.../apache/pinot/spi/utils/CommonConstants.java | 10 +++
4 files changed, 176 insertions(+), 31 deletions(-)
diff --git
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerService.java
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerService.java
index 18d34f35c67..807e0d038a5 100644
---
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerService.java
+++
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerService.java
@@ -27,6 +27,9 @@ import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.locks.Lock;
+import java.util.concurrent.locks.ReadWriteLock;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.stream.Collectors;
import org.apache.pinot.core.util.trace.TraceRunnable;
import org.apache.pinot.query.runtime.blocks.ErrorMseBlock;
@@ -37,44 +40,82 @@ import org.apache.pinot.query.runtime.operator.OpChainId;
import org.apache.pinot.query.runtime.plan.MultiStageQueryStats;
import org.apache.pinot.spi.accounting.ThreadExecutionContext;
import org.apache.pinot.spi.env.PinotConfiguration;
+import org.apache.pinot.spi.exception.QueryCancelledException;
import org.apache.pinot.spi.trace.Tracing;
import org.apache.pinot.spi.utils.CommonConstants.MultiStageQueryRunner;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+
+/// TODO: Use CID to manage the queries
public class OpChainSchedulerService {
private static final Logger LOGGER =
LoggerFactory.getLogger(OpChainSchedulerService.class);
+ private static final int NUM_QUERY_LOCKS = 1 << 10; // 1024 locks
+ private static final int QUERY_LOCK_MASK = NUM_QUERY_LOCKS - 1;
private final ExecutorService _executorService;
private final ConcurrentHashMap<OpChainId, Future<?>> _submittedOpChainMap =
new ConcurrentHashMap<>();
private final Cache<OpChainId, MultiStageOperator> _opChainCache;
-
+ private final ReadWriteLock[] _queryLocks;
+ private final Cache<Long, Boolean> _cancelledQueryCache;
public OpChainSchedulerService(ExecutorService executorService,
PinotConfiguration config) {
- this(
- executorService,
- config.getProperty(MultiStageQueryRunner.KEY_OF_OP_STATS_CACHE_SIZE,
+ this(executorService,
config.getProperty(MultiStageQueryRunner.KEY_OF_OP_STATS_CACHE_SIZE,
MultiStageQueryRunner.DEFAULT_OF_OP_STATS_CACHE_SIZE),
config.getProperty(MultiStageQueryRunner.KEY_OF_OP_STATS_CACHE_EXPIRE_MS,
- MultiStageQueryRunner.DEFAULT_OF_OP_STATS_CACHE_EXPIRE_MS)
- );
+ MultiStageQueryRunner.DEFAULT_OF_OP_STATS_CACHE_EXPIRE_MS),
+
config.getProperty(MultiStageQueryRunner.KEY_OF_CANCELLED_QUERY_CACHE_SIZE,
+ MultiStageQueryRunner.DEFAULT_OF_CANCELLED_QUERY_CACHE_SIZE),
+
config.getProperty(MultiStageQueryRunner.KEY_OF_CANCELLED_QUERY_CACHE_EXPIRE_MS,
+ MultiStageQueryRunner.DEFAULT_OF_CANCELLED_QUERY_CACHE_EXPIRE_MS));
}
public OpChainSchedulerService(ExecutorService executorService) {
this(executorService, MultiStageQueryRunner.DEFAULT_OF_OP_STATS_CACHE_SIZE,
- MultiStageQueryRunner.DEFAULT_OF_OP_STATS_CACHE_EXPIRE_MS);
+ MultiStageQueryRunner.DEFAULT_OF_OP_STATS_CACHE_EXPIRE_MS,
+ MultiStageQueryRunner.DEFAULT_OF_CANCELLED_QUERY_CACHE_SIZE,
+ MultiStageQueryRunner.DEFAULT_OF_CANCELLED_QUERY_CACHE_EXPIRE_MS);
}
- public OpChainSchedulerService(ExecutorService executorService, int
maxWeight, long expireAfterWriteMs) {
+ public OpChainSchedulerService(ExecutorService executorService, int
opStatsCacheSize, long opStatsCacheExpireMs,
+ int cancelledQueryCacheSize, long cancelledQueryCacheExpireMs) {
_executorService = executorService;
_opChainCache = CacheBuilder.newBuilder()
.weigher((OpChainId key, MultiStageOperator value) ->
countOperators(value))
- .maximumWeight(maxWeight)
- .expireAfterWrite(expireAfterWriteMs, TimeUnit.MILLISECONDS)
+ .maximumWeight(opStatsCacheSize)
+ .expireAfterWrite(opStatsCacheExpireMs, TimeUnit.MILLISECONDS)
+ .build();
+ _queryLocks = new ReadWriteLock[NUM_QUERY_LOCKS];
+ for (int i = 0; i < NUM_QUERY_LOCKS; i++) {
+ _queryLocks[i] = new ReentrantReadWriteLock();
+ }
+ _cancelledQueryCache = CacheBuilder.newBuilder()
+ .maximumSize(cancelledQueryCacheSize)
+ .expireAfterWrite(cancelledQueryCacheExpireMs, TimeUnit.MILLISECONDS)
.build();
}
public void register(OpChain operatorChain) {
+ // Acquire read lock for the query to ensure that the query is not
cancelled while scheduling the operator chain.
+ long requestId = operatorChain.getId().getRequestId();
+ Lock readLock = getQueryLock(requestId).readLock();
+ readLock.lock();
+ try {
+ // Do not schedule the operator chain if the query has been cancelled.
+ if (_cancelledQueryCache.getIfPresent(requestId) != null) {
+ LOGGER.debug("({}): Query has been cancelled", operatorChain);
+ throw new QueryCancelledException(
+ "Query has been cancelled before op-chain: " +
operatorChain.getId() + " being scheduled");
+ } else {
+ registerInternal(operatorChain);
+ }
+ } finally {
+ readLock.unlock();
+ }
+ }
+
+ private void registerInternal(OpChain operatorChain) {
+ OpChainId opChainId = operatorChain.getId();
Future<?> scheduledFuture = _executorService.submit(new TraceRunnable() {
@Override
public void runJob() {
@@ -83,8 +124,8 @@ public class OpChainSchedulerService {
// try-with-resources to ensure that the operator chain is closed
// TODO: Change the code so we ownership is expressed in the code in a
better way
try (OpChain closeMe = operatorChain) {
-
Tracing.ThreadAccountantOps.setupWorker(operatorChain.getId().getStageId(),
- ThreadExecutionContext.TaskType.MSE,
operatorChain.getParentContext());
+ Tracing.ThreadAccountantOps.setupWorker(opChainId.getStageId(),
ThreadExecutionContext.TaskType.MSE,
+ operatorChain.getParentContext());
LOGGER.trace("({}): Executing", operatorChain);
MseBlock result = operatorChain.getRoot().nextBlock();
while (result.isData()) {
@@ -96,13 +137,13 @@ public class OpChainSchedulerService {
LOGGER.error("({}): Completed erroneously {} {}", operatorChain,
stats, errorBlock.getErrorMessages());
} else {
LOGGER.debug("({}): Completed {}", operatorChain, stats);
- _opChainCache.invalidate(operatorChain.getId());
+ _opChainCache.invalidate(opChainId);
}
} catch (Exception e) {
LOGGER.error("({}): Failed to execute operator chain!",
operatorChain, e);
thrown = e;
} finally {
- _submittedOpChainMap.remove(operatorChain.getId());
+ _submittedOpChainMap.remove(opChainId);
if (errorBlock != null || thrown != null) {
if (thrown == null) {
thrown = new RuntimeException("Error block " +
errorBlock.getErrorMessages());
@@ -113,11 +154,20 @@ public class OpChainSchedulerService {
}
}
});
- _opChainCache.put(operatorChain.getId(), operatorChain.getRoot());
- _submittedOpChainMap.put(operatorChain.getId(), scheduledFuture);
+ _opChainCache.put(opChainId, operatorChain.getRoot());
+ _submittedOpChainMap.put(opChainId, scheduledFuture);
}
public Map<Integer, MultiStageQueryStats.StageStats.Closed> cancel(long
requestId) {
+ // Acquire write lock for the query to ensure that the query is not
cancelled while scheduling the operator chain.
+ Lock writeLock = getQueryLock(requestId).writeLock();
+ writeLock.lock();
+ try {
+ _cancelledQueryCache.put(requestId, Boolean.TRUE);
+ } finally {
+ writeLock.unlock();
+ }
+
// simple cancellation. for leaf stage this cannot be a dangling opchain
b/c they will eventually be cleared up
// via query timeout.
Iterator<Map.Entry<OpChainId, Future<?>>> iterator =
_submittedOpChainMap.entrySet().iterator();
@@ -166,4 +216,8 @@ public class OpChainSchedulerService {
}
return result;
}
+
+ private ReadWriteLock getQueryLock(long requestId) {
+ return _queryLocks[(int) (requestId & QUERY_LOCK_MASK)];
+ }
}
diff --git
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/pipeline/PipelineBreakerExecutor.java
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/pipeline/PipelineBreakerExecutor.java
index 27f375e4dbc..2f997918550 100644
---
a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/pipeline/PipelineBreakerExecutor.java
+++
b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/pipeline/PipelineBreakerExecutor.java
@@ -18,7 +18,6 @@
*/
package org.apache.pinot.query.runtime.plan.pipeline;
-import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
@@ -38,6 +37,7 @@ import org.apache.pinot.query.runtime.operator.OpChain;
import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
import org.apache.pinot.query.runtime.plan.PlanNodeToOpChain;
import org.apache.pinot.spi.accounting.ThreadExecutionContext;
+import org.apache.pinot.spi.exception.QueryCancelledException;
import org.apache.pinot.spi.query.QueryThreadContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -107,9 +107,14 @@ public class PipelineBreakerExecutor {
stagePlan.getStageMetadata(), workerMetadata, null,
parentContext, sendStats);
return execute(scheduler, pipelineBreakerContext,
opChainExecutionContext);
} catch (Exception e) {
- LOGGER.error("Caught exception executing pipeline breaker for request:
{}, stage: {}", requestId,
- stagePlan.getStageMetadata().getStageId(), e);
- return new
PipelineBreakerResult(pipelineBreakerContext.getNodeIdMap(),
Collections.emptyMap(),
+ if (e instanceof QueryCancelledException) {
+ LOGGER.debug("Pipeline breaker execution cancelled for request: {},
stage: {}", requestId,
+ stagePlan.getStageMetadata().getStageId(), e);
+ } else {
+ LOGGER.error("Caught exception executing pipeline breaker for
request: {}, stage: {}", requestId,
+ stagePlan.getStageMetadata().getStageId(), e);
+ }
+ return new
PipelineBreakerResult(pipelineBreakerContext.getNodeIdMap(), Map.of(),
ErrorMseBlock.fromException(e), null);
}
} else {
diff --git
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerServiceTest.java
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerServiceTest.java
index 40c676313ae..ac156ec2d83 100644
---
a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerServiceTest.java
+++
b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/executor/OpChainSchedulerServiceTest.java
@@ -34,10 +34,10 @@ import
org.apache.pinot.query.runtime.operator.MultiStageOperator;
import org.apache.pinot.query.runtime.operator.OpChain;
import org.apache.pinot.query.runtime.plan.MultiStageQueryStats;
import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
+import org.apache.pinot.spi.exception.QueryCancelledException;
import org.apache.pinot.spi.executor.ExecutorServiceUtils;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
-import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.BeforeMethod;
@@ -46,6 +46,8 @@ import org.testng.annotations.Test;
import static org.mockito.Mockito.clearInvocations;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
+import static org.testng.Assert.assertTrue;
+import static org.testng.Assert.fail;
public class OpChainSchedulerServiceTest {
@@ -81,8 +83,7 @@ public class OpChainSchedulerServiceTest {
WorkerMetadata workerMetadata = new WorkerMetadata(0, Map.of(), Map.of());
OpChainExecutionContext context =
new OpChainExecutionContext(mailboxService, 123L, Long.MAX_VALUE,
Long.MAX_VALUE, Map.of(),
- new StageMetadata(0, ImmutableList.of(workerMetadata), Map.of()),
workerMetadata, null, null,
- true);
+ new StageMetadata(0, ImmutableList.of(workerMetadata), Map.of()),
workerMetadata, null, null, true);
return new OpChain(context, operator);
}
@@ -100,7 +101,7 @@ public class OpChainSchedulerServiceTest {
schedulerService.register(opChain);
- Assert.assertTrue(latch.await(10, TimeUnit.SECONDS), "expected await to be
called in less than 10 seconds");
+ assertTrue(latch.await(10, TimeUnit.SECONDS), "expected await to be called
in less than 10 seconds");
}
@Test
@@ -117,7 +118,7 @@ public class OpChainSchedulerServiceTest {
schedulerService.register(opChain);
- Assert.assertTrue(latch.await(10, TimeUnit.SECONDS), "expected await to be
called in less than 10 seconds");
+ assertTrue(latch.await(10, TimeUnit.SECONDS), "expected await to be called
in less than 10 seconds");
}
@Test
@@ -135,7 +136,7 @@ public class OpChainSchedulerServiceTest {
schedulerService.register(opChain);
- Assert.assertTrue(latch.await(10, TimeUnit.SECONDS), "expected await to be
called in less than 10 seconds");
+ assertTrue(latch.await(10, TimeUnit.SECONDS), "expected await to be called
in less than 10 seconds");
}
@Test
@@ -153,7 +154,7 @@ public class OpChainSchedulerServiceTest {
schedulerService.register(opChain);
- Assert.assertTrue(latch.await(10, TimeUnit.SECONDS), "expected await to be
called in less than 10 seconds");
+ assertTrue(latch.await(10, TimeUnit.SECONDS), "expected await to be called
in less than 10 seconds");
}
@Test
@@ -179,12 +180,12 @@ public class OpChainSchedulerServiceTest {
schedulerService.register(opChain);
- Assert.assertTrue(opChainStarted.await(10, TimeUnit.SECONDS), "op chain
doesn't seem to be started");
+ assertTrue(opChainStarted.await(10, TimeUnit.SECONDS), "op chain doesn't
seem to be started");
// now cancel the request.
- schedulerService.cancel(123);
+ schedulerService.cancel(123L);
- Assert.assertTrue(cancelLatch.await(10, TimeUnit.SECONDS), "expected
OpChain to be cancelled");
+ assertTrue(cancelLatch.await(10, TimeUnit.SECONDS), "expected OpChain to
be cancelled");
Mockito.verify(_operatorA, Mockito.times(1)).cancel(Mockito.any());
}
@@ -203,7 +204,82 @@ public class OpChainSchedulerServiceTest {
schedulerService.register(opChain);
- Assert.assertTrue(cancelLatch.await(10, TimeUnit.SECONDS), "expected
OpChain to be cancelled");
+ assertTrue(cancelLatch.await(10, TimeUnit.SECONDS), "expected OpChain to
be cancelled");
Mockito.verify(_operatorA, Mockito.times(1)).cancel(Mockito.any());
}
+
+ @Test
+ public void
shouldThrowQueryCancelledExceptionWhenRegisteringOpChainAfterQueryCancellation()
{
+ OpChain opChain = getChain(_operatorA);
+ OpChainSchedulerService schedulerService = new
OpChainSchedulerService(_executor);
+
+ // First cancel the query with the same requestId (123L as defined in
getChain method)
+ schedulerService.cancel(123L);
+
+ // Now try to register an OpChain for the same query - should throw
QueryCancelledException
+ try {
+ schedulerService.register(opChain);
+ fail("Expected QueryCancelledException to be thrown when registering
OpChain after query cancellation");
+ } catch (QueryCancelledException e) {
+ assertTrue(e.getMessage().contains("Query has been cancelled before
op-chain"));
+ assertTrue(e.getMessage().contains("being scheduled"));
+ }
+ }
+
+ @Test
+ public void shouldHandleConcurrentCancellationAndRegistration()
+ throws InterruptedException {
+ OpChain opChain = getChain(_operatorA);
+ OpChainSchedulerService schedulerService = new
OpChainSchedulerService(_executor);
+
+ // Setup mock to slow down the execution so we can test concurrent
cancellation
+ CountDownLatch registrationStarted = new CountDownLatch(1);
+ CountDownLatch cancellationCanProceed = new CountDownLatch(1);
+ Mockito.when(_operatorA.nextBlock()).thenAnswer(inv -> {
+ registrationStarted.countDown();
+ cancellationCanProceed.await();
+ return SuccessMseBlock.INSTANCE;
+ });
+ Mockito.doAnswer(inv ->
MultiStageQueryStats.emptyStats(1)).when(_operatorA).calculateStats();
+
+ // Start registration in a separate thread
+ CountDownLatch registrationCompleted = new CountDownLatch(1);
+ boolean[] registrationSucceeded = {false};
+ Thread registrationThread = new Thread(() -> {
+ try {
+ schedulerService.register(opChain);
+ registrationSucceeded[0] = true;
+ } catch (QueryCancelledException e) {
+ // Expected if cancellation happens before registration
+ registrationSucceeded[0] = false;
+ } finally {
+ registrationCompleted.countDown();
+ }
+ });
+
+ registrationThread.start();
+
+ // Wait for registration to start
+ assertTrue(registrationStarted.await(10, TimeUnit.SECONDS), "Registration
should have started");
+
+ // Now cancel the query while it's running
+ schedulerService.cancel(123L);
+
+ // Allow the opchain execution to complete
+ cancellationCanProceed.countDown();
+
+ // Wait for registration thread to complete
+ assertTrue(registrationCompleted.await(10, TimeUnit.SECONDS),
"Registration thread should complete");
+
+ // The registration should have succeeded since it started before
cancellation
+ assertTrue(registrationSucceeded[0], "Registration should have succeeded");
+
+ // Now try to register another OpChain with the same requestId - should
fail
+ try {
+ schedulerService.register(getChain(_operatorA));
+ fail("Expected QueryCancelledException for subsequent registration after
cancellation");
+ } catch (QueryCancelledException e) {
+ assertTrue(e.getMessage().contains("Query has been cancelled before
op-chain"));
+ }
+ }
}
diff --git
a/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java
b/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java
index f725dfde334..cdf5f378f7e 100644
--- a/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java
+++ b/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java
@@ -1953,7 +1953,17 @@ public class CommonConstants {
/// Max time to keep the op stats in the cache.
public static final String KEY_OF_OP_STATS_CACHE_EXPIRE_MS =
"pinot.server.query.op.stats.cache.ms";
public static final int DEFAULT_OF_OP_STATS_CACHE_EXPIRE_MS = 60 * 1000;
+
+ /// Max number of cancelled queries to keep in the cache.
+ public static final String KEY_OF_CANCELLED_QUERY_CACHE_SIZE =
"pinot.server.query.cancelled.cache.size";
+ public static final int DEFAULT_OF_CANCELLED_QUERY_CACHE_SIZE = 1000;
+
+ /// Max time to keep the cancelled queries in the cache.
+ public static final String KEY_OF_CANCELLED_QUERY_CACHE_EXPIRE_MS =
"pinot.server.query.cancelled.cache.ms";
+ public static final int DEFAULT_OF_CANCELLED_QUERY_CACHE_EXPIRE_MS = 60 *
1000;
+
/// Timeout of the cancel request, in milliseconds.
+ /// TODO: This is used by the broker. Consider renaming it.
public static final String KEY_OF_CANCEL_TIMEOUT_MS =
"pinot.server.query.cancel.timeout.ms";
public static final long DEFAULT_OF_CANCEL_TIMEOUT_MS = 1000;
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]