cryptoe commented on code in PR #13205:
URL: https://github.com/apache/druid/pull/13205#discussion_r1024229108


##########
extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java:
##########
@@ -595,7 +611,35 @@ public void updateStatus(int stageNumber, int 
workerNumber, Object keyStatistics
             );
           }
 
-          queryKernel.addResultKeyStatisticsForStageAndWorker(stageId, 
workerNumber, keyStatistics);
+          queryKernel.addPartialKeyStatisticsForStageAndWorker(stageId, 
workerNumber, partialKeyStatisticsInformation);
+
+          if 
(queryKernel.getStagePhase(stageId).equals(ControllerStagePhase.MERGING_STATISTICS))
 {
+            List<String> workerTaskIds = workerTaskLauncher.getTaskList();
+            CompleteKeyStatisticsInformation completeKeyStatisticsInformation =
+                queryKernel.getCompleteKeyStatisticsInformation(stageId);
+
+            // Queue the sketch fetching task into the worker sketch fetcher.
+            CompletableFuture<Either<Long, ClusterByPartitions>> 
clusterByPartitionsCompletableFuture =
+                workerSketchFetcher.submitFetcherTask(
+                    completeKeyStatisticsInformation,
+                    workerTaskIds,
+                    stageDef
+                );
+
+            // Add the listener to handle completion.
+            
clusterByPartitionsCompletableFuture.whenComplete((clusterByPartitionsEither, 
throwable) -> {
+              kernelManipulationQueue.add(holder -> {
+                if (throwable != null) {
+                  queryKernel.failStageForReason(stageId, 
UnknownFault.forException(throwable));
+                } else if (clusterByPartitionsEither.isError()) {
+                  queryKernel.failStageForReason(stageId, new 
TooManyPartitionsFault(stageDef.getMaxPartitionCount()));
+                } else {
+                  queryKernel.setClusterByPartitionBoundaries(stageId, 
clusterByPartitionsEither.valueOrThrow());
+                }
+                holder.transitionStageKernel(stageId, 
queryKernel.getStagePhase(stageId));
+              });
+            });
+          }

Review Comment:
   lets debug log the else part as well ?



##########
extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java:
##########
@@ -562,6 +568,19 @@ public void postFinish()
     kernelManipulationQueue.add(KernelHolder::setDone);
   }
 
+  @Override
+  public ClusterByStatisticsSnapshot fetchStatisticsSnapshot(StageId stageId)
+  {
+    return stageKernelMap.get(stageId).getResultKeyStatisticsSnapshot();

Review Comment:
   What happens if the stage does not have result key stats. 



##########
extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Controller.java:
##########
@@ -81,9 +81,11 @@ public String getId()
   // Worker-to-controller messages
 
   /**
-   * Provide a {@link ClusterByStatisticsSnapshot} for shuffling stages.
+   * Accepts a {@link PartialKeyStatisticsInformation} and updates the 
controller key statistics information. If all key
+   * statistics have been gathered, enqueues the task with the {@link 
WorkerSketchFetcher} to generate partiton boundaries.
+   * This is intended to be called by the {@link 
org.apache.druid.msq.indexing.ControllerChatHandler}.
    */
-  void updateStatus(int stageNumber, int workerNumber, Object 
keyStatisticsObject);
+  void updatePartialKeyStatistics(int stageNumber, int workerNumber, Object 
partialKeyStatisticsObject);

Review Comment:
   Lets rename the variable to partialKeyStatisticsInformation



##########
extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java:
##########
@@ -595,7 +611,35 @@ public void updateStatus(int stageNumber, int 
workerNumber, Object keyStatistics
             );
           }
 
-          queryKernel.addResultKeyStatisticsForStageAndWorker(stageId, 
workerNumber, keyStatistics);
+          queryKernel.addPartialKeyStatisticsForStageAndWorker(stageId, 
workerNumber, partialKeyStatisticsInformation);
+
+          if 
(queryKernel.getStagePhase(stageId).equals(ControllerStagePhase.MERGING_STATISTICS))
 {
+            List<String> workerTaskIds = workerTaskLauncher.getTaskList();
+            CompleteKeyStatisticsInformation completeKeyStatisticsInformation =
+                queryKernel.getCompleteKeyStatisticsInformation(stageId);
+
+            // Queue the sketch fetching task into the worker sketch fetcher.
+            CompletableFuture<Either<Long, ClusterByPartitions>> 
clusterByPartitionsCompletableFuture =
+                workerSketchFetcher.submitFetcherTask(
+                    completeKeyStatisticsInformation,
+                    workerTaskIds,
+                    stageDef
+                );
+
+            // Add the listener to handle completion.
+            
clusterByPartitionsCompletableFuture.whenComplete((clusterByPartitionsEither, 
throwable) -> {
+              kernelManipulationQueue.add(holder -> {
+                if (throwable != null) {
+                  queryKernel.failStageForReason(stageId, 
UnknownFault.forException(throwable));

Review Comment:
   Also there is a method called addToKernelManipulationQueue which may be 
used. 



##########
extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.msq.exec;
+
+import com.google.common.util.concurrent.ListenableFuture;
+import org.apache.druid.frame.key.ClusterBy;
+import org.apache.druid.frame.key.ClusterByPartition;
+import org.apache.druid.frame.key.ClusterByPartitions;
+import org.apache.druid.java.util.common.Either;
+import org.apache.druid.msq.kernel.StageDefinition;
+import org.apache.druid.msq.statistics.ClusterByStatisticsCollector;
+import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
+import org.apache.druid.msq.statistics.CompleteKeyStatisticsInformation;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.stream.IntStream;
+
+/**
+ * Queues up fetching sketches from workers and progressively generates 
partitions boundaries.
+ */
+public class WorkerSketchFetcher
+{
+  private static final int DEFAULT_THREAD_COUNT = 4;
+  // If the combined size of worker sketches is more than this threshold, 
SEQUENTIAL merging mode is used.
+  private static final long BYTES_THRESHOLD = 1_000_000_000L;
+  // If there are more workers than this threshold, SEQUENTIAL merging mode is 
used.
+  private static final long WORKER_THRESHOLD = 100;
+
+  private final ClusterStatisticsMergeMode clusterStatisticsMergeMode;
+  private final int statisticsMaxRetainedBytes;
+  private final WorkerClient workerClient;
+  private final ExecutorService executorService;
+
+  public WorkerSketchFetcher(WorkerClient workerClient, 
ClusterStatisticsMergeMode clusterStatisticsMergeMode, int 
statisticsMaxRetainedBytes)
+  {
+    this.workerClient = workerClient;
+    this.clusterStatisticsMergeMode = clusterStatisticsMergeMode;
+    this.executorService = Executors.newFixedThreadPool(DEFAULT_THREAD_COUNT);
+    this.statisticsMaxRetainedBytes = statisticsMaxRetainedBytes;
+  }
+
+  /**
+   * Submits a request to fetch and generate partitions for the given worker 
statistics and returns a future for it. It
+   * decides based on the statistics if it should fetch sketches one by one or 
together.
+   */
+  public CompletableFuture<Either<Long, ClusterByPartitions>> 
submitFetcherTask(
+      CompleteKeyStatisticsInformation completeKeyStatisticsInformation,
+      List<String> workerTaskIds,
+      StageDefinition stageDefinition
+  )
+  {
+    ClusterBy clusterBy = stageDefinition.getClusterBy();
+
+    switch (clusterStatisticsMergeMode) {
+      case SEQUENTIAL:
+        return sequentialTimeChunkMerging(completeKeyStatisticsInformation, 
stageDefinition, workerTaskIds);
+      case PARALLEL:
+        return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
+      case AUTO:
+        if (clusterBy.getBucketByCount() == 0) {
+          // If there is no time clustering, there is no scope for sequential 
merge
+          return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
+        } else if (stageDefinition.getMaxWorkerCount() > WORKER_THRESHOLD || 
completeKeyStatisticsInformation.getBytesRetained() > BYTES_THRESHOLD) {
+          return sequentialTimeChunkMerging(completeKeyStatisticsInformation, 
stageDefinition, workerTaskIds);

Review Comment:
   Lets debug log the mode eventually chosen here?



##########
extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.msq.exec;
+
+import com.google.common.util.concurrent.ListenableFuture;
+import org.apache.druid.frame.key.ClusterBy;
+import org.apache.druid.frame.key.ClusterByPartition;
+import org.apache.druid.frame.key.ClusterByPartitions;
+import org.apache.druid.java.util.common.Either;
+import org.apache.druid.msq.kernel.StageDefinition;
+import org.apache.druid.msq.statistics.ClusterByStatisticsCollector;
+import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
+import org.apache.druid.msq.statistics.CompleteKeyStatisticsInformation;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.stream.IntStream;
+
+/**
+ * Queues up fetching sketches from workers and progressively generates 
partitions boundaries.
+ */
+public class WorkerSketchFetcher
+{
+  private static final int DEFAULT_THREAD_COUNT = 4;
+  // If the combined size of worker sketches is more than this threshold, 
SEQUENTIAL merging mode is used.
+  private static final long BYTES_THRESHOLD = 1_000_000_000L;
+  // If there are more workers than this threshold, SEQUENTIAL merging mode is 
used.
+  private static final long WORKER_THRESHOLD = 100;
+
+  private final ClusterStatisticsMergeMode clusterStatisticsMergeMode;
+  private final int statisticsMaxRetainedBytes;
+  private final WorkerClient workerClient;
+  private final ExecutorService executorService;
+
+  public WorkerSketchFetcher(WorkerClient workerClient, 
ClusterStatisticsMergeMode clusterStatisticsMergeMode, int 
statisticsMaxRetainedBytes)
+  {
+    this.workerClient = workerClient;
+    this.clusterStatisticsMergeMode = clusterStatisticsMergeMode;
+    this.executorService = Executors.newFixedThreadPool(DEFAULT_THREAD_COUNT);
+    this.statisticsMaxRetainedBytes = statisticsMaxRetainedBytes;
+  }
+
+  /**
+   * Submits a request to fetch and generate partitions for the given worker 
statistics and returns a future for it. It
+   * decides based on the statistics if it should fetch sketches one by one or 
together.
+   */
+  public CompletableFuture<Either<Long, ClusterByPartitions>> 
submitFetcherTask(
+      CompleteKeyStatisticsInformation completeKeyStatisticsInformation,
+      List<String> workerTaskIds,
+      StageDefinition stageDefinition
+  )
+  {
+    ClusterBy clusterBy = stageDefinition.getClusterBy();
+
+    switch (clusterStatisticsMergeMode) {
+      case SEQUENTIAL:
+        return sequentialTimeChunkMerging(completeKeyStatisticsInformation, 
stageDefinition, workerTaskIds);
+      case PARALLEL:
+        return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
+      case AUTO:
+        if (clusterBy.getBucketByCount() == 0) {
+          // If there is no time clustering, there is no scope for sequential 
merge
+          return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
+        } else if (stageDefinition.getMaxWorkerCount() > WORKER_THRESHOLD || 
completeKeyStatisticsInformation.getBytesRetained() > BYTES_THRESHOLD) {
+          return sequentialTimeChunkMerging(completeKeyStatisticsInformation, 
stageDefinition, workerTaskIds);
+        }
+        return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
+      default:
+        throw new IllegalStateException("No fetching strategy found for mode: 
" + clusterStatisticsMergeMode);
+    }
+  }
+
+  /**
+   * Fetches the full {@link ClusterByStatisticsCollector} from all workers 
and generates partition boundaries from them.
+   * This is faster than fetching them timechunk by timechunk but the 
collector will be downsampled till it can fit
+   * on the controller, resulting in less accurate partition boundries.
+   */
+  private CompletableFuture<Either<Long, ClusterByPartitions>> 
inMemoryFullSketchMerging(
+      StageDefinition stageDefinition,
+      List<String> workerTaskIds
+  )
+  {
+    CompletableFuture<Either<Long, ClusterByPartitions>> partitionFuture = new 
CompletableFuture<>();
+
+    // Create a new key statistics collector to merge worker sketches into
+    final ClusterByStatisticsCollector mergedStatisticsCollector =
+        
stageDefinition.createResultKeyStatisticsCollector(statisticsMaxRetainedBytes);
+    final int workerCount = workerTaskIds.size();
+    // Guarded by synchronized mergedStatisticsCollector
+    final Set<Integer> finishedWorkers = new HashSet<>();
+
+    // Submit a task for each worker to fetch statistics
+    IntStream.range(0, workerCount).forEach(workerNo -> {
+      executorService.submit(() -> {
+        ListenableFuture<ClusterByStatisticsSnapshot> snapshotFuture =
+            workerClient.fetchClusterByStatisticsSnapshot(
+                workerTaskIds.get(workerNo),
+                stageDefinition.getId().getQueryId(),
+                stageDefinition.getStageNumber()
+            );
+        partitionFuture.whenComplete((result, exception) -> 
snapshotFuture.cancel(true));

Review Comment:
   I think the logic here should be if the partition future completes 
exceptionally, cancel the correct work and not cancel the work everytime the 
partitionFuture gets completed. Wdyt ?



##########
extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/CompleteKeyStatisticsInformation.java:
##########
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.msq.statistics;
+
+import java.util.HashSet;
+import java.util.Set;
+import java.util.SortedMap;
+
+/**
+ * Class maintained by the controller to merge {@link 
PartialKeyStatisticsInformation} sent by the worker.
+ */
+public class CompleteKeyStatisticsInformation
+{
+  private final SortedMap<Long, Set<Integer>> timeSegmentVsWorkerMap;
+
+  private boolean hasMultipleValues;
+
+  private double bytesRetained;
+
+  public CompleteKeyStatisticsInformation(
+      final SortedMap<Long, Set<Integer>> timeChunks,
+      boolean hasMultipleValues,
+      double bytesRetained
+  )
+  {
+    this.timeSegmentVsWorkerMap = timeChunks;
+    this.hasMultipleValues = hasMultipleValues;
+    this.bytesRetained = bytesRetained;
+  }
+
+  public void mergePartialInformation(int workerNumber, 
PartialKeyStatisticsInformation partialKeyStatisticsInformation)

Review Comment:
   I think we should java doc this especially the sorted key part here 
SortedMap<Long, Set<Integer>> timeSegmentVsWorkerMap



##########
extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStageTracker.java:
##########
@@ -259,16 +257,21 @@ ControllerStagePhase addResultKeyStatisticsForWorker(
     }
 
     try {
-      if (workersWithResultKeyStatistics.add(workerNumber)) {
-        resultKeyStatisticsCollector.addAll(snapshot);
+      if (workersWithReportedKeyStatistics.add(workerNumber)) {
 
-        if (workersWithResultKeyStatistics.size() == workerCount) {
-          generateResultPartitionsAndBoundaries();
+        if (partialKeyStatisticsInformation.getTimeSegments().contains(null)) {

Review Comment:
   Is there an associated UT for this ?



##########
extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStageTracker.java:
##########
@@ -234,23 +224,31 @@ WorkerInputs getWorkerInputs()
     return workerInputs;
   }
 
+  /**
+   * Returns the merged key statistics.
+   */
+  @Nullable
+  public CompleteKeyStatisticsInformation getCompleteKeyStatisticsInformation()
+  {
+    return completeKeyStatisticsInformation;
+  }
+
   /**
    * Adds result key statistics for a particular worker number. If statistics 
have already been added for this worker,
    * then this call ignores the new ones and does nothing.
    *
    * @param workerNumber the worker
-   * @param snapshot     worker statistics
+   * @param partialKeyStatisticsInformation partial key statistics
    */
-  ControllerStagePhase addResultKeyStatisticsForWorker(
+  ControllerStagePhase addPartialKeyStatisticsForWorker(
       final int workerNumber,
-      final ClusterByStatisticsSnapshot snapshot
+      final PartialKeyStatisticsInformation partialKeyStatisticsInformation
   )
   {
     if (phase != ControllerStagePhase.READING_INPUT) {
       throw new ISE("Cannot add result key statistics from stage [%s]", phase);
     }
-
-    if (resultKeyStatisticsCollector == null) {
+    if (!stageDef.doesShuffle() || completeKeyStatisticsInformation == null) {

Review Comment:
   mustGatherResultKeyStatisticss should be checked as we sometimes add an 
empty shuffle spec.



##########
extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStageTracker.java:
##########
@@ -259,16 +257,21 @@ ControllerStagePhase addResultKeyStatisticsForWorker(
     }
 
     try {
-      if (workersWithResultKeyStatistics.add(workerNumber)) {
-        resultKeyStatisticsCollector.addAll(snapshot);
+      if (workersWithReportedKeyStatistics.add(workerNumber)) {
 
-        if (workersWithResultKeyStatistics.size() == workerCount) {
-          generateResultPartitionsAndBoundaries();
+        if (partialKeyStatisticsInformation.getTimeSegments().contains(null)) {
+          // Time should not contain null value
+          failForReason(InsertTimeNullFault.instance());
+          return getPhase();
+        }
+
+        completeKeyStatisticsInformation.mergePartialInformation(workerNumber, 
partialKeyStatisticsInformation);
+
+        if (workersWithReportedKeyStatistics.size() == workerCount) {
+          // All workers have sent the report.
+          // Transition to MERGING_STATISTICS state to queue fetch clustering 
statistics from workers.
+          transitionTo(ControllerStagePhase.MERGING_STATISTICS);

Review Comment:
   We should mention similar comments on 271:272 in the JAVA doc of 
ControllerStagePhase.MERGING_STATISTICS



##########
extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.msq.exec;
+
+import com.google.common.util.concurrent.ListenableFuture;
+import org.apache.druid.frame.key.ClusterBy;
+import org.apache.druid.frame.key.ClusterByPartition;
+import org.apache.druid.frame.key.ClusterByPartitions;
+import org.apache.druid.java.util.common.Either;
+import org.apache.druid.msq.kernel.StageDefinition;
+import org.apache.druid.msq.statistics.ClusterByStatisticsCollector;
+import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
+import org.apache.druid.msq.statistics.CompleteKeyStatisticsInformation;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.stream.IntStream;
+
+/**
+ * Queues up fetching sketches from workers and progressively generates 
partitions boundaries.
+ */
+public class WorkerSketchFetcher
+{
+  private static final int DEFAULT_THREAD_COUNT = 4;
+  // If the combined size of worker sketches is more than this threshold, 
SEQUENTIAL merging mode is used.
+  private static final long BYTES_THRESHOLD = 1_000_000_000L;
+  // If there are more workers than this threshold, SEQUENTIAL merging mode is 
used.
+  private static final long WORKER_THRESHOLD = 100;
+
+  private final ClusterStatisticsMergeMode clusterStatisticsMergeMode;
+  private final int statisticsMaxRetainedBytes;
+  private final WorkerClient workerClient;
+  private final ExecutorService executorService;
+
+  public WorkerSketchFetcher(WorkerClient workerClient, 
ClusterStatisticsMergeMode clusterStatisticsMergeMode, int 
statisticsMaxRetainedBytes)
+  {
+    this.workerClient = workerClient;
+    this.clusterStatisticsMergeMode = clusterStatisticsMergeMode;
+    this.executorService = Executors.newFixedThreadPool(DEFAULT_THREAD_COUNT);
+    this.statisticsMaxRetainedBytes = statisticsMaxRetainedBytes;
+  }
+
+  /**
+   * Submits a request to fetch and generate partitions for the given worker 
statistics and returns a future for it. It
+   * decides based on the statistics if it should fetch sketches one by one or 
together.
+   */
+  public CompletableFuture<Either<Long, ClusterByPartitions>> 
submitFetcherTask(
+      CompleteKeyStatisticsInformation completeKeyStatisticsInformation,
+      List<String> workerTaskIds,
+      StageDefinition stageDefinition
+  )
+  {
+    ClusterBy clusterBy = stageDefinition.getClusterBy();
+
+    switch (clusterStatisticsMergeMode) {
+      case SEQUENTIAL:
+        return sequentialTimeChunkMerging(completeKeyStatisticsInformation, 
stageDefinition, workerTaskIds);
+      case PARALLEL:
+        return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
+      case AUTO:
+        if (clusterBy.getBucketByCount() == 0) {
+          // If there is no time clustering, there is no scope for sequential 
merge
+          return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
+        } else if (stageDefinition.getMaxWorkerCount() > WORKER_THRESHOLD || 
completeKeyStatisticsInformation.getBytesRetained() > BYTES_THRESHOLD) {
+          return sequentialTimeChunkMerging(completeKeyStatisticsInformation, 
stageDefinition, workerTaskIds);
+        }
+        return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
+      default:
+        throw new IllegalStateException("No fetching strategy found for mode: 
" + clusterStatisticsMergeMode);
+    }
+  }
+
+  /**
+   * Fetches the full {@link ClusterByStatisticsCollector} from all workers 
and generates partition boundaries from them.
+   * This is faster than fetching them timechunk by timechunk but the 
collector will be downsampled till it can fit
+   * on the controller, resulting in less accurate partition boundries.
+   */
+  private CompletableFuture<Either<Long, ClusterByPartitions>> 
inMemoryFullSketchMerging(
+      StageDefinition stageDefinition,
+      List<String> workerTaskIds
+  )
+  {
+    CompletableFuture<Either<Long, ClusterByPartitions>> partitionFuture = new 
CompletableFuture<>();
+
+    // Create a new key statistics collector to merge worker sketches into
+    final ClusterByStatisticsCollector mergedStatisticsCollector =
+        
stageDefinition.createResultKeyStatisticsCollector(statisticsMaxRetainedBytes);
+    final int workerCount = workerTaskIds.size();
+    // Guarded by synchronized mergedStatisticsCollector
+    final Set<Integer> finishedWorkers = new HashSet<>();
+
+    // Submit a task for each worker to fetch statistics
+    IntStream.range(0, workerCount).forEach(workerNo -> {
+      executorService.submit(() -> {
+        ListenableFuture<ClusterByStatisticsSnapshot> snapshotFuture =
+            workerClient.fetchClusterByStatisticsSnapshot(
+                workerTaskIds.get(workerNo),
+                stageDefinition.getId().getQueryId(),
+                stageDefinition.getStageNumber()
+            );
+        partitionFuture.whenComplete((result, exception) -> 
snapshotFuture.cancel(true));
+
+        try {
+          ClusterByStatisticsSnapshot clusterByStatisticsSnapshot = 
snapshotFuture.get();
+          synchronized (mergedStatisticsCollector) {
+            mergedStatisticsCollector.addAll(clusterByStatisticsSnapshot);
+            finishedWorkers.add(workerNo);
+
+            if (finishedWorkers.size() == workerCount) {
+              
partitionFuture.complete(stageDefinition.generatePartitionsForShuffle(mergedStatisticsCollector));
+            }
+          }
+        }
+        catch (Exception e) {
+          synchronized (mergedStatisticsCollector) {
+            partitionFuture.completeExceptionally(e);
+          }
+        }
+      });
+    });
+    return partitionFuture;
+  }
+
+  /**
+   * Fetches cluster statistics from all workers and generates partition 
boundaries from them one time chunk at a time.
+   * This takes longer due to the overhead of fetching sketches, however, this 
prevents any loss in accuracy from
+   * downsampling on the controller.
+   */
+  private CompletableFuture<Either<Long, ClusterByPartitions>> 
sequentialTimeChunkMerging(
+      CompleteKeyStatisticsInformation completeKeyStatisticsInformation,
+      StageDefinition stageDefinition,
+      List<String> workerTaskIds
+  )
+  {
+    SequentialFetchStage sequentialFetchStage = new SequentialFetchStage(
+        stageDefinition,
+        workerTaskIds,
+        
completeKeyStatisticsInformation.getTimeSegmentVsWorkerMap().entrySet().iterator()
+    );
+    sequentialFetchStage.submitFetchingTasksForNextTimeChunk();
+    return sequentialFetchStage.getPartitionFuture();
+  }
+
+  private class SequentialFetchStage
+  {
+    private final StageDefinition stageDefinition;
+    private final List<String> workerTaskIds;
+    private final Iterator<Map.Entry<Long, Set<Integer>>> 
timeSegmentVsWorkerIdIterator;
+    private final CompletableFuture<Either<Long, ClusterByPartitions>> 
partitionFuture;
+    // Final sorted list of partition boundaries. This is appended to after 
statistics for each time chunk are gathered.
+    private final List<ClusterByPartition> finalPartitionBoundries;
+
+    public SequentialFetchStage(
+        StageDefinition stageDefinition,
+        List<String> workerTaskIds,
+        Iterator<Map.Entry<Long, Set<Integer>>> timeSegmentVsWorkerIdIterator
+    )
+    {
+      this.finalPartitionBoundries = new ArrayList<>();
+      this.stageDefinition = stageDefinition;
+      this.workerTaskIds = workerTaskIds;
+      this.timeSegmentVsWorkerIdIterator = timeSegmentVsWorkerIdIterator;
+      this.partitionFuture = new CompletableFuture<>();
+    }
+
+    public void submitFetchingTasksForNextTimeChunk()

Review Comment:
   Lets java doc this method.



##########
extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java:
##########
@@ -562,6 +568,19 @@ public void postFinish()
     kernelManipulationQueue.add(KernelHolder::setDone);
   }
 
+  @Override
+  public ClusterByStatisticsSnapshot fetchStatisticsSnapshot(StageId stageId)
+  {
+    return stageKernelMap.get(stageId).getResultKeyStatisticsSnapshot();

Review Comment:
   Or the result key stats are not yet populated.



##########
integration-tests-ex/cases/src/test/java/org/apache/druid/testsEx/msq/ITMultiStageQuery.java:
##########
@@ -122,4 +126,144 @@ public void testMsqIngestionAndQuerying() throws Exception
 
     msqHelper.testQueriesFromFile(QUERY_FILE, datasource);
   }
+
+  @Test
+  public void testMsqIngestionParallelMerging() throws Exception
+  {
+    String datasource = "dst";
+
+    // Clear up the datasource from the previous runs
+    coordinatorClient.unloadSegmentsForDataSource(datasource);
+
+    String queryLocal =
+        StringUtils.format(
+            "INSERT INTO %s\n"
+            + "SELECT\n"
+            + "  TIME_PARSE(\"timestamp\") AS __time,\n"
+            + "  isRobot,\n"
+            + "  diffUrl,\n"
+            + "  added,\n"
+            + "  countryIsoCode,\n"
+            + "  regionName,\n"
+            + "  channel,\n"
+            + "  flags,\n"
+            + "  delta,\n"
+            + "  isUnpatrolled,\n"
+            + "  isNew,\n"
+            + "  deltaBucket,\n"
+            + "  isMinor,\n"
+            + "  isAnonymous,\n"
+            + "  deleted,\n"
+            + "  cityName,\n"
+            + "  metroCode,\n"
+            + "  namespace,\n"
+            + "  comment,\n"
+            + "  page,\n"
+            + "  commentLength,\n"
+            + "  countryName,\n"
+            + "  user,\n"
+            + "  regionIsoCode\n"
+            + "FROM TABLE(\n"
+            + "  EXTERN(\n"
+            + "    
'{\"type\":\"local\",\"files\":[\"/resources/data/batch_index/json/wikipedia_index_data1.json\"]}',\n"
+            + "    '{\"type\":\"json\"}',\n"
+            + "    
'[{\"type\":\"string\",\"name\":\"timestamp\"},{\"type\":\"string\",\"name\":\"isRobot\"},{\"type\":\"string\",\"name\":\"diffUrl\"},{\"type\":\"long\",\"name\":\"added\"},{\"type\":\"string\",\"name\":\"countryIsoCode\"},{\"type\":\"string\",\"name\":\"regionName\"},{\"type\":\"string\",\"name\":\"channel\"},{\"type\":\"string\",\"name\":\"flags\"},{\"type\":\"long\",\"name\":\"delta\"},{\"type\":\"string\",\"name\":\"isUnpatrolled\"},{\"type\":\"string\",\"name\":\"isNew\"},{\"type\":\"double\",\"name\":\"deltaBucket\"},{\"type\":\"string\",\"name\":\"isMinor\"},{\"type\":\"string\",\"name\":\"isAnonymous\"},{\"type\":\"long\",\"name\":\"deleted\"},{\"type\":\"string\",\"name\":\"cityName\"},{\"type\":\"long\",\"name\":\"metroCode\"},{\"type\":\"string\",\"name\":\"namespace\"},{\"type\":\"string\",\"name\":\"comment\"},{\"type\":\"string\",\"name\":\"page\"},{\"type\":\"long\",\"name\":\"commentLength\"},{\"type\":\"string\",\"name\":\"countryName\"},{\"type\":
 
\"string\",\"name\":\"user\"},{\"type\":\"string\",\"name\":\"regionIsoCode\"}]'\n"
+            + "  )\n"
+            + ")\n"
+            + "PARTITIONED BY DAY\n"
+            + "CLUSTERED BY \"__time\"",
+            datasource
+        );
+
+    ImmutableMap<String, Object> context = ImmutableMap.of(
+        MultiStageQueryContext.CTX_CLUSTER_STATISTICS_MERGE_MODE,
+        ClusterStatisticsMergeMode.PARALLEL

Review Comment:
   Sure as long we test that the correct mode is engaged



##########
extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java:
##########
@@ -562,6 +568,19 @@ public void postFinish()
     kernelManipulationQueue.add(KernelHolder::setDone);
   }
 
+  @Override
+  public ClusterByStatisticsSnapshot fetchStatisticsSnapshot(StageId stageId)
+  {
+    return stageKernelMap.get(stageId).getResultKeyStatisticsSnapshot();
+  }
+
+  @Override
+  public ClusterByStatisticsSnapshot 
fetchStatisticsSnapshotForTimeChunk(StageId stageId, long timeChunk)
+  {
+    ClusterByStatisticsSnapshot snapshot = 
stageKernelMap.get(stageId).getResultKeyStatisticsSnapshot();

Review Comment:
   Similar Q here.



##########
extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java:
##########
@@ -595,7 +611,35 @@ public void updateStatus(int stageNumber, int 
workerNumber, Object keyStatistics
             );
           }
 
-          queryKernel.addResultKeyStatisticsForStageAndWorker(stageId, 
workerNumber, keyStatistics);
+          queryKernel.addPartialKeyStatisticsForStageAndWorker(stageId, 
workerNumber, partialKeyStatisticsInformation);
+
+          if 
(queryKernel.getStagePhase(stageId).equals(ControllerStagePhase.MERGING_STATISTICS))
 {
+            List<String> workerTaskIds = workerTaskLauncher.getTaskList();
+            CompleteKeyStatisticsInformation completeKeyStatisticsInformation =
+                queryKernel.getCompleteKeyStatisticsInformation(stageId);
+
+            // Queue the sketch fetching task into the worker sketch fetcher.
+            CompletableFuture<Either<Long, ClusterByPartitions>> 
clusterByPartitionsCompletableFuture =
+                workerSketchFetcher.submitFetcherTask(
+                    completeKeyStatisticsInformation,
+                    workerTaskIds,
+                    stageDef
+                );
+
+            // Add the listener to handle completion.
+            
clusterByPartitionsCompletableFuture.whenComplete((clusterByPartitionsEither, 
throwable) -> {
+              kernelManipulationQueue.add(holder -> {
+                if (throwable != null) {
+                  queryKernel.failStageForReason(stageId, 
UnknownFault.forException(throwable));

Review Comment:
   this should be the holder object. I donot know if it matters but sake of 
pattern we are using. 



##########
extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/ControllerChatHandler.java:
##########
@@ -58,24 +59,25 @@ public ControllerChatHandler(TaskToolbox toolbox, 
Controller controller)
   }
 
   /**
-   * Used by subtasks to post {@link ClusterByStatisticsSnapshot} for 
shuffling stages.
+   * Used by subtasks to post {@link PartialKeyStatisticsInformation} for 
shuffling stages.
    *
-   * See {@link ControllerClient#postKeyStatistics} for the client-side code 
that calls this API.
+   * See {@link ControllerClient#postPartialKeyStatistics(StageId, int, 
PartialKeyStatisticsInformation)}
+   * for the client-side code that calls this API.
    */
   @POST
-  @Path("/keyStatistics/{queryId}/{stageNumber}/{workerNumber}")
+  @Path("/partialKeyStatistics/{queryId}/{stageNumber}/{workerNumber}")

Review Comment:
   nit : I would change this to : partialKeyStatisticsInformation



##########
extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.msq.exec;
+
+import com.google.common.util.concurrent.ListenableFuture;
+import org.apache.druid.frame.key.ClusterBy;
+import org.apache.druid.frame.key.ClusterByPartition;
+import org.apache.druid.frame.key.ClusterByPartitions;
+import org.apache.druid.java.util.common.Either;
+import org.apache.druid.msq.kernel.StageDefinition;
+import org.apache.druid.msq.statistics.ClusterByStatisticsCollector;
+import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
+import org.apache.druid.msq.statistics.CompleteKeyStatisticsInformation;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.stream.IntStream;
+
+/**
+ * Queues up fetching sketches from workers and progressively generates 
partitions boundaries.
+ */
+public class WorkerSketchFetcher
+{
+  private static final int DEFAULT_THREAD_COUNT = 4;
+  // If the combined size of worker sketches is more than this threshold, 
SEQUENTIAL merging mode is used.
+  private static final long BYTES_THRESHOLD = 1_000_000_000L;
+  // If there are more workers than this threshold, SEQUENTIAL merging mode is 
used.
+  private static final long WORKER_THRESHOLD = 100;
+
+  private final ClusterStatisticsMergeMode clusterStatisticsMergeMode;
+  private final int statisticsMaxRetainedBytes;
+  private final WorkerClient workerClient;
+  private final ExecutorService executorService;
+
+  public WorkerSketchFetcher(WorkerClient workerClient, 
ClusterStatisticsMergeMode clusterStatisticsMergeMode, int 
statisticsMaxRetainedBytes)
+  {
+    this.workerClient = workerClient;
+    this.clusterStatisticsMergeMode = clusterStatisticsMergeMode;
+    this.executorService = Executors.newFixedThreadPool(DEFAULT_THREAD_COUNT);
+    this.statisticsMaxRetainedBytes = statisticsMaxRetainedBytes;
+  }
+
+  /**
+   * Submits a request to fetch and generate partitions for the given worker 
statistics and returns a future for it. It
+   * decides based on the statistics if it should fetch sketches one by one or 
together.
+   */
+  public CompletableFuture<Either<Long, ClusterByPartitions>> 
submitFetcherTask(
+      CompleteKeyStatisticsInformation completeKeyStatisticsInformation,
+      List<String> workerTaskIds,
+      StageDefinition stageDefinition
+  )
+  {
+    ClusterBy clusterBy = stageDefinition.getClusterBy();
+
+    switch (clusterStatisticsMergeMode) {
+      case SEQUENTIAL:
+        return sequentialTimeChunkMerging(completeKeyStatisticsInformation, 
stageDefinition, workerTaskIds);
+      case PARALLEL:
+        return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
+      case AUTO:
+        if (clusterBy.getBucketByCount() == 0) {
+          // If there is no time clustering, there is no scope for sequential 
merge
+          return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
+        } else if (stageDefinition.getMaxWorkerCount() > WORKER_THRESHOLD || 
completeKeyStatisticsInformation.getBytesRetained() > BYTES_THRESHOLD) {
+          return sequentialTimeChunkMerging(completeKeyStatisticsInformation, 
stageDefinition, workerTaskIds);
+        }
+        return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
+      default:
+        throw new IllegalStateException("No fetching strategy found for mode: 
" + clusterStatisticsMergeMode);
+    }
+  }
+
+  /**
+   * Fetches the full {@link ClusterByStatisticsCollector} from all workers 
and generates partition boundaries from them.
+   * This is faster than fetching them timechunk by timechunk but the 
collector will be downsampled till it can fit
+   * on the controller, resulting in less accurate partition boundries.
+   */
+  private CompletableFuture<Either<Long, ClusterByPartitions>> 
inMemoryFullSketchMerging(
+      StageDefinition stageDefinition,
+      List<String> workerTaskIds
+  )
+  {
+    CompletableFuture<Either<Long, ClusterByPartitions>> partitionFuture = new 
CompletableFuture<>();
+
+    // Create a new key statistics collector to merge worker sketches into
+    final ClusterByStatisticsCollector mergedStatisticsCollector =
+        
stageDefinition.createResultKeyStatisticsCollector(statisticsMaxRetainedBytes);
+    final int workerCount = workerTaskIds.size();
+    // Guarded by synchronized mergedStatisticsCollector
+    final Set<Integer> finishedWorkers = new HashSet<>();
+
+    // Submit a task for each worker to fetch statistics
+    IntStream.range(0, workerCount).forEach(workerNo -> {
+      executorService.submit(() -> {
+        ListenableFuture<ClusterByStatisticsSnapshot> snapshotFuture =
+            workerClient.fetchClusterByStatisticsSnapshot(
+                workerTaskIds.get(workerNo),
+                stageDefinition.getId().getQueryId(),
+                stageDefinition.getStageNumber()
+            );
+        partitionFuture.whenComplete((result, exception) -> 
snapshotFuture.cancel(true));
+
+        try {
+          ClusterByStatisticsSnapshot clusterByStatisticsSnapshot = 
snapshotFuture.get();
+          synchronized (mergedStatisticsCollector) {
+            mergedStatisticsCollector.addAll(clusterByStatisticsSnapshot);
+            finishedWorkers.add(workerNo);
+
+            if (finishedWorkers.size() == workerCount) {
+              
partitionFuture.complete(stageDefinition.generatePartitionsForShuffle(mergedStatisticsCollector));
+            }
+          }
+        }
+        catch (Exception e) {
+          synchronized (mergedStatisticsCollector) {

Review Comment:
   I could not reason about this lock. 
   



##########
extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java:
##########
@@ -565,10 +579,12 @@ private QueryDefinition initializeQueryDefAndState(final 
Closer closer)
   }
 
   /**
-   * Provide a {@link ClusterByStatisticsSnapshot} for shuffling stages.
+   * Accepts a {@link PartialKeyStatisticsInformation} and updates the 
controller key statistics information. If all key
+   * statistics information has been gathered, enqueues the task with the 
{@link WorkerSketchFetcher} to generate
+   * partiton boundaries. This is intended to be called by the {@link 
org.apache.druid.msq.indexing.ControllerChatHandler}.
    */
   @Override
-  public void updateStatus(int stageNumber, int workerNumber, Object 
keyStatisticsObject)
+  public void updatePartialKeyStatistics(int stageNumber, int workerNumber, 
Object partialKeyStatisticsObject)

Review Comment:
   nit: partialKeyStatisticsInfObject?



##########
extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/CompleteKeyStatisticsInformation.java:
##########
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.msq.statistics;
+
+import java.util.HashSet;
+import java.util.Set;
+import java.util.SortedMap;
+
+/**
+ * Class maintained by the controller to merge {@link 
PartialKeyStatisticsInformation} sent by the worker.
+ */
+public class CompleteKeyStatisticsInformation
+{
+  private final SortedMap<Long, Set<Integer>> timeSegmentVsWorkerMap;
+
+  private boolean hasMultipleValues;
+
+  private double bytesRetained;
+
+  public CompleteKeyStatisticsInformation(
+      final SortedMap<Long, Set<Integer>> timeChunks,
+      boolean hasMultipleValues,
+      double bytesRetained
+  )
+  {
+    this.timeSegmentVsWorkerMap = timeChunks;
+    this.hasMultipleValues = hasMultipleValues;
+    this.bytesRetained = bytesRetained;
+  }
+
+  public void mergePartialInformation(int workerNumber, 
PartialKeyStatisticsInformation partialKeyStatisticsInformation)
+  {
+    for (Long timeSegment : partialKeyStatisticsInformation.getTimeSegments()) 
{
+      this.timeSegmentVsWorkerMap
+          .computeIfAbsent(timeSegment, key -> new HashSet<>())
+          .add(workerNumber);
+    }
+    this.hasMultipleValues = this.hasMultipleValues || 
partialKeyStatisticsInformation.isHasMultipleValues();
+    this.bytesRetained += bytesRetained;
+  }
+
+  public SortedMap<Long, Set<Integer>> getTimeSegmentVsWorkerMap()
+  {
+    return timeSegmentVsWorkerMap;
+  }
+
+  public boolean isHasMultipleValues()

Review Comment:
   nit: this seems weird. Can we do hasMultipleValues()?



##########
extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStageTracker.java:
##########
@@ -397,15 +418,11 @@ private void generateResultPartitionsAndBoundaries()
    *
    * @param fault reason why this stage has failed
    */
-  private void failForReason(final MSQFault fault)
+  void failForReason(final MSQFault fault)
   {
     transitionTo(ControllerStagePhase.FAILED);
 
     this.failureReason = fault;
-
-    if (resultKeyStatisticsCollector != null) {

Review Comment:
   is there a reason we are reverting this change ?



##########
extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/CompleteKeyStatisticsInformation.java:
##########
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.msq.statistics;
+
+import java.util.HashSet;
+import java.util.Set;
+import java.util.SortedMap;
+
+/**
+ * Class maintained by the controller to merge {@link 
PartialKeyStatisticsInformation} sent by the worker.
+ */
+public class CompleteKeyStatisticsInformation
+{
+  private final SortedMap<Long, Set<Integer>> timeSegmentVsWorkerMap;
+
+  private boolean hasMultipleValues;

Review Comment:
   Nit: we can remove the has from this variable name and then add it to the 
function name 



##########
extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStageTracker.java:
##########
@@ -280,6 +283,25 @@ ControllerStagePhase addResultKeyStatisticsForWorker(
     return getPhase();
   }
 
+  /**
+   * Sets the {@link #resultPartitions} and {@link #resultPartitionBoundaries} 
and transitions the phase to POST_READING.
+   */
+  void setClusterByPartitionBoundaries(ClusterByPartitions clusterByPartitions)
+  {
+    if (resultPartitions != null) {

Review Comment:
   Lets do a state check that the current state should be MergingStats?



##########
extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStageTracker.java:
##########
@@ -259,16 +257,21 @@ ControllerStagePhase addResultKeyStatisticsForWorker(
     }
 
     try {
-      if (workersWithResultKeyStatistics.add(workerNumber)) {
-        resultKeyStatisticsCollector.addAll(snapshot);
+      if (workersWithReportedKeyStatistics.add(workerNumber)) {
 
-        if (workersWithResultKeyStatistics.size() == workerCount) {
-          generateResultPartitionsAndBoundaries();
+        if (partialKeyStatisticsInformation.getTimeSegments().contains(null)) {
+          // Time should not contain null value
+          failForReason(InsertTimeNullFault.instance());
+          return getPhase();
+        }
+
+        completeKeyStatisticsInformation.mergePartialInformation(workerNumber, 
partialKeyStatisticsInformation);
+
+        if (workersWithReportedKeyStatistics.size() == workerCount) {
+          // All workers have sent the report.

Review Comment:
   ```suggestion
             // All workers have sent the partial key statistics information.
   ```



##########
extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStageTracker.java:
##########
@@ -280,6 +283,25 @@ ControllerStagePhase addResultKeyStatisticsForWorker(
     return getPhase();
   }
 
+  /**
+   * Sets the {@link #resultPartitions} and {@link #resultPartitionBoundaries} 
and transitions the phase to POST_READING.
+   */
+  void setClusterByPartitionBoundaries(ClusterByPartitions clusterByPartitions)
+  {
+    if (resultPartitions != null) {

Review Comment:
   And also if this stage.needsKeyStats() 



##########
extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/ControllerChatHandler.java:
##########
@@ -58,24 +59,25 @@ public ControllerChatHandler(TaskToolbox toolbox, 
Controller controller)
   }
 
   /**
-   * Used by subtasks to post {@link ClusterByStatisticsSnapshot} for 
shuffling stages.
+   * Used by subtasks to post {@link PartialKeyStatisticsInformation} for 
shuffling stages.
    *
-   * See {@link ControllerClient#postKeyStatistics} for the client-side code 
that calls this API.
+   * See {@link ControllerClient#postPartialKeyStatistics(StageId, int, 
PartialKeyStatisticsInformation)}
+   * for the client-side code that calls this API.
    */
   @POST
-  @Path("/keyStatistics/{queryId}/{stageNumber}/{workerNumber}")
+  @Path("/partialKeyStatistics/{queryId}/{stageNumber}/{workerNumber}")

Review Comment:
   Also, we would break backward compatibility here so it's worth mentioning 
that in the release notes.



##########
extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.msq.exec;
+
+import com.google.common.util.concurrent.ListenableFuture;
+import org.apache.druid.frame.key.ClusterBy;
+import org.apache.druid.frame.key.ClusterByPartition;
+import org.apache.druid.frame.key.ClusterByPartitions;
+import org.apache.druid.java.util.common.Either;
+import org.apache.druid.msq.kernel.StageDefinition;
+import org.apache.druid.msq.statistics.ClusterByStatisticsCollector;
+import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
+import org.apache.druid.msq.statistics.CompleteKeyStatisticsInformation;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.stream.IntStream;
+
+/**
+ * Queues up fetching sketches from workers and progressively generates 
partitions boundaries.
+ */
+public class WorkerSketchFetcher
+{
+  private static final int DEFAULT_THREAD_COUNT = 4;
+  // If the combined size of worker sketches is more than this threshold, 
SEQUENTIAL merging mode is used.
+  private static final long BYTES_THRESHOLD = 1_000_000_000L;
+  // If there are more workers than this threshold, SEQUENTIAL merging mode is 
used.
+  private static final long WORKER_THRESHOLD = 100;
+
+  private final ClusterStatisticsMergeMode clusterStatisticsMergeMode;
+  private final int statisticsMaxRetainedBytes;
+  private final WorkerClient workerClient;
+  private final ExecutorService executorService;
+
+  public WorkerSketchFetcher(WorkerClient workerClient, 
ClusterStatisticsMergeMode clusterStatisticsMergeMode, int 
statisticsMaxRetainedBytes)
+  {
+    this.workerClient = workerClient;
+    this.clusterStatisticsMergeMode = clusterStatisticsMergeMode;
+    this.executorService = Executors.newFixedThreadPool(DEFAULT_THREAD_COUNT);
+    this.statisticsMaxRetainedBytes = statisticsMaxRetainedBytes;
+  }
+
+  /**
+   * Submits a request to fetch and generate partitions for the given worker 
statistics and returns a future for it. It
+   * decides based on the statistics if it should fetch sketches one by one or 
together.
+   */
+  public CompletableFuture<Either<Long, ClusterByPartitions>> 
submitFetcherTask(
+      CompleteKeyStatisticsInformation completeKeyStatisticsInformation,
+      List<String> workerTaskIds,
+      StageDefinition stageDefinition
+  )
+  {
+    ClusterBy clusterBy = stageDefinition.getClusterBy();
+
+    switch (clusterStatisticsMergeMode) {
+      case SEQUENTIAL:
+        return sequentialTimeChunkMerging(completeKeyStatisticsInformation, 
stageDefinition, workerTaskIds);
+      case PARALLEL:
+        return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
+      case AUTO:
+        if (clusterBy.getBucketByCount() == 0) {
+          // If there is no time clustering, there is no scope for sequential 
merge
+          return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
+        } else if (stageDefinition.getMaxWorkerCount() > WORKER_THRESHOLD || 
completeKeyStatisticsInformation.getBytesRetained() > BYTES_THRESHOLD) {
+          return sequentialTimeChunkMerging(completeKeyStatisticsInformation, 
stageDefinition, workerTaskIds);
+        }
+        return inMemoryFullSketchMerging(stageDefinition, workerTaskIds);
+      default:
+        throw new IllegalStateException("No fetching strategy found for mode: 
" + clusterStatisticsMergeMode);
+    }
+  }
+
+  /**
+   * Fetches the full {@link ClusterByStatisticsCollector} from all workers 
and generates partition boundaries from them.
+   * This is faster than fetching them timechunk by timechunk but the 
collector will be downsampled till it can fit
+   * on the controller, resulting in less accurate partition boundries.
+   */
+  private CompletableFuture<Either<Long, ClusterByPartitions>> 
inMemoryFullSketchMerging(
+      StageDefinition stageDefinition,
+      List<String> workerTaskIds
+  )
+  {
+    CompletableFuture<Either<Long, ClusterByPartitions>> partitionFuture = new 
CompletableFuture<>();
+
+    // Create a new key statistics collector to merge worker sketches into
+    final ClusterByStatisticsCollector mergedStatisticsCollector =
+        
stageDefinition.createResultKeyStatisticsCollector(statisticsMaxRetainedBytes);
+    final int workerCount = workerTaskIds.size();
+    // Guarded by synchronized mergedStatisticsCollector
+    final Set<Integer> finishedWorkers = new HashSet<>();
+
+    // Submit a task for each worker to fetch statistics
+    IntStream.range(0, workerCount).forEach(workerNo -> {
+      executorService.submit(() -> {
+        ListenableFuture<ClusterByStatisticsSnapshot> snapshotFuture =
+            workerClient.fetchClusterByStatisticsSnapshot(
+                workerTaskIds.get(workerNo),
+                stageDefinition.getId().getQueryId(),
+                stageDefinition.getStageNumber()
+            );
+        partitionFuture.whenComplete((result, exception) -> 
snapshotFuture.cancel(true));
+
+        try {
+          ClusterByStatisticsSnapshot clusterByStatisticsSnapshot = 
snapshotFuture.get();
+          synchronized (mergedStatisticsCollector) {
+            mergedStatisticsCollector.addAll(clusterByStatisticsSnapshot);
+            finishedWorkers.add(workerNo);
+
+            if (finishedWorkers.size() == workerCount) {
+              
partitionFuture.complete(stageDefinition.generatePartitionsForShuffle(mergedStatisticsCollector));
+            }
+          }
+        }
+        catch (Exception e) {
+          synchronized (mergedStatisticsCollector) {
+            partitionFuture.completeExceptionally(e);
+          }
+        }
+      });
+    });
+    return partitionFuture;
+  }
+
+  /**
+   * Fetches cluster statistics from all workers and generates partition 
boundaries from them one time chunk at a time.
+   * This takes longer due to the overhead of fetching sketches, however, this 
prevents any loss in accuracy from
+   * downsampling on the controller.
+   */
+  private CompletableFuture<Either<Long, ClusterByPartitions>> 
sequentialTimeChunkMerging(
+      CompleteKeyStatisticsInformation completeKeyStatisticsInformation,
+      StageDefinition stageDefinition,
+      List<String> workerTaskIds
+  )
+  {
+    SequentialFetchStage sequentialFetchStage = new SequentialFetchStage(
+        stageDefinition,
+        workerTaskIds,
+        
completeKeyStatisticsInformation.getTimeSegmentVsWorkerMap().entrySet().iterator()
+    );
+    sequentialFetchStage.submitFetchingTasksForNextTimeChunk();
+    return sequentialFetchStage.getPartitionFuture();
+  }
+
+  private class SequentialFetchStage
+  {
+    private final StageDefinition stageDefinition;
+    private final List<String> workerTaskIds;
+    private final Iterator<Map.Entry<Long, Set<Integer>>> 
timeSegmentVsWorkerIdIterator;
+    private final CompletableFuture<Either<Long, ClusterByPartitions>> 
partitionFuture;
+    // Final sorted list of partition boundaries. This is appended to after 
statistics for each time chunk are gathered.
+    private final List<ClusterByPartition> finalPartitionBoundries;
+
+    public SequentialFetchStage(
+        StageDefinition stageDefinition,
+        List<String> workerTaskIds,
+        Iterator<Map.Entry<Long, Set<Integer>>> timeSegmentVsWorkerIdIterator
+    )
+    {
+      this.finalPartitionBoundries = new ArrayList<>();
+      this.stageDefinition = stageDefinition;
+      this.workerTaskIds = workerTaskIds;
+      this.timeSegmentVsWorkerIdIterator = timeSegmentVsWorkerIdIterator;
+      this.partitionFuture = new CompletableFuture<>();
+    }
+
+    public void submitFetchingTasksForNextTimeChunk()
+    {
+      if (!timeSegmentVsWorkerIdIterator.hasNext()) {
+        partitionFuture.complete(Either.value(new 
ClusterByPartitions(finalPartitionBoundries)));
+      } else {
+        Map.Entry<Long, Set<Integer>> entry = 
timeSegmentVsWorkerIdIterator.next();
+        // Time chunk for which partition boundries are going to be generated 
for
+        Long timeChunk = entry.getKey();
+        Set<Integer> workerIdsWithTimeChunk = entry.getValue();
+        // Create a new key statistics collector to merge worker sketches into
+        ClusterByStatisticsCollector mergedStatisticsCollector =
+            
stageDefinition.createResultKeyStatisticsCollector(statisticsMaxRetainedBytes);
+        // Guarded by synchronized mergedStatisticsCollector
+        Set<Integer> finishedWorkers = new HashSet<>();
+
+        // Submits a task for every worker which has a certain time chunk
+        for (int workerNo : workerIdsWithTimeChunk) {
+          executorService.submit(() -> {
+            ListenableFuture<ClusterByStatisticsSnapshot> snapshotFuture =
+                workerClient.fetchClusterByStatisticsSnapshotForTimeChunk(
+                    workerTaskIds.get(workerNo),
+                    stageDefinition.getId().getQueryId(),
+                    stageDefinition.getStageNumber(),
+                    timeChunk
+                );
+            partitionFuture.whenComplete((result, exception) -> 
snapshotFuture.cancel(true));

Review Comment:
   Similar comment as above.



##########
extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/statistics/CompleteKeyStatisticsInformation.java:
##########
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.msq.statistics;
+
+import java.util.HashSet;
+import java.util.Set;
+import java.util.SortedMap;
+
+/**
+ * Class maintained by the controller to merge {@link 
PartialKeyStatisticsInformation} sent by the worker.
+ */
+public class CompleteKeyStatisticsInformation
+{
+  private final SortedMap<Long, Set<Integer>> timeSegmentVsWorkerMap;
+
+  private boolean hasMultipleValues;
+
+  private double bytesRetained;
+
+  public CompleteKeyStatisticsInformation(
+      final SortedMap<Long, Set<Integer>> timeChunks,
+      boolean hasMultipleValues,
+      double bytesRetained
+  )
+  {
+    this.timeSegmentVsWorkerMap = timeChunks;
+    this.hasMultipleValues = hasMultipleValues;
+    this.bytesRetained = bytesRetained;
+  }
+
+  public void mergePartialInformation(int workerNumber, 
PartialKeyStatisticsInformation partialKeyStatisticsInformation)
+  {
+    for (Long timeSegment : partialKeyStatisticsInformation.getTimeSegments()) 
{
+      this.timeSegmentVsWorkerMap
+          .computeIfAbsent(timeSegment, key -> new HashSet<>())
+          .add(workerNumber);
+    }
+    this.hasMultipleValues = this.hasMultipleValues || 
partialKeyStatisticsInformation.isHasMultipleValues();
+    this.bytesRetained += bytesRetained;
+  }
+
+  public SortedMap<Long, Set<Integer>> getTimeSegmentVsWorkerMap()
+  {
+    return timeSegmentVsWorkerMap;

Review Comment:
   Should this be an immutable copy ?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to