This is an automated email from the ASF dual-hosted git repository.

karan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git


The following commit(s) were added to refs/heads/master by this push:
     new 2b605aa9cf Multiple fixes for the MSQ stats merging piece which 
(#13463)
2b605aa9cf is described below

commit 2b605aa9cf987ccb8e9ace04f71214aa6b508ec4
Author: Adarsh Sanjeev <[email protected]>
AuthorDate: Thu Dec 15 09:35:11 2022 +0530

    Multiple fixes for the MSQ stats merging piece which (#13463)
    
    * Add validation checks to worker chat handler apis
    
    * Merge things and polishing the error messages.
    
    * Minor error message change
    
    * Fixing race and adding some tests
    
    * Fixing controller fetching stats from wrong workers.
    Fixing race
    Changing default mode to Parallel
    Adding logging.
    Fixing exceptions not propagated properly.
    
    * Changing to kernel worker count
    
    * Added a better logic to figure out assigned worker for a stage.
    
    * Nits
    
    * Moving to existing kernel methods
    
    * Adding more coverage
    
    Co-authored-by: cryptoe <[email protected]>
---
 docs/multi-stage-query/reference.md                |   2 +-
 .../org/apache/druid/msq/exec/ControllerImpl.java  |  12 +-
 .../msq/exec/ExceptionWrappingWorkerClient.java    |  14 +-
 .../java/org/apache/druid/msq/exec/WorkerImpl.java |  29 ++-
 .../apache/druid/msq/exec/WorkerSketchFetcher.java |  54 ++++-
 .../druid/msq/indexing/WorkerChatHandler.java      |  49 +++-
 .../druid/msq/kernel/worker/WorkerStageKernel.java |   3 +-
 .../statistics/ClusterByStatisticsSnapshot.java    |   4 +
 .../druid/msq/util/MultiStageQueryContext.java     |   2 +-
 .../org/apache/druid/msq/exec/MSQInsertTest.java   |  26 +++
 .../org/apache/druid/msq/exec/WorkerImplTest.java  |  54 +++++
 .../msq/exec/WorkerSketchFetcherAutoModeTest.java  |  39 +++-
 .../druid/msq/exec/WorkerSketchFetcherTest.java    |  37 ++-
 .../druid/msq/indexing/WorkerChatHandlerTest.java  | 254 +++++++++++++++++++++
 14 files changed, 529 insertions(+), 50 deletions(-)

diff --git a/docs/multi-stage-query/reference.md 
b/docs/multi-stage-query/reference.md
index 8ea9adf61a..5016b6ab48 100644
--- a/docs/multi-stage-query/reference.md
+++ b/docs/multi-stage-query/reference.md
@@ -325,7 +325,7 @@ The following table lists the context parameters for the 
MSQ task engine:
 | `maxParseExceptions`| SELECT, INSERT, REPLACE<br /><br />Maximum number of 
parse exceptions that are ignored while executing the query before it stops 
with `TooManyWarningsFault`. To ignore all the parse exceptions, set the value 
to -1.| 0 |
 | `rowsPerSegment` | INSERT or REPLACE<br /><br />The number of rows per 
segment to target. The actual number of rows per segment may be somewhat higher 
or lower than this number. In most cases, use the default. For general 
information about sizing rows per segment, see [Segment Size 
Optimization](../operations/segment-optimization.md). | 3,000,000 |
 | `indexSpec` | INSERT or REPLACE<br /><br />An 
[`indexSpec`](../ingestion/ingestion-spec.md#indexspec) to use when generating 
segments. May be a JSON string or object. See [Front 
coding](../ingestion/ingestion-spec.md#front-coding) for details on configuring 
an `indexSpec` with front coding. | See 
[`indexSpec`](../ingestion/ingestion-spec.md#indexspec). |
-| `clusterStatisticsMergeMode` | Whether to use parallel or sequential mode 
for merging of the worker sketches. Can be `PARALLEL`, `SEQUENTIAL` or `AUTO`. 
See [Sketch Merging Mode](#sketch-merging-mode) for more information. | `AUTO` |
+| `clusterStatisticsMergeMode` | Whether to use parallel or sequential mode 
for merging of the worker sketches. Can be `PARALLEL`, `SEQUENTIAL` or `AUTO`. 
See [Sketch Merging Mode](#sketch-merging-mode) for more information. | 
`PARALLEL` |
 
 ## Sketch Merging Mode
 This section details the advantages and performance of various Cluster By 
Statistics Merge Modes.
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java
index 528baa4c27..f36b91e55e 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java
@@ -263,6 +263,7 @@ public class ControllerImpl implements Controller
   // For live reports. Written by the main controller thread, read by HTTP 
threads.
   private final ConcurrentHashMap<Integer, Integer> 
stagePartitionCountsForLiveReports = new ConcurrentHashMap<>();
 
+
   private WorkerSketchFetcher workerSketchFetcher;
   // Time at which the query started.
   // For live reports. Written by the main controller thread, read by HTTP 
threads.
@@ -624,14 +625,21 @@ public class ControllerImpl implements Controller
                 workerSketchFetcher.submitFetcherTask(
                     completeKeyStatisticsInformation,
                     workerTaskIds,
-                    stageDef
+                    stageDef,
+                    queryKernel.getWorkerInputsForStage(stageId).workers()
+                    // we only need tasks which are active for this stage.
                 );
 
             // Add the listener to handle completion.
             
clusterByPartitionsCompletableFuture.whenComplete((clusterByPartitionsEither, 
throwable) -> {
               addToKernelManipulationQueue(holder -> {
                 if (throwable != null) {
-                  holder.failStageForReason(stageId, 
UnknownFault.forException(throwable));
+                  log.error("Error while fetching stats for stageId[%s]", 
stageId);
+                  if (throwable instanceof MSQException) {
+                    holder.failStageForReason(stageId, ((MSQException) 
throwable).getFault());
+                  } else {
+                    holder.failStageForReason(stageId, 
UnknownFault.forException(throwable));
+                  }
                 } else if (clusterByPartitionsEither.isError()) {
                   holder.failStageForReason(stageId, new 
TooManyPartitionsFault(stageDef.getMaxPartitionCount()));
                 } else {
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ExceptionWrappingWorkerClient.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ExceptionWrappingWorkerClient.java
index 3d78b7c9ce..eb6b1af529 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ExceptionWrappingWorkerClient.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ExceptionWrappingWorkerClient.java
@@ -57,9 +57,13 @@ public class ExceptionWrappingWorkerClient implements 
WorkerClient
   }
 
   @Override
-  public ListenableFuture<ClusterByStatisticsSnapshot> 
fetchClusterByStatisticsSnapshot(String workerTaskId, String queryId, int 
stageNumber)
+  public ListenableFuture<ClusterByStatisticsSnapshot> 
fetchClusterByStatisticsSnapshot(
+      String workerTaskId,
+      String queryId,
+      int stageNumber
+  )
   {
-    return client.fetchClusterByStatisticsSnapshot(workerTaskId, queryId, 
stageNumber);
+    return wrap(workerTaskId, client, c -> 
c.fetchClusterByStatisticsSnapshot(workerTaskId, queryId, stageNumber));
   }
 
   @Override
@@ -70,7 +74,11 @@ public class ExceptionWrappingWorkerClient implements 
WorkerClient
       long timeChunk
   )
   {
-    return client.fetchClusterByStatisticsSnapshotForTimeChunk(workerTaskId, 
queryId, stageNumber, timeChunk);
+    return wrap(
+        workerTaskId,
+        client,
+        c -> c.fetchClusterByStatisticsSnapshotForTimeChunk(workerTaskId, 
queryId, stageNumber, timeChunk)
+    );
   }
 
   @Override
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java
index 49d6f9080d..8c5a782f53 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java
@@ -571,16 +571,37 @@ public class WorkerImpl implements Worker
   @Override
   public ClusterByStatisticsSnapshot fetchStatisticsSnapshot(StageId stageId)
   {
-    return stageKernelMap.get(stageId).getResultKeyStatisticsSnapshot();
+    if (stageKernelMap.get(stageId) == null) {
+      throw new ISE("Requested statistics snapshot for non-existent stageId 
%s.", stageId);
+    } else if (stageKernelMap.get(stageId).getResultKeyStatisticsSnapshot() == 
null) {
+      throw new ISE(
+          "Requested statistics snapshot is not generated yet for stageId[%s]",
+          stageId
+      );
+    } else {
+      return stageKernelMap.get(stageId).getResultKeyStatisticsSnapshot();
+    }
   }
 
   @Override
   public ClusterByStatisticsSnapshot 
fetchStatisticsSnapshotForTimeChunk(StageId stageId, long timeChunk)
   {
-    ClusterByStatisticsSnapshot snapshot = 
stageKernelMap.get(stageId).getResultKeyStatisticsSnapshot();
-    return snapshot.getSnapshotForTimeChunk(timeChunk);
+    if (stageKernelMap.get(stageId) == null) {
+      throw new ISE("Requested statistics snapshot for non-existent 
stageId[%s].", stageId);
+    } else if (stageKernelMap.get(stageId).getResultKeyStatisticsSnapshot() == 
null) {
+      throw new ISE(
+          "Requested statistics snapshot is not generated yet for stageId[%s]",
+          stageId
+      );
+    } else {
+      return stageKernelMap.get(stageId)
+                           .getResultKeyStatisticsSnapshot()
+                           .getSnapshotForTimeChunk(timeChunk);
+    }
+
   }
 
+
   @Override
   public CounterSnapshotsTree getCounters()
   {
@@ -643,7 +664,7 @@ public class WorkerImpl implements Worker
   /**
    * Decorates the server-wide {@link QueryProcessingPool} such that any 
Callables and Runnables, not just
    * {@link PrioritizedCallable} and {@link PrioritizedRunnable}, may be added 
to it.
-   *
+   * <p>
    * In production, the underlying {@link QueryProcessingPool} pool is set up 
by
    * {@link org.apache.druid.guice.DruidProcessingModule}.
    */
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java
index dc6f219905..2eba0c409d 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java
@@ -20,6 +20,7 @@
 package org.apache.druid.msq.exec;
 
 import com.google.common.util.concurrent.ListenableFuture;
+import it.unimi.dsi.fastutil.ints.IntSet;
 import org.apache.druid.frame.key.ClusterBy;
 import org.apache.druid.frame.key.ClusterByPartition;
 import org.apache.druid.frame.key.ClusterByPartitions;
@@ -40,7 +41,7 @@ import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutorService;
-import java.util.stream.IntStream;
+import java.util.stream.Collectors;
 
 /**
  * Queues up fetching sketches from workers and progressively generates 
partitions boundaries.
@@ -78,7 +79,8 @@ public class WorkerSketchFetcher implements AutoCloseable
   public CompletableFuture<Either<Long, ClusterByPartitions>> 
submitFetcherTask(
       CompleteKeyStatisticsInformation completeKeyStatisticsInformation,
       List<String> workerTaskIds,
-      StageDefinition stageDefinition
+      StageDefinition stageDefinition,
+      IntSet workersForStage
   )
   {
     ClusterBy clusterBy = stageDefinition.getClusterBy();
@@ -87,18 +89,31 @@ public class WorkerSketchFetcher implements AutoCloseable
       case SEQUENTIAL:
         return sequentialTimeChunkMerging(completeKeyStatisticsInformation, 
stageDefinition, workerTaskIds);
       case PARALLEL:
-        return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
+        return inMemoryFullSketchMerging(stageDefinition, workerTaskIds, 
workersForStage);
       case AUTO:
         if (clusterBy.getBucketByCount() == 0) {
-          log.info("Query [%s] AUTO mode: chose PARALLEL mode to merge key 
statistics", stageDefinition.getId().getQueryId());
+          log.info(
+              "Query[%s] stage[%d] for AUTO mode: chose PARALLEL mode to merge 
key statistics",
+              stageDefinition.getId().getQueryId(),
+              stageDefinition.getStageNumber()
+          );
           // If there is no time clustering, there is no scope for sequential 
merge
-          return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
-        } else if (stageDefinition.getMaxWorkerCount() > WORKER_THRESHOLD || 
completeKeyStatisticsInformation.getBytesRetained() > BYTES_THRESHOLD) {
-          log.info("Query [%s] AUTO mode: chose SEQUENTIAL mode to merge key 
statistics", stageDefinition.getId().getQueryId());
+          return inMemoryFullSketchMerging(stageDefinition, workerTaskIds, 
workersForStage);
+        } else if (stageDefinition.getMaxWorkerCount() > WORKER_THRESHOLD
+                   || completeKeyStatisticsInformation.getBytesRetained() > 
BYTES_THRESHOLD) {
+          log.info(
+              "Query[%s] stage[%d] for AUTO mode: chose SEQUENTIAL mode to 
merge key statistics",
+              stageDefinition.getId().getQueryId(),
+              stageDefinition.getStageNumber()
+          );
           return sequentialTimeChunkMerging(completeKeyStatisticsInformation, 
stageDefinition, workerTaskIds);
         }
-        log.info("Query [%s] AUTO mode: chose PARALLEL mode to merge key 
statistics", stageDefinition.getId().getQueryId());
-        return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
+        log.info(
+            "Query[%s] stage[%d] for AUTO mode: chose PARALLEL mode to merge 
key statistics",
+            stageDefinition.getId().getQueryId(),
+            stageDefinition.getStageNumber()
+        );
+        return inMemoryFullSketchMerging(stageDefinition, workerTaskIds, 
workersForStage);
       default:
         throw new IllegalStateException("No fetching strategy found for mode: 
" + clusterStatisticsMergeMode);
     }
@@ -111,7 +126,8 @@ public class WorkerSketchFetcher implements AutoCloseable
    */
   CompletableFuture<Either<Long, ClusterByPartitions>> 
inMemoryFullSketchMerging(
       StageDefinition stageDefinition,
-      List<String> workerTaskIds
+      List<String> workerTaskIds,
+      IntSet workersForStage
   )
   {
     CompletableFuture<Either<Long, ClusterByPartitions>> partitionFuture = new 
CompletableFuture<>();
@@ -119,12 +135,19 @@ public class WorkerSketchFetcher implements AutoCloseable
     // Create a new key statistics collector to merge worker sketches into
     final ClusterByStatisticsCollector mergedStatisticsCollector =
         
stageDefinition.createResultKeyStatisticsCollector(statisticsMaxRetainedBytes);
-    final int workerCount = workerTaskIds.size();
+    final int workerCount = workersForStage.size();
     // Guarded by synchronized mergedStatisticsCollector
     final Set<Integer> finishedWorkers = new HashSet<>();
 
+    log.info(
+        "Fetching stats using %s for stage[%d] for workers[%s] ",
+        ClusterStatisticsMergeMode.PARALLEL,
+        stageDefinition.getStageNumber(),
+        
workersForStage.stream().map(Object::toString).collect(Collectors.joining(","))
+    );
+
     // Submit a task for each worker to fetch statistics
-    IntStream.range(0, workerCount).forEach(workerNo -> {
+    workersForStage.forEach(workerNo -> {
       executorService.submit(() -> {
         ListenableFuture<ClusterByStatisticsSnapshot> snapshotFuture =
             workerClient.fetchClusterByStatisticsSnapshot(
@@ -177,6 +200,13 @@ public class WorkerSketchFetcher implements AutoCloseable
         workerTaskIds,
         
completeKeyStatisticsInformation.getTimeSegmentVsWorkerMap().entrySet().iterator()
     );
+
+    log.info(
+        "Fetching stats using %s for stage[%d] for tasks[%s]",
+        ClusterStatisticsMergeMode.SEQUENTIAL,
+        stageDefinition.getStageNumber(),
+        String.join("", workerTaskIds)
+    );
     sequentialFetchStage.submitFetchingTasksForNextTimeChunk();
     return sequentialFetchStage.getPartitionFuture();
   }
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/WorkerChatHandler.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/WorkerChatHandler.java
index dd6ea7cb71..3eae3b05cc 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/WorkerChatHandler.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/WorkerChatHandler.java
@@ -19,11 +19,13 @@
 
 package org.apache.druid.msq.indexing;
 
+import com.google.common.collect.ImmutableMap;
 import it.unimi.dsi.fastutil.bytes.ByteArrays;
 import org.apache.commons.lang.mutable.MutableLong;
 import org.apache.druid.frame.file.FrameFileHttpResponseHandler;
 import org.apache.druid.frame.key.ClusterByPartitions;
 import org.apache.druid.indexing.common.TaskToolbox;
+import org.apache.druid.java.util.common.StringUtils;
 import org.apache.druid.java.util.common.logger.Logger;
 import org.apache.druid.msq.exec.Worker;
 import org.apache.druid.msq.kernel.StageId;
@@ -71,7 +73,7 @@ public class WorkerChatHandler implements ChatHandler
 
   /**
    * Returns up to {@link #CHANNEL_DATA_CHUNK_SIZE} bytes of stage output data.
-   *
+   * <p>
    * See {@link org.apache.druid.msq.exec.WorkerClient#fetchChannelData} for 
the client-side code that calls this API.
    */
   @GET
@@ -193,17 +195,30 @@ public class WorkerChatHandler implements ChatHandler
     ChatHandlers.authorizationCheck(req, Action.READ, task.getDataSource(), 
toolbox.getAuthorizerMapper());
     ClusterByStatisticsSnapshot clusterByStatisticsSnapshot;
     StageId stageId = new StageId(queryId, stageNumber);
-    clusterByStatisticsSnapshot = worker.fetchStatisticsSnapshot(stageId);
-    return Response.status(Response.Status.ACCEPTED)
-                   .entity(clusterByStatisticsSnapshot)
-                   .build();
+    try {
+      clusterByStatisticsSnapshot = worker.fetchStatisticsSnapshot(stageId);
+      return Response.status(Response.Status.ACCEPTED)
+                     .entity(clusterByStatisticsSnapshot)
+                     .build();
+    }
+    catch (Exception e) {
+      String errorMessage = StringUtils.format(
+          "Invalid request for key statistics for query[%s] and stage[%d]",
+          queryId,
+          stageNumber
+      );
+      log.error(e, errorMessage);
+      return Response.status(Response.Status.BAD_REQUEST)
+                     .entity(ImmutableMap.<String, Object>of("error", 
errorMessage))
+                     .build();
+    }
   }
 
   @POST
   @Path("/keyStatisticsForTimeChunk/{queryId}/{stageNumber}/{timeChunk}")
   @Produces(MediaType.APPLICATION_JSON)
   @Consumes(MediaType.APPLICATION_JSON)
-  public Response httpSketch(
+  public Response httpFetchKeyStatisticsWithSnapshot(
       @PathParam("queryId") final String queryId,
       @PathParam("stageNumber") final int stageNumber,
       @PathParam("timeChunk") final long timeChunk,
@@ -213,10 +228,24 @@ public class WorkerChatHandler implements ChatHandler
     ChatHandlers.authorizationCheck(req, Action.READ, task.getDataSource(), 
toolbox.getAuthorizerMapper());
     ClusterByStatisticsSnapshot snapshotForTimeChunk;
     StageId stageId = new StageId(queryId, stageNumber);
-    snapshotForTimeChunk = worker.fetchStatisticsSnapshotForTimeChunk(stageId, 
timeChunk);
-    return Response.status(Response.Status.ACCEPTED)
-                   .entity(snapshotForTimeChunk)
-                   .build();
+    try {
+      snapshotForTimeChunk = 
worker.fetchStatisticsSnapshotForTimeChunk(stageId, timeChunk);
+      return Response.status(Response.Status.ACCEPTED)
+                     .entity(snapshotForTimeChunk)
+                     .build();
+    }
+    catch (Exception e) {
+      String errorMessage = StringUtils.format(
+          "Invalid request for key statistics for query[%s], stage[%d] and 
timeChunk[%d]",
+          queryId,
+          stageNumber,
+          timeChunk
+      );
+      log.error(e, errorMessage);
+      return Response.status(Response.Status.BAD_REQUEST)
+                     .entity(ImmutableMap.<String, Object>of("error", 
errorMessage))
+                     .build();
+    }
   }
 
   /**
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/worker/WorkerStageKernel.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/worker/WorkerStageKernel.java
index b0ed8e5c19..00a49656be 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/worker/WorkerStageKernel.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/worker/WorkerStageKernel.java
@@ -48,8 +48,9 @@ public class WorkerStageKernel
 
   private WorkerStagePhase phase = WorkerStagePhase.NEW;
 
+  // We read this variable in the main thread and the netty threads
   @Nullable
-  private ClusterByStatisticsSnapshot resultKeyStatisticsSnapshot;
+  private volatile ClusterByStatisticsSnapshot resultKeyStatisticsSnapshot;
 
   @Nullable
   private ClusterByPartitions resultPartitionBoundaries;
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/ClusterByStatisticsSnapshot.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/ClusterByStatisticsSnapshot.java
index e54253ad21..16a4c1656b 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/ClusterByStatisticsSnapshot.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/ClusterByStatisticsSnapshot.java
@@ -25,6 +25,7 @@ import com.fasterxml.jackson.annotation.JsonProperty;
 import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableMap;
 import org.apache.druid.frame.key.RowKey;
+import org.apache.druid.java.util.common.ISE;
 
 import javax.annotation.Nullable;
 import java.util.Collections;
@@ -61,6 +62,9 @@ public class ClusterByStatisticsSnapshot
   public ClusterByStatisticsSnapshot getSnapshotForTimeChunk(long timeChunk)
   {
     Bucket bucket = buckets.get(timeChunk);
+    if (bucket == null) {
+      throw new ISE("ClusterByStatistics not present for requested timechunk 
%s", timeChunk);
+    }
     return new ClusterByStatisticsSnapshot(ImmutableMap.of(timeChunk, bucket), 
null);
   }
 
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java
index 7c589f2326..3dc622870f 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java
@@ -60,7 +60,7 @@ public class MultiStageQueryContext
 
   public static final String CTX_ENABLE_DURABLE_SHUFFLE_STORAGE = 
"durableShuffleStorage";
   public static final String CTX_CLUSTER_STATISTICS_MERGE_MODE = 
"clusterStatisticsMergeMode";
-  public static final String DEFAULT_CLUSTER_STATISTICS_MERGE_MODE = 
ClusterStatisticsMergeMode.AUTO.toString();
+  public static final String DEFAULT_CLUSTER_STATISTICS_MERGE_MODE = 
ClusterStatisticsMergeMode.PARALLEL.toString();
   private static final boolean DEFAULT_ENABLE_DURABLE_SHUFFLE_STORAGE = false;
 
   public static final String CTX_DESTINATION = "destination";
diff --git 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQInsertTest.java
 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQInsertTest.java
index f54d2fa880..cf4e4052d3 100644
--- 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQInsertTest.java
+++ 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQInsertTest.java
@@ -128,6 +128,32 @@ public class MSQInsertTest extends MSQTestBase
 
   }
 
+  @Test
+  public void testInsertOnFoo1WithTimeFunctionWithSequential()
+  {
+    RowSignature rowSignature = RowSignature.builder()
+                                            .add("__time", ColumnType.LONG)
+                                            .add("dim1", ColumnType.STRING)
+                                            .add("cnt", 
ColumnType.LONG).build();
+    Map<String, Object> context = ImmutableMap.<String, Object>builder()
+                                              .putAll(DEFAULT_MSQ_CONTEXT)
+                                              .put(
+                                                  
MultiStageQueryContext.CTX_CLUSTER_STATISTICS_MERGE_MODE,
+                                                  
ClusterStatisticsMergeMode.SEQUENTIAL.toString()
+                                              )
+                                              .build();
+
+    testIngestQuery().setSql(
+                         "insert into foo1 select  floor(__time to day) as 
__time , dim1 , count(*) as cnt from foo where dim1 is not null group by 1, 2 
PARTITIONED by day clustered by dim1")
+                     .setQueryContext(context)
+                     .setExpectedDataSource("foo1")
+                     .setExpectedRowSignature(rowSignature)
+                     .setExpectedSegment(expectedFooSegments())
+                     .setExpectedResultRows(expectedFooRows())
+                     .verifyResults();
+
+  }
+
   @Test
   public void testInsertOnFoo1WithMultiValueDim()
   {
diff --git 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerImplTest.java
 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerImplTest.java
new file mode 100644
index 0000000000..52231a116b
--- /dev/null
+++ 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerImplTest.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.msq.exec;
+
+import org.apache.druid.java.util.common.ISE;
+import org.apache.druid.msq.indexing.MSQWorkerTask;
+import org.apache.druid.msq.kernel.StageId;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.junit.MockitoJUnitRunner;
+
+import java.util.HashMap;
+
+
+@RunWith(MockitoJUnitRunner.class)
+public class WorkerImplTest
+{
+  @Mock
+  WorkerContext workerContext;
+
+  @Test
+  public void testFetchStatsThrows()
+  {
+    WorkerImpl worker = new WorkerImpl(new MSQWorkerTask("controller", "ds", 
1, new HashMap<>()), workerContext);
+    Assert.assertThrows(ISE.class, () -> worker.fetchStatisticsSnapshot(new 
StageId("xx", 1)));
+  }
+
+  @Test
+  public void testFetchStatsWithTimeChunkThrows()
+  {
+    WorkerImpl worker = new WorkerImpl(new MSQWorkerTask("controller", "ds", 
1, new HashMap<>()), workerContext);
+    Assert.assertThrows(ISE.class, () -> 
worker.fetchStatisticsSnapshotForTimeChunk(new StageId("xx", 1), 1L));
+  }
+
+}
diff --git 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherAutoModeTest.java
 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherAutoModeTest.java
index 42f6f0437f..02be2876f9 100644
--- 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherAutoModeTest.java
+++ 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherAutoModeTest.java
@@ -19,6 +19,7 @@
 
 package org.apache.druid.msq.exec;
 
+import it.unimi.dsi.fastutil.ints.IntSet;
 import org.apache.druid.frame.key.ClusterBy;
 import org.apache.druid.msq.kernel.StageDefinition;
 import org.apache.druid.msq.kernel.StageId;
@@ -56,7 +57,7 @@ public class WorkerSketchFetcherAutoModeTest
 
     target = spy(new WorkerSketchFetcher(mock(WorkerClient.class), 
ClusterStatisticsMergeMode.AUTO, 300_000_000));
     // Don't actually try to fetch sketches
-    doReturn(null).when(target).inMemoryFullSketchMerging(any(), any());
+    doReturn(null).when(target).inMemoryFullSketchMerging(any(), any(), any());
     doReturn(null).when(target).sequentialTimeChunkMerging(any(), any(), 
any());
 
     doReturn(StageId.fromString("1_1")).when(stageDefinition).getId();
@@ -81,8 +82,13 @@ public class WorkerSketchFetcherAutoModeTest
     // Worker count below threshold
     doReturn(1).when(stageDefinition).getMaxWorkerCount();
 
-    target.submitFetcherTask(completeKeyStatisticsInformation, 
Collections.emptyList(), stageDefinition);
-    verify(target, times(1)).inMemoryFullSketchMerging(any(), any());
+    target.submitFetcherTask(
+        completeKeyStatisticsInformation,
+        Collections.emptyList(),
+        stageDefinition,
+        IntSet.of()
+    );
+    verify(target, times(1)).inMemoryFullSketchMerging(any(), any(), any());
     verify(target, times(0)).sequentialTimeChunkMerging(any(), any(), any());
   }
 
@@ -98,8 +104,13 @@ public class WorkerSketchFetcherAutoModeTest
     // Worker count below threshold
     doReturn((int) WorkerSketchFetcher.WORKER_THRESHOLD + 
1).when(stageDefinition).getMaxWorkerCount();
 
-    target.submitFetcherTask(completeKeyStatisticsInformation, 
Collections.emptyList(), stageDefinition);
-    verify(target, times(0)).inMemoryFullSketchMerging(any(), any());
+    target.submitFetcherTask(
+        completeKeyStatisticsInformation,
+        Collections.emptyList(),
+        stageDefinition,
+        IntSet.of()
+    );
+    verify(target, times(0)).inMemoryFullSketchMerging(any(), any(), any());
     verify(target, times(1)).sequentialTimeChunkMerging(any(), any(), any());
   }
 
@@ -115,8 +126,13 @@ public class WorkerSketchFetcherAutoModeTest
     // Worker count above threshold
     doReturn((int) WorkerSketchFetcher.WORKER_THRESHOLD + 
1).when(stageDefinition).getMaxWorkerCount();
 
-    target.submitFetcherTask(completeKeyStatisticsInformation, 
Collections.emptyList(), stageDefinition);
-    verify(target, times(1)).inMemoryFullSketchMerging(any(), any());
+    target.submitFetcherTask(
+        completeKeyStatisticsInformation,
+        Collections.emptyList(),
+        stageDefinition,
+        IntSet.of()
+    );
+    verify(target, times(1)).inMemoryFullSketchMerging(any(), any(), any());
     verify(target, times(0)).sequentialTimeChunkMerging(any(), any(), any());
   }
 
@@ -132,8 +148,13 @@ public class WorkerSketchFetcherAutoModeTest
     // Worker count below threshold
     doReturn(1).when(stageDefinition).getMaxWorkerCount();
 
-    target.submitFetcherTask(completeKeyStatisticsInformation, 
Collections.emptyList(), stageDefinition);
-    verify(target, times(0)).inMemoryFullSketchMerging(any(), any());
+    target.submitFetcherTask(
+        completeKeyStatisticsInformation,
+        Collections.emptyList(),
+        stageDefinition,
+        IntSet.of()
+    );
+    verify(target, times(0)).inMemoryFullSketchMerging(any(), any(), any());
     verify(target, times(1)).sequentialTimeChunkMerging(any(), any(), any());
   }
 }
diff --git 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java
 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java
index 83fb73043b..fc24490036 100644
--- 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java
+++ 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java
@@ -23,6 +23,8 @@ import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.ImmutableSortedMap;
 import com.google.common.util.concurrent.Futures;
+import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
+import it.unimi.dsi.fastutil.ints.IntSet;
 import org.apache.druid.frame.key.ClusterBy;
 import org.apache.druid.frame.key.ClusterByPartition;
 import org.apache.druid.frame.key.ClusterByPartitions;
@@ -88,11 +90,19 @@ public class WorkerSketchFetcherTest
     doReturn(clusterBy).when(stageDefinition).getClusterBy();
     doReturn(25_000).when(stageDefinition).getMaxPartitionCount();
 
-    expectedPartitions1 = new ClusterByPartitions(ImmutableList.of(new 
ClusterByPartition(mock(RowKey.class), mock(RowKey.class))));
-    expectedPartitions2 = new ClusterByPartitions(ImmutableList.of(new 
ClusterByPartition(mock(RowKey.class), mock(RowKey.class))));
+    expectedPartitions1 = new ClusterByPartitions(ImmutableList.of(new 
ClusterByPartition(
+        mock(RowKey.class),
+        mock(RowKey.class)
+    )));
+    expectedPartitions2 = new ClusterByPartitions(ImmutableList.of(new 
ClusterByPartition(
+        mock(RowKey.class),
+        mock(RowKey.class)
+    )));
 
-    
doReturn(Either.value(expectedPartitions1)).when(stageDefinition).generatePartitionsForShuffle(eq(mergedClusterByStatisticsCollector1));
-    
doReturn(Either.value(expectedPartitions2)).when(stageDefinition).generatePartitionsForShuffle(eq(mergedClusterByStatisticsCollector2));
+    doReturn(Either.value(expectedPartitions1)).when(stageDefinition)
+                                               
.generatePartitionsForShuffle(eq(mergedClusterByStatisticsCollector1));
+    doReturn(Either.value(expectedPartitions2)).when(stageDefinition)
+                                               
.generatePartitionsForShuffle(eq(mergedClusterByStatisticsCollector2));
 
     doReturn(
         mergedClusterByStatisticsCollector1,
@@ -128,10 +138,14 @@ public class WorkerSketchFetcherTest
       return Futures.immediateFuture(snapshot);
     }).when(workerClient).fetchClusterByStatisticsSnapshot(any(), any(), 
anyInt());
 
+    IntSet workersForStage = new IntAVLTreeSet();
+    workersForStage.addAll(ImmutableSet.of(0, 1, 2, 3, 4));
+
     CompletableFuture<Either<Long, ClusterByPartitions>> 
eitherCompletableFuture = target.submitFetcherTask(
         completeKeyStatisticsInformation,
         workerIds,
-        stageDefinition
+        stageDefinition,
+        workersForStage
     );
 
     // Assert that the final result is complete and all other sketches 
returned have been merged.
@@ -154,7 +168,12 @@ public class WorkerSketchFetcherTest
     // Store snapshots in a queue
     final Queue<ClusterByStatisticsSnapshot> snapshotQueue = new 
ConcurrentLinkedQueue<>();
 
-    SortedMap<Long, Set<Integer>> timeSegmentVsWorkerMap = 
ImmutableSortedMap.of(1L, ImmutableSet.of(0, 1, 2), 2L, ImmutableSet.of(0, 1, 
4));
+    SortedMap<Long, Set<Integer>> timeSegmentVsWorkerMap = 
ImmutableSortedMap.of(
+        1L,
+        ImmutableSet.of(0, 1, 2),
+        2L,
+        ImmutableSet.of(0, 1, 4)
+    );
     
doReturn(timeSegmentVsWorkerMap).when(completeKeyStatisticsInformation).getTimeSegmentVsWorkerMap();
 
     final CyclicBarrier barrier = new CyclicBarrier(3);
@@ -168,10 +187,14 @@ public class WorkerSketchFetcherTest
       return Futures.immediateFuture(snapshot);
     }).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(any(), 
any(), anyInt(), anyLong());
 
+    IntSet workersForStage = new IntAVLTreeSet();
+    workersForStage.addAll(ImmutableSet.of(0, 1, 2, 3, 4));
+
     CompletableFuture<Either<Long, ClusterByPartitions>> 
eitherCompletableFuture = target.submitFetcherTask(
         completeKeyStatisticsInformation,
         ImmutableList.of("0", "1", "2", "3", "4"),
-        stageDefinition
+        stageDefinition,
+        workersForStage
     );
 
     // Assert that the final result is complete and all other sketches 
returned have been merged.
diff --git 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/WorkerChatHandlerTest.java
 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/WorkerChatHandlerTest.java
new file mode 100644
index 0000000000..5b9d6e497a
--- /dev/null
+++ 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/WorkerChatHandlerTest.java
@@ -0,0 +1,254 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.msq.indexing;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import org.apache.druid.frame.key.ClusterByPartitions;
+import org.apache.druid.indexer.TaskStatus;
+import org.apache.druid.indexing.common.TaskReport;
+import org.apache.druid.indexing.common.TaskReportFileWriter;
+import org.apache.druid.indexing.common.TaskToolbox;
+import org.apache.druid.jackson.DefaultObjectMapper;
+import org.apache.druid.java.util.common.ISE;
+import org.apache.druid.msq.counters.CounterSnapshotsTree;
+import org.apache.druid.msq.exec.Worker;
+import org.apache.druid.msq.kernel.StageId;
+import org.apache.druid.msq.kernel.WorkOrder;
+import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
+import org.apache.druid.segment.IndexIO;
+import org.apache.druid.segment.IndexMergerV9;
+import 
org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory;
+import org.apache.druid.server.security.AuthConfig;
+import org.apache.druid.server.security.AuthenticationResult;
+import org.apache.druid.sql.calcite.util.CalciteTests;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mock;
+import org.mockito.Mockito;
+import org.mockito.MockitoAnnotations;
+
+import javax.annotation.Nullable;
+import javax.servlet.http.HttpServletRequest;
+import javax.ws.rs.core.Response;
+import java.io.InputStream;
+import java.util.HashMap;
+import java.util.Map;
+
+public class WorkerChatHandlerTest
+{
+  private static final StageId TEST_STAGE = new StageId("123", 0);
+  @Mock
+  private HttpServletRequest req;
+
+  private TaskToolbox toolbox;
+  private AutoCloseable mocks;
+
+  private final TestWorker worker = new TestWorker();
+
+  @Before
+  public void setUp()
+  {
+    ObjectMapper mapper = new DefaultObjectMapper();
+    IndexIO indexIO = new IndexIO(mapper, () -> 0);
+    IndexMergerV9 indexMerger = new IndexMergerV9(
+        mapper,
+        indexIO,
+        OffHeapMemorySegmentWriteOutMediumFactory.instance()
+    );
+
+    mocks = MockitoAnnotations.openMocks(this);
+    Mockito.when(req.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT))
+           .thenReturn(new AuthenticationResult("druid", "druid", null, null));
+    TaskToolbox.Builder builder = new TaskToolbox.Builder();
+    toolbox = builder.authorizerMapper(CalciteTests.TEST_AUTHORIZER_MAPPER)
+                     .indexIO(indexIO)
+                     .indexMergerV9(indexMerger)
+                     .taskReportFileWriter(
+                         new TaskReportFileWriter()
+                         {
+                           @Override
+                           public void write(String taskId, Map<String, 
TaskReport> reports)
+                           {
+
+                           }
+
+                           @Override
+                           public void setObjectMapper(ObjectMapper 
objectMapper)
+                           {
+
+                           }
+                         }
+                     )
+                     .build();
+  }
+
+  @Test
+  public void testFetchSnapshot()
+  {
+    WorkerChatHandler chatHandler = new WorkerChatHandler(toolbox, worker);
+    Assert.assertEquals(
+        ClusterByStatisticsSnapshot.empty(),
+        chatHandler.httpFetchKeyStatistics(TEST_STAGE.getQueryId(), 
TEST_STAGE.getStageNumber(), req)
+                   .getEntity()
+    );
+  }
+
+  @Test
+  public void testFetchSnapshot404()
+  {
+    WorkerChatHandler chatHandler = new WorkerChatHandler(toolbox, worker);
+    Assert.assertEquals(
+        Response.Status.BAD_REQUEST.getStatusCode(),
+        chatHandler.httpFetchKeyStatistics("123", 2, req)
+                   .getStatus()
+    );
+  }
+
+  @Test
+  public void testFetchSnapshotWithTimeChunk()
+  {
+    WorkerChatHandler chatHandler = new WorkerChatHandler(toolbox, worker);
+    Assert.assertEquals(
+        ClusterByStatisticsSnapshot.empty(),
+        
chatHandler.httpFetchKeyStatisticsWithSnapshot(TEST_STAGE.getQueryId(), 
TEST_STAGE.getStageNumber(), 1, req)
+                   .getEntity()
+    );
+  }
+
+  @Test
+  public void testFetchSnapshotWithTimeChunk404()
+  {
+    WorkerChatHandler chatHandler = new WorkerChatHandler(toolbox, worker);
+    Assert.assertEquals(
+        Response.Status.BAD_REQUEST.getStatusCode(),
+        chatHandler.httpFetchKeyStatisticsWithSnapshot("123", 2, 1, req)
+                   .getStatus()
+    );
+  }
+
+
+  private static class TestWorker implements Worker
+  {
+
+    @Override
+    public String id()
+    {
+      return TEST_STAGE.getQueryId() + "task";
+    }
+
+    @Override
+    public MSQWorkerTask task()
+    {
+      return new MSQWorkerTask("controller", "ds", 1, new HashMap<>());
+    }
+
+    @Override
+    public TaskStatus run()
+    {
+      return null;
+    }
+
+    @Override
+    public void stopGracefully()
+    {
+
+    }
+
+    @Override
+    public void controllerFailed()
+    {
+
+    }
+
+    @Override
+    public void postWorkOrder(WorkOrder workOrder)
+    {
+
+    }
+
+    @Override
+    public ClusterByStatisticsSnapshot fetchStatisticsSnapshot(StageId stageId)
+    {
+      if (TEST_STAGE.equals(stageId)) {
+        return ClusterByStatisticsSnapshot.empty();
+      } else {
+        throw new ISE("stage not found %s", stageId);
+      }
+    }
+
+    @Override
+    public ClusterByStatisticsSnapshot 
fetchStatisticsSnapshotForTimeChunk(StageId stageId, long timeChunk)
+    {
+      if (TEST_STAGE.equals(stageId)) {
+        return ClusterByStatisticsSnapshot.empty();
+      } else {
+        throw new ISE("stage not found %s", stageId);
+      }
+    }
+
+    @Override
+    public boolean postResultPartitionBoundaries(
+        ClusterByPartitions stagePartitionBoundaries,
+        String queryId,
+        int stageNumber
+    )
+    {
+      return false;
+    }
+
+    @Nullable
+    @Override
+    public InputStream readChannel(String queryId, int stageNumber, int 
partitionNumber, long offset)
+    {
+      return null;
+    }
+
+    @Override
+    public CounterSnapshotsTree getCounters()
+    {
+      return null;
+    }
+
+    @Override
+    public void postCleanupStage(StageId stageId)
+    {
+
+    }
+
+    @Override
+    public void postFinish()
+    {
+
+    }
+  }
+
+  @After
+  public void tearDown()
+  {
+    try {
+      mocks.close();
+    }
+    catch (Exception ignored) {
+      // ignore tear down exceptions
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to