This is an automated email from the ASF dual-hosted git repository.

asdf2014 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git


The following commit(s) were added to refs/heads/master by this push:
     new 302739aa58c more aggressive cancellation of broker parallel merge, 
more chill blocking queue timeouts, and query cancellation participation 
(#16748)
302739aa58c is described below

commit 302739aa58c7a5c9fa7c16686d70de9c102d5dd9
Author: Clint Wylie <[email protected]>
AuthorDate: Tue Jul 23 23:58:34 2024 -0700

    more aggressive cancellation of broker parallel merge, more chill blocking 
queue timeouts, and query cancellation participation (#16748)
    
    * more aggressive cancellation of broker parallel merge, more chill 
blocking queue timeouts
    
    * wire parallel merge into query cancellation system
    
    * oops
    
    * style
    
    * adjust metrics initialization
    
    * fix timeout, fix cleanup to not block
    
    * javadocs to clarify why cancellation future and gizmo are split
    
    * cancelled -> canceled, simplify QueuePusher since it always takes a 
ResultBatch, non-static terminal marker to make stuff stop complaining about 
types, specialize tryOffer to be tryOfferTerminal so it wont be misused, add 
comments to clarify reason for non-blocking offers that might fail
---
 .../guava/ParallelMergeCombiningSequence.java      | 290 ++++++++++++++-------
 .../guava/ParallelMergeCombiningSequenceTest.java  | 156 +++++++++--
 .../druid/client/CachingClusteredClient.java       |   4 +-
 3 files changed, 343 insertions(+), 107 deletions(-)

diff --git 
a/processing/src/main/java/org/apache/druid/java/util/common/guava/ParallelMergeCombiningSequence.java
 
b/processing/src/main/java/org/apache/druid/java/util/common/guava/ParallelMergeCombiningSequence.java
index 517235a99f9..ca2708700f0 100644
--- 
a/processing/src/main/java/org/apache/druid/java/util/common/guava/ParallelMergeCombiningSequence.java
+++ 
b/processing/src/main/java/org/apache/druid/java/util/common/guava/ParallelMergeCombiningSequence.java
@@ -19,10 +19,9 @@
 
 package org.apache.druid.java.util.common.guava;
 
-import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Ordering;
-import org.apache.druid.java.util.common.RE;
+import com.google.common.util.concurrent.AbstractFuture;
 import org.apache.druid.java.util.common.io.Closer;
 import org.apache.druid.java.util.common.logger.Logger;
 import org.apache.druid.query.QueryTimeoutException;
@@ -63,6 +62,7 @@ import java.util.function.Consumer;
 public class ParallelMergeCombiningSequence<T> extends YieldingSequenceBase<T>
 {
   private static final Logger LOG = new 
Logger(ParallelMergeCombiningSequence.class);
+  private static final long BLOCK_TIMEOUT = TimeUnit.NANOSECONDS.convert(500, 
TimeUnit.MILLISECONDS);
 
   // these values were chosen carefully via feedback from benchmarks,
   // see PR https://github.com/apache/druid/pull/8578 for details
@@ -84,7 +84,7 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
   private final long targetTimeNanos;
   private final Consumer<MergeCombineMetrics> metricsReporter;
 
-  private final CancellationGizmo cancellationGizmo;
+  private final CancellationFuture cancellationFuture;
 
   public ParallelMergeCombiningSequence(
       ForkJoinPool workerPool,
@@ -114,14 +114,24 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
     this.targetTimeNanos = TimeUnit.NANOSECONDS.convert(targetTimeMillis, 
TimeUnit.MILLISECONDS);
     this.queueSize = (1 << 15) / batchSize; // each queue can by default hold 
~32k rows
     this.metricsReporter = reporter;
-    this.cancellationGizmo = new CancellationGizmo();
+    this.cancellationFuture = new CancellationFuture(new CancellationGizmo());
   }
 
   @Override
   public <OutType> Yielder<OutType> toYielder(OutType initValue, 
YieldingAccumulator<OutType, T> accumulator)
   {
     if (inputSequences.isEmpty()) {
-      return Sequences.<T>empty().toYielder(initValue, accumulator);
+      return Sequences.wrap(
+          Sequences.<T>empty(),
+          new SequenceWrapper()
+          {
+            @Override
+            public void after(boolean isDone, Throwable thrown)
+            {
+              cancellationFuture.set(true);
+            }
+          }
+      ).toYielder(initValue, accumulator);
     }
     // we make final output queue larger than the merging queues so if 
downstream readers are slower to read there is
     // less chance of blocking the merge
@@ -144,27 +154,43 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
         hasTimeout,
         timeoutAtNanos,
         metricsAccumulator,
-        cancellationGizmo
+        cancellationFuture.cancellationGizmo
     );
     workerPool.execute(mergeCombineAction);
-    Sequence<T> finalOutSequence = makeOutputSequenceForQueue(
-        outputQueue,
-        hasTimeout,
-        timeoutAtNanos,
-        cancellationGizmo
-    ).withBaggage(() -> {
-      if (metricsReporter != null) {
-        metricsAccumulator.setTotalWallTime(System.nanoTime() - 
startTimeNanos);
-        metricsReporter.accept(metricsAccumulator.build());
-      }
-    });
+
+    final Sequence<T> finalOutSequence = Sequences.wrap(
+        makeOutputSequenceForQueue(
+            outputQueue,
+            hasTimeout,
+            timeoutAtNanos,
+            cancellationFuture.cancellationGizmo
+        ),
+        new SequenceWrapper()
+        {
+          @Override
+          public void after(boolean isDone, Throwable thrown)
+          {
+            if (isDone) {
+              cancellationFuture.set(true);
+            } else {
+              cancellationFuture.cancel(true);
+            }
+            if (metricsReporter != null) {
+              metricsAccumulator.setTotalWallTime(System.nanoTime() - 
startTimeNanos);
+              metricsReporter.accept(metricsAccumulator.build());
+            }
+          }
+        }
+    );
     return finalOutSequence.toYielder(initValue, accumulator);
   }
 
-  @VisibleForTesting
-  public CancellationGizmo getCancellationGizmo()
+  /**
+   *
+   */
+  public CancellationFuture getCancellationFuture()
   {
-    return cancellationGizmo;
+    return cancellationFuture;
   }
 
   /**
@@ -181,8 +207,6 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
     return new BaseSequence<>(
         new BaseSequence.IteratorMaker<T, Iterator<T>>()
         {
-          private boolean shouldCancelOnCleanup = true;
-
           @Override
           public Iterator<T> make()
           {
@@ -195,7 +219,7 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
               {
                 final long thisTimeoutNanos = timeoutAtNanos - 
System.nanoTime();
                 if (hasTimeout && thisTimeoutNanos < 0) {
-                  throw new QueryTimeoutException();
+                  throw cancellationGizmo.cancelAndThrow(new 
QueryTimeoutException());
                 }
 
                 if (currentBatch != null && !currentBatch.isTerminalResult() 
&& !currentBatch.isDrained()) {
@@ -210,33 +234,32 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
                     }
                   }
                   if (currentBatch == null) {
-                    throw new QueryTimeoutException();
+                    throw cancellationGizmo.cancelAndThrow(new 
QueryTimeoutException());
                   }
 
-                  if (cancellationGizmo.isCancelled()) {
+                  if (cancellationGizmo.isCanceled()) {
                     throw cancellationGizmo.getRuntimeException();
                   }
 
                   if (currentBatch.isTerminalResult()) {
-                    shouldCancelOnCleanup = false;
                     return false;
                   }
                   return true;
                 }
                 catch (InterruptedException e) {
-                  throw new RE(e);
+                  throw cancellationGizmo.cancelAndThrow(e);
                 }
               }
 
               @Override
               public T next()
               {
-                if (cancellationGizmo.isCancelled()) {
+                if (cancellationGizmo.isCanceled()) {
                   throw cancellationGizmo.getRuntimeException();
                 }
 
                 if (currentBatch == null || currentBatch.isDrained() || 
currentBatch.isTerminalResult()) {
-                  throw new NoSuchElementException();
+                  throw cancellationGizmo.cancelAndThrow(new 
NoSuchElementException());
                 }
                 return currentBatch.next();
               }
@@ -246,9 +269,7 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
           @Override
           public void cleanup(Iterator<T> iterFromMake)
           {
-            if (shouldCancelOnCleanup) {
-              cancellationGizmo.cancel(new RuntimeException("Already closed"));
-            }
+            // nothing to cleanup
           }
         }
     );
@@ -338,7 +359,7 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
               parallelTaskCount
           );
 
-          QueuePusher<ResultBatch<T>> resultsPusher = new QueuePusher<>(out, 
hasTimeout, timeoutAt);
+          QueuePusher<T> resultsPusher = new QueuePusher<>(out, 
cancellationGizmo, hasTimeout, timeoutAt);
 
           for (Sequence<T> s : sequences) {
             sequenceCursors.add(new YielderBatchedResultsCursor<>(new 
SequenceBatcher<>(s, batchSize), orderingFn));
@@ -367,10 +388,10 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
       catch (Throwable t) {
         closeAllCursors(sequenceCursors);
         cancellationGizmo.cancel(t);
-        // Should be the following, but can' change due to lack of
-        // unit tests.
-        // out.offer((ParallelMergeCombiningSequence.ResultBatch<T>) 
ResultBatch.TERMINAL);
-        out.offer(ResultBatch.TERMINAL);
+        // offer terminal result if queue is not full in case out is empty to 
allow downstream threads waiting on
+        // stuff to be present to stop blocking immediately. However, if the 
queue is full, it doesn't matter if we
+        // write anything because the cancellation signal has been set, which 
will also terminate processing.
+        out.offer(ResultBatch.terminal());
       }
     }
 
@@ -387,7 +408,7 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
       for (List<Sequence<T>> partition : partitions) {
         BlockingQueue<ResultBatch<T>> outputQueue = new 
ArrayBlockingQueue<>(queueSize);
         intermediaryOutputs.add(outputQueue);
-        QueuePusher<ResultBatch<T>> pusher = new QueuePusher<>(outputQueue, 
hasTimeout, timeoutAt);
+        QueuePusher<T> pusher = new QueuePusher<>(outputQueue, 
cancellationGizmo, hasTimeout, timeoutAt);
 
         List<BatchedResultsCursor<T>> partitionCursors = new 
ArrayList<>(sequences.size());
         for (Sequence<T> s : partition) {
@@ -415,11 +436,11 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
         getPool().execute(task);
       }
 
-      QueuePusher<ResultBatch<T>> outputPusher = new QueuePusher<>(out, 
hasTimeout, timeoutAt);
+      QueuePusher<T> outputPusher = new QueuePusher<>(out, cancellationGizmo, 
hasTimeout, timeoutAt);
       List<BatchedResultsCursor<T>> intermediaryOutputsCursors = new 
ArrayList<>(intermediaryOutputs.size());
       for (BlockingQueue<ResultBatch<T>> queue : intermediaryOutputs) {
         intermediaryOutputsCursors.add(
-            new BlockingQueueuBatchedResultsCursor<>(queue, orderingFn, 
hasTimeout, timeoutAt)
+            new BlockingQueueuBatchedResultsCursor<>(queue, cancellationGizmo, 
orderingFn, hasTimeout, timeoutAt)
         );
       }
       MergeCombineActionMetricsAccumulator finalMergeMetrics = new 
MergeCombineActionMetricsAccumulator();
@@ -513,7 +534,7 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
     private final PriorityQueue<BatchedResultsCursor<T>> pQueue;
     private final Ordering<T> orderingFn;
     private final BinaryOperator<T> combineFn;
-    private final QueuePusher<ResultBatch<T>> outputQueue;
+    private final QueuePusher<T> outputQueue;
     private final T initialValue;
     private final int yieldAfter;
     private final int batchSize;
@@ -523,7 +544,7 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
 
     private MergeCombineAction(
         PriorityQueue<BatchedResultsCursor<T>> pQueue,
-        QueuePusher<ResultBatch<T>> outputQueue,
+        QueuePusher<T> outputQueue,
         Ordering<T> orderingFn,
         BinaryOperator<T> combineFn,
         T initialValue,
@@ -550,6 +571,10 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
     @Override
     protected void compute()
     {
+      if (cancellationGizmo.isCanceled()) {
+        cleanup();
+        return;
+      }
       try {
         long start = System.nanoTime();
         long startCpuNanos = JvmUtils.safeGetThreadCpuTime();
@@ -608,7 +633,7 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
         metricsAccumulator.incrementCpuTimeNanos(elapsedCpuNanos);
         metricsAccumulator.incrementTaskCount();
 
-        if (!pQueue.isEmpty() && !cancellationGizmo.isCancelled()) {
+        if (!pQueue.isEmpty() && !cancellationGizmo.isCanceled()) {
           // if there is still work to be done, execute a new task with the 
current accumulated value to continue
           // combining where we left off
           if (!outputBatch.isDrained()) {
@@ -650,29 +675,36 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
               metricsAccumulator,
               cancellationGizmo
           ));
-        } else if (cancellationGizmo.isCancelled()) {
+        } else if (cancellationGizmo.isCanceled()) {
           // if we got the cancellation signal, go ahead and write terminal 
value into output queue to help gracefully
           // allow downstream stuff to stop
-          LOG.debug("cancelled after %s tasks", 
metricsAccumulator.getTaskCount());
+          LOG.debug("canceled after %s tasks", 
metricsAccumulator.getTaskCount());
           // make sure to close underlying cursors
-          closeAllCursors(pQueue);
-          outputQueue.offer(ResultBatch.TERMINAL);
+          cleanup();
         } else {
           // if priority queue is empty, push the final accumulated value into 
the output batch and push it out
           outputBatch.add(currentCombinedValue);
           metricsAccumulator.incrementOutputRows(batchCounter + 1L);
           outputQueue.offer(outputBatch);
           // ... and the terminal value to indicate the blocking queue holding 
the values is complete
-          outputQueue.offer(ResultBatch.TERMINAL);
+          outputQueue.offer(ResultBatch.terminal());
           LOG.debug("merge combine complete after %s tasks", 
metricsAccumulator.getTaskCount());
         }
       }
       catch (Throwable t) {
-        closeAllCursors(pQueue);
         cancellationGizmo.cancel(t);
-        outputQueue.offer(ResultBatch.TERMINAL);
+        cleanup();
       }
     }
+
+    private void cleanup()
+    {
+      closeAllCursors(pQueue);
+      // offer terminal result if queue is not full in case out is empty to 
allow downstream threads waiting on
+      // stuff to be present to stop blocking immediately. However, if the 
queue is full, it doesn't matter if we
+      // write anything because the cancellation signal has been set, which 
will also terminate processing.
+      outputQueue.offer(ResultBatch.terminal());
+    }
   }
 
 
@@ -696,7 +728,7 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
     private final List<BatchedResultsCursor<T>> partition;
     private final Ordering<T> orderingFn;
     private final BinaryOperator<T> combineFn;
-    private final QueuePusher<ResultBatch<T>> outputQueue;
+    private final QueuePusher<T> outputQueue;
     private final int yieldAfter;
     private final int batchSize;
     private final long targetTimeNanos;
@@ -707,7 +739,7 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
 
     private PrepareMergeCombineInputsAction(
         List<BatchedResultsCursor<T>> partition,
-        QueuePusher<ResultBatch<T>> outputQueue,
+        QueuePusher<T> outputQueue,
         Ordering<T> orderingFn,
         BinaryOperator<T> combineFn,
         int yieldAfter,
@@ -744,7 +776,7 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
             cursor.close();
           }
         }
-        if (cursors.size() > 0) {
+        if (!cancellationGizmo.isCanceled() && !cursors.isEmpty()) {
           getPool().execute(new MergeCombineAction<T>(
               cursors,
               outputQueue,
@@ -758,14 +790,17 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
               cancellationGizmo
           ));
         } else {
-          outputQueue.offer(ResultBatch.TERMINAL);
+          outputQueue.offer(ResultBatch.terminal());
         }
         metricsAccumulator.setPartitionInitializedTime(System.nanoTime() - 
startTime);
       }
       catch (Throwable t) {
         closeAllCursors(partition);
         cancellationGizmo.cancel(t);
-        outputQueue.offer(ResultBatch.TERMINAL);
+        // offer terminal result if queue is not full in case out is empty to 
allow downstream threads waiting on
+        // stuff to be present to stop blocking immediately. However, if the 
queue is full, it doesn't matter if we
+        // write anything because the cancellation signal has been set, which 
will also terminate processing.
+        outputQueue.tryOfferTerminal();
       }
     }
   }
@@ -779,12 +814,14 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
   {
     final boolean hasTimeout;
     final long timeoutAtNanos;
-    final BlockingQueue<E> queue;
-    volatile E item = null;
+    final BlockingQueue<ResultBatch<E>> queue;
+    final CancellationGizmo gizmo;
+    volatile ResultBatch<E> item = null;
 
-    QueuePusher(BlockingQueue<E> q, boolean hasTimeout, long timeoutAtNanos)
+    QueuePusher(BlockingQueue<ResultBatch<E>> q, CancellationGizmo gizmo, 
boolean hasTimeout, long timeoutAtNanos)
     {
       this.queue = q;
+      this.gizmo = gizmo;
       this.hasTimeout = hasTimeout;
       this.timeoutAtNanos = timeoutAtNanos;
     }
@@ -795,14 +832,16 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
       boolean success = false;
       if (item != null) {
         if (hasTimeout) {
-          final long thisTimeoutNanos = timeoutAtNanos - System.nanoTime();
-          if (thisTimeoutNanos < 0) {
+          final long remainingNanos = timeoutAtNanos - System.nanoTime();
+          if (remainingNanos < 0) {
             item = null;
-            throw new QueryTimeoutException("QueuePusher timed out offering 
data");
+            throw gizmo.cancelAndThrow(new QueryTimeoutException());
           }
-          success = queue.offer(item, thisTimeoutNanos, TimeUnit.NANOSECONDS);
+          final long blockTimeoutNanos = Math.min(remainingNanos, 
BLOCK_TIMEOUT);
+          success = queue.offer(item, blockTimeoutNanos, TimeUnit.NANOSECONDS);
         } else {
-          success = queue.offer(item);
+          queue.put(item);
+          success = true;
         }
         if (success) {
           item = null;
@@ -817,7 +856,7 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
       return item == null;
     }
 
-    public void offer(E item)
+    public void offer(ResultBatch<E> item)
     {
       try {
         this.item = item;
@@ -828,6 +867,11 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
         throw new RuntimeException("Failed to offer result to output queue", 
e);
       }
     }
+
+    public void tryOfferTerminal()
+    {
+      this.queue.offer(ResultBatch.terminal());
+    }
   }
 
   /**
@@ -837,8 +881,10 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
    */
   static class ResultBatch<E>
   {
-    @SuppressWarnings("rawtypes")
-    static final ResultBatch TERMINAL = new ResultBatch();
+    static <T> ResultBatch<T> terminal()
+    {
+      return new ResultBatch<>();
+    }
 
     @Nullable
     private final Queue<E> values;
@@ -855,19 +901,16 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
 
     public void add(E in)
     {
-      assert values != null;
       values.offer(in);
     }
 
     public E get()
     {
-      assert values != null;
       return values.peek();
     }
 
     public E next()
     {
-      assert values != null;
       return values.poll();
     }
 
@@ -925,6 +968,7 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
     Yielder<ResultBatch<E>> getBatchYielder()
     {
       try {
+        batchYielder = null;
         ForkJoinPool.managedBlock(this);
         return batchYielder;
       }
@@ -1033,8 +1077,8 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
     @Override
     public void initialize()
     {
-      yielder = batcher.getBatchYielder();
-      resultBatch = yielder.get();
+      yielder = null;
+      nextBatch();
     }
 
     @Override
@@ -1059,6 +1103,10 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
     @Override
     public boolean block()
     {
+      if (yielder == null) {
+        yielder = batcher.getBatchYielder();
+        resultBatch = yielder.get();
+      }
       if (yielder.isDone()) {
         return true;
       }
@@ -1073,7 +1121,7 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
     @Override
     public boolean isReleasable()
     {
-      return yielder.isDone() || (resultBatch != null && 
!resultBatch.isDrained());
+      return (yielder != null && yielder.isDone()) || (resultBatch != null && 
!resultBatch.isDrained());
     }
 
     @Override
@@ -1092,11 +1140,13 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
   static class BlockingQueueuBatchedResultsCursor<E> extends 
BatchedResultsCursor<E>
   {
     final BlockingQueue<ResultBatch<E>> queue;
+    final CancellationGizmo gizmo;
     final boolean hasTimeout;
     final long timeoutAtNanos;
 
     BlockingQueueuBatchedResultsCursor(
         BlockingQueue<ResultBatch<E>> blockingQueue,
+        CancellationGizmo cancellationGizmo,
         Ordering<E> ordering,
         boolean hasTimeout,
         long timeoutAtNanos
@@ -1104,6 +1154,7 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
     {
       super(ordering);
       this.queue = blockingQueue;
+      this.gizmo = cancellationGizmo;
       this.hasTimeout = hasTimeout;
       this.timeoutAtNanos = timeoutAtNanos;
     }
@@ -1142,17 +1193,18 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
     {
       if (resultBatch == null || resultBatch.isDrained()) {
         if (hasTimeout) {
-          final long thisTimeoutNanos = timeoutAtNanos - System.nanoTime();
-          if (thisTimeoutNanos < 0) {
-            resultBatch = ResultBatch.TERMINAL;
-            throw new QueryTimeoutException("BlockingQueue cursor timed out 
waiting for data");
+          final long remainingNanos = timeoutAtNanos - System.nanoTime();
+          if (remainingNanos < 0) {
+            resultBatch = ResultBatch.terminal();
+            throw gizmo.cancelAndThrow(new QueryTimeoutException());
           }
-          resultBatch = queue.poll(thisTimeoutNanos, TimeUnit.NANOSECONDS);
+          final long blockTimeoutNanos = Math.min(remainingNanos, 
BLOCK_TIMEOUT);
+          resultBatch = queue.poll(blockTimeoutNanos, TimeUnit.NANOSECONDS);
         } else {
           resultBatch = queue.take();
         }
       }
-      return resultBatch != null;
+      return resultBatch != null && !resultBatch.isDrained();
     }
 
     @Override
@@ -1164,35 +1216,91 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
       }
       // if we can get a result immediately without blocking, also no need to 
block
       resultBatch = queue.poll();
-      return resultBatch != null;
+      return resultBatch != null && !resultBatch.isDrained();
     }
   }
 
   /**
-   * Token to allow any {@link RecursiveAction} signal the others and the 
output sequence that something bad happened
-   * and processing should cancel, such as a timeout or connection loss.
+   * Token used to stop internal parallel processing across all tasks in the 
merge pool. Allows any
+   * {@link RecursiveAction} signal the others and the output sequence that 
something bad happened and
+   * processing should cancel, such as a timeout, error, or connection loss.
    */
-  static class CancellationGizmo
+  public static class CancellationGizmo
   {
     private final AtomicReference<Throwable> throwable = new 
AtomicReference<>(null);
 
+    RuntimeException cancelAndThrow(Throwable t)
+    {
+      throwable.compareAndSet(null, t);
+      return wrapRuntimeException(t);
+    }
+
     void cancel(Throwable t)
     {
       throwable.compareAndSet(null, t);
     }
 
-    boolean isCancelled()
+    boolean isCanceled()
     {
       return throwable.get() != null;
     }
 
     RuntimeException getRuntimeException()
     {
-      Throwable ex = throwable.get();
-      if (ex instanceof RuntimeException) {
-        return (RuntimeException) ex;
+      return wrapRuntimeException(throwable.get());
+    }
+
+    private static RuntimeException wrapRuntimeException(Throwable t)
+    {
+      if (t instanceof RuntimeException) {
+        return (RuntimeException) t;
       }
-      return new RE(ex);
+      return new RuntimeException(t);
+    }
+  }
+
+  /**
+   * {@link com.google.common.util.concurrent.ListenableFuture} that allows 
{@link ParallelMergeCombiningSequence} to be
+   * registered with {@link 
org.apache.druid.query.QueryWatcher#registerQueryFuture} to participate in query
+   * cancellation or anything else that has a need to watch the activity on 
the merge pool. Wraps a
+   * {@link CancellationGizmo} to allow for external threads to signal 
cancellation of parallel processing on the pool
+   * by triggering {@link CancellationGizmo#cancel(Throwable)} whenever {@link 
#cancel(boolean)} is called.
+   *
+   * This is not used internally by workers on the pool in favor of using the 
much simpler {@link CancellationGizmo}
+   * directly instead.
+   */
+  public static class CancellationFuture extends AbstractFuture<Boolean>
+  {
+    private final CancellationGizmo cancellationGizmo;
+
+    public CancellationFuture(CancellationGizmo cancellationGizmo)
+    {
+      this.cancellationGizmo = cancellationGizmo;
+    }
+
+    public CancellationGizmo getCancellationGizmo()
+    {
+      return cancellationGizmo;
+    }
+
+    @Override
+    public boolean set(Boolean value)
+    {
+      return super.set(value);
+    }
+
+    @Override
+    public boolean setException(Throwable throwable)
+    {
+      cancellationGizmo.cancel(throwable);
+      return super.setException(throwable);
+    }
+
+    @Override
+    public boolean cancel(boolean mayInterruptIfRunning)
+    {
+      cancellationGizmo.cancel(new RuntimeException("Sequence canceled"));
+      return super.cancel(mayInterruptIfRunning);
     }
   }
 
@@ -1308,8 +1416,8 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
    */
   static class MergeCombineMetricsAccumulator
   {
-    List<MergeCombineActionMetricsAccumulator> partitionMetrics;
-    MergeCombineActionMetricsAccumulator mergeMetrics;
+    List<MergeCombineActionMetricsAccumulator> partitionMetrics = 
Collections.emptyList();
+    MergeCombineActionMetricsAccumulator mergeMetrics = new 
MergeCombineActionMetricsAccumulator();
 
     private long totalWallTime;
 
@@ -1343,8 +1451,8 @@ public class ParallelMergeCombiningSequence<T> extends 
YieldingSequenceBase<T>
       // partition
       long totalPoolTasks = 1 + 1 + partitionMetrics.size();
 
-      long fastestPartInitialized = partitionMetrics.size() > 0 ? 
Long.MAX_VALUE : mergeMetrics.getPartitionInitializedtime();
-      long slowestPartInitialied = partitionMetrics.size() > 0 ? 
Long.MIN_VALUE : mergeMetrics.getPartitionInitializedtime();
+      long fastestPartInitialized = !partitionMetrics.isEmpty() ? 
Long.MAX_VALUE : mergeMetrics.getPartitionInitializedtime();
+      long slowestPartInitialied = !partitionMetrics.isEmpty() ? 
Long.MIN_VALUE : mergeMetrics.getPartitionInitializedtime();
 
       // accumulate input row count, cpu time, and total number of tasks from 
each partition
       for (MergeCombineActionMetricsAccumulator partition : partitionMetrics) {
diff --git 
a/processing/src/test/java/org/apache/druid/java/util/common/guava/ParallelMergeCombiningSequenceTest.java
 
b/processing/src/test/java/org/apache/druid/java/util/common/guava/ParallelMergeCombiningSequenceTest.java
index ca34c364dca..5b76afb9022 100644
--- 
a/processing/src/test/java/org/apache/druid/java/util/common/guava/ParallelMergeCombiningSequenceTest.java
+++ 
b/processing/src/test/java/org/apache/druid/java/util/common/guava/ParallelMergeCombiningSequenceTest.java
@@ -143,7 +143,7 @@ public class ParallelMergeCombiningSequenceTest
     if (!currentBatch.isDrained()) {
       outputQueue.offer(currentBatch);
     }
-    outputQueue.offer(ParallelMergeCombiningSequence.ResultBatch.TERMINAL);
+    outputQueue.offer(ParallelMergeCombiningSequence.ResultBatch.terminal());
 
     rawYielder.close();
     cursor.close();
@@ -211,16 +211,18 @@ public class ParallelMergeCombiningSequenceTest
     if (!currentBatch.isDrained()) {
       outputQueue.offer(currentBatch);
     }
-    outputQueue.offer(ParallelMergeCombiningSequence.ResultBatch.TERMINAL);
+    outputQueue.offer(ParallelMergeCombiningSequence.ResultBatch.terminal());
 
     rawYielder.close();
     cursor.close();
 
     rawYielder = Yielders.each(rawSequence);
 
+    ParallelMergeCombiningSequence.CancellationGizmo gizmo = new 
ParallelMergeCombiningSequence.CancellationGizmo();
     ParallelMergeCombiningSequence.BlockingQueueuBatchedResultsCursor<IntPair> 
queueCursor =
         new 
ParallelMergeCombiningSequence.BlockingQueueuBatchedResultsCursor<>(
             outputQueue,
+            gizmo,
             INT_PAIR_ORDERING,
             false,
             -1L
@@ -551,14 +553,14 @@ public class ParallelMergeCombiningSequenceTest
   }
 
   @Test
-  public void testTimeoutExceptionDueToStalledReader()
+  public void testTimeoutExceptionDueToSlowReader()
   {
-    final int someSize = 2048;
+    final int someSize = 50_000;
     List<Sequence<IntPair>> input = new ArrayList<>();
-    input.add(nonBlockingSequence(someSize));
-    input.add(nonBlockingSequence(someSize));
-    input.add(nonBlockingSequence(someSize));
-    input.add(nonBlockingSequence(someSize));
+    input.add(nonBlockingSequence(someSize, true));
+    input.add(nonBlockingSequence(someSize, true));
+    input.add(nonBlockingSequence(someSize, true));
+    input.add(nonBlockingSequence(someSize, true));
 
     Throwable t = Assert.assertThrows(QueryTimeoutException.class, () -> 
assertException(input, 8, 64, 1000, 1500));
     Assert.assertEquals("Query did not complete within configured timeout 
period. " +
@@ -567,6 +569,110 @@ public class ParallelMergeCombiningSequenceTest
     Assert.assertTrue(pool.isQuiescent());
   }
 
+  @Test
+  public void testTimeoutExceptionDueToStoppedReader() throws 
InterruptedException
+  {
+    final int someSize = 150_000;
+    List<TestingReporter> reporters = new ArrayList<>();
+    for (int i = 0; i < 100; i++) {
+      List<Sequence<IntPair>> input = new ArrayList<>();
+      input.add(nonBlockingSequence(someSize, true));
+      input.add(nonBlockingSequence(someSize, true));
+      input.add(nonBlockingSequence(someSize, true));
+      input.add(nonBlockingSequence(someSize, true));
+
+      TestingReporter reporter = new TestingReporter();
+      final ParallelMergeCombiningSequence<IntPair> 
parallelMergeCombineSequence = new ParallelMergeCombiningSequence<>(
+          pool,
+          input,
+          INT_PAIR_ORDERING,
+          INT_PAIR_MERGE_FN,
+          true,
+          1000,
+          0,
+          TEST_POOL_SIZE,
+          512,
+          128,
+          ParallelMergeCombiningSequence.DEFAULT_TASK_TARGET_RUN_TIME_MILLIS,
+          reporter
+      );
+      Yielder<IntPair> parallelMergeCombineYielder = 
Yielders.each(parallelMergeCombineSequence);
+      reporter.future = parallelMergeCombineSequence.getCancellationFuture();
+      reporter.yielder = parallelMergeCombineYielder;
+      reporter.yielder = parallelMergeCombineYielder.next(null);
+      Assert.assertFalse(parallelMergeCombineYielder.isDone());
+      reporters.add(reporter);
+    }
+
+    // sleep until timeout
+    Thread.sleep(1000);
+    Assert.assertTrue(pool.awaitQuiescence(10, TimeUnit.SECONDS));
+    Assert.assertTrue(pool.isQuiescent());
+    Assert.assertFalse(pool.hasQueuedSubmissions());
+    for (TestingReporter reporter : reporters) {
+      Assert.assertThrows(QueryTimeoutException.class, () -> 
reporter.yielder.next(null));
+      Assert.assertTrue(reporter.future.isCancelled());
+      Assert.assertTrue(reporter.future.getCancellationGizmo().isCanceled());
+    }
+    Assert.assertTrue(pool.awaitQuiescence(10, TimeUnit.SECONDS));
+    Assert.assertTrue(pool.isQuiescent());
+  }
+
+  @Test
+  public void testManyBigSequencesAllAtOnce() throws IOException
+  {
+    final int someSize = 50_000;
+    List<TestingReporter> reporters = new ArrayList<>();
+    for (int i = 0; i < 100; i++) {
+      List<Sequence<IntPair>> input = new ArrayList<>();
+      input.add(nonBlockingSequence(someSize, true));
+      input.add(nonBlockingSequence(someSize, true));
+      input.add(nonBlockingSequence(someSize, true));
+      input.add(nonBlockingSequence(someSize, true));
+
+      TestingReporter reporter = new TestingReporter();
+      final ParallelMergeCombiningSequence<IntPair> 
parallelMergeCombineSequence = new ParallelMergeCombiningSequence<>(
+          pool,
+          input,
+          INT_PAIR_ORDERING,
+          INT_PAIR_MERGE_FN,
+          true,
+          30 * 1000,
+          0,
+          TEST_POOL_SIZE,
+          512,
+          128,
+          ParallelMergeCombiningSequence.DEFAULT_TASK_TARGET_RUN_TIME_MILLIS,
+          reporter
+      );
+      Yielder<IntPair> parallelMergeCombineYielder = 
Yielders.each(parallelMergeCombineSequence);
+      reporter.future = parallelMergeCombineSequence.getCancellationFuture();
+      reporter.yielder = parallelMergeCombineYielder;
+      parallelMergeCombineYielder.next(null);
+      Assert.assertFalse(parallelMergeCombineYielder.isDone());
+      reporters.add(reporter);
+    }
+
+    for (TestingReporter testingReporter : reporters) {
+      Yielder<IntPair> parallelMergeCombineYielder = testingReporter.yielder;
+      while (!parallelMergeCombineYielder.isDone()) {
+        parallelMergeCombineYielder = 
parallelMergeCombineYielder.next(parallelMergeCombineYielder.get());
+      }
+      Assert.assertTrue(parallelMergeCombineYielder.isDone());
+      parallelMergeCombineYielder.close();
+      Assert.assertTrue(testingReporter.future.isDone());
+    }
+
+    Assert.assertTrue(pool.awaitQuiescence(10, TimeUnit.SECONDS));
+    Assert.assertTrue(pool.isQuiescent());
+    Assert.assertEquals(0, pool.getRunningThreadCount());
+    Assert.assertFalse(pool.hasQueuedSubmissions());
+    Assert.assertEquals(0, pool.getActiveThreadCount());
+    for (TestingReporter reporter : reporters) {
+      Assert.assertTrue(reporter.done);
+    }
+  }
+
   @Test
   public void testGracefulCloseOfYielderCancelsPool() throws IOException
   {
@@ -666,7 +772,9 @@ public class ParallelMergeCombiningSequenceTest
     parallelMergeCombineYielder.close();
     // cancellation trigger should not be set if sequence was fully yielded 
and close is called
     // (though shouldn't actually matter even if it was...)
-    
Assert.assertFalse(parallelMergeCombineSequence.getCancellationGizmo().isCancelled());
+    
Assert.assertFalse(parallelMergeCombineSequence.getCancellationFuture().isCancelled());
+    
Assert.assertTrue(parallelMergeCombineSequence.getCancellationFuture().isDone());
+    
Assert.assertFalse(parallelMergeCombineSequence.getCancellationFuture().getCancellationGizmo().isCanceled());
   }
 
   private void assertResult(
@@ -713,13 +821,15 @@ public class ParallelMergeCombiningSequenceTest
 
     Assert.assertTrue(combiningYielder.isDone());
     Assert.assertTrue(parallelMergeCombineYielder.isDone());
-    Assert.assertTrue(pool.awaitQuiescence(1, TimeUnit.SECONDS));
+    Assert.assertTrue(pool.awaitQuiescence(5, TimeUnit.SECONDS));
     Assert.assertTrue(pool.isQuiescent());
     combiningYielder.close();
     parallelMergeCombineYielder.close();
     // cancellation trigger should not be set if sequence was fully yielded 
and close is called
     // (though shouldn't actually matter even if it was...)
-    
Assert.assertFalse(parallelMergeCombineSequence.getCancellationGizmo().isCancelled());
+    
Assert.assertFalse(parallelMergeCombineSequence.getCancellationFuture().isCancelled());
+    
Assert.assertFalse(parallelMergeCombineSequence.getCancellationFuture().getCancellationGizmo().isCanceled());
+    
Assert.assertTrue(parallelMergeCombineSequence.getCancellationFuture().isDone());
   }
 
   private void assertResultWithEarlyClose(
@@ -773,20 +883,21 @@ public class ParallelMergeCombiningSequenceTest
       }
     }
     // trying to next the yielder creates sadness for you
-    final String expectedExceptionMsg = "Already closed";
+    final String expectedExceptionMsg = "Sequence canceled";
     Assert.assertEquals(combiningYielder.get(), 
parallelMergeCombineYielder.get());
     final Yielder<IntPair> finalYielder = parallelMergeCombineYielder;
     Throwable t = Assert.assertThrows(RuntimeException.class, () -> 
finalYielder.next(finalYielder.get()));
     Assert.assertEquals(expectedExceptionMsg, t.getMessage());
 
     // cancellation gizmo of sequence should be cancelled, and also should 
contain our expected message
-    
Assert.assertTrue(parallelMergeCombineSequence.getCancellationGizmo().isCancelled());
+    
Assert.assertTrue(parallelMergeCombineSequence.getCancellationFuture().getCancellationGizmo().isCanceled());
     Assert.assertEquals(
         expectedExceptionMsg,
-        
parallelMergeCombineSequence.getCancellationGizmo().getRuntimeException().getMessage()
+        
parallelMergeCombineSequence.getCancellationFuture().getCancellationGizmo().getRuntimeException().getMessage()
     );
+    
Assert.assertTrue(parallelMergeCombineSequence.getCancellationFuture().isCancelled());
 
-    Assert.assertTrue(pool.awaitQuiescence(1, TimeUnit.SECONDS));
+    Assert.assertTrue(pool.awaitQuiescence(10, TimeUnit.SECONDS));
     Assert.assertTrue(pool.isQuiescent());
 
     Assert.assertFalse(combiningYielder.isDone());
@@ -1082,4 +1193,19 @@ public class ParallelMergeCombiningSequenceTest
   {
     return new IntPair(mergeKey, ThreadLocalRandom.current().nextInt(1, 100));
   }
+
+  static class TestingReporter implements 
Consumer<ParallelMergeCombiningSequence.MergeCombineMetrics>
+  {
+    ParallelMergeCombiningSequence.CancellationFuture future;
+    Yielder<IntPair> yielder;
+    volatile ParallelMergeCombiningSequence.MergeCombineMetrics metrics;
+    volatile boolean done = false;
+
+    @Override
+    public void accept(ParallelMergeCombiningSequence.MergeCombineMetrics 
mergeCombineMetrics)
+    {
+      metrics = mergeCombineMetrics;
+      done = true;
+    }
+  }
 }
diff --git 
a/server/src/main/java/org/apache/druid/client/CachingClusteredClient.java 
b/server/src/main/java/org/apache/druid/client/CachingClusteredClient.java
index 5fa34d6699d..e4027bcd357 100644
--- a/server/src/main/java/org/apache/druid/client/CachingClusteredClient.java
+++ b/server/src/main/java/org/apache/druid/client/CachingClusteredClient.java
@@ -384,7 +384,7 @@ public class CachingClusteredClient implements 
QuerySegmentWalker
       BinaryOperator<T> mergeFn = toolChest.createMergeFn(query);
       final QueryContext queryContext = query.context();
       if (parallelMergeConfig.useParallelMergePool() && 
queryContext.getEnableParallelMerges() && mergeFn != null) {
-        return new ParallelMergeCombiningSequence<>(
+        final ParallelMergeCombiningSequence<T> parallelSequence = new 
ParallelMergeCombiningSequence<>(
             pool,
             sequencesByInterval,
             query.getResultOrdering(),
@@ -414,6 +414,8 @@ public class CachingClusteredClient implements 
QuerySegmentWalker
               }
             }
         );
+        scheduler.registerQueryFuture(query, 
parallelSequence.getCancellationFuture());
+        return parallelSequence;
       } else {
         return Sequences
             .simple(sequencesByInterval)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to