gianm commented on code in PR #12918:
URL: https://github.com/apache/druid/pull/12918#discussion_r951479650


##########
extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java:
##########
@@ -0,0 +1,2040 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.msq.exec;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Iterators;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.ListeningExecutorService;
+import com.google.common.util.concurrent.MoreExecutors;
+import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap;
+import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
+import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
+import it.unimi.dsi.fastutil.ints.IntSet;
+import org.apache.druid.common.guava.FutureUtils;
+import org.apache.druid.data.input.StringTuple;
+import org.apache.druid.data.input.impl.DimensionSchema;
+import org.apache.druid.data.input.impl.DimensionsSpec;
+import org.apache.druid.data.input.impl.TimestampSpec;
+import org.apache.druid.frame.allocation.ArenaMemoryAllocator;
+import org.apache.druid.frame.channel.FrameChannelSequence;
+import org.apache.druid.frame.key.ClusterBy;
+import org.apache.druid.frame.key.ClusterByPartition;
+import org.apache.druid.frame.key.ClusterByPartitions;
+import org.apache.druid.frame.key.RowKey;
+import org.apache.druid.frame.key.RowKeyReader;
+import org.apache.druid.frame.key.SortColumn;
+import org.apache.druid.frame.processor.FrameProcessorExecutor;
+import org.apache.druid.frame.processor.FrameProcessors;
+import org.apache.druid.indexer.TaskState;
+import org.apache.druid.indexer.TaskStatus;
+import org.apache.druid.indexing.common.LockGranularity;
+import org.apache.druid.indexing.common.TaskLock;
+import org.apache.druid.indexing.common.TaskLockType;
+import org.apache.druid.indexing.common.TaskReport;
+import org.apache.druid.indexing.common.actions.LockListAction;
+import org.apache.druid.indexing.common.actions.MarkSegmentsAsUnusedAction;
+import org.apache.druid.indexing.common.actions.RetrieveUsedSegmentsAction;
+import org.apache.druid.indexing.common.actions.SegmentAllocateAction;
+import org.apache.druid.indexing.common.actions.SegmentInsertAction;
+import 
org.apache.druid.indexing.common.actions.SegmentTransactionalInsertAction;
+import org.apache.druid.indexing.overlord.Segments;
+import org.apache.druid.java.util.common.DateTimes;
+import org.apache.druid.java.util.common.IAE;
+import org.apache.druid.java.util.common.ISE;
+import org.apache.druid.java.util.common.Intervals;
+import org.apache.druid.java.util.common.JodaUtils;
+import org.apache.druid.java.util.common.Pair;
+import org.apache.druid.java.util.common.StringUtils;
+import org.apache.druid.java.util.common.concurrent.Execs;
+import org.apache.druid.java.util.common.granularity.Granularities;
+import org.apache.druid.java.util.common.granularity.Granularity;
+import org.apache.druid.java.util.common.guava.Sequences;
+import org.apache.druid.java.util.common.guava.Yielder;
+import org.apache.druid.java.util.common.guava.Yielders;
+import org.apache.druid.java.util.common.io.Closer;
+import org.apache.druid.java.util.common.logger.Logger;
+import org.apache.druid.msq.counters.CounterSnapshots;
+import org.apache.druid.msq.counters.CounterSnapshotsTree;
+import org.apache.druid.msq.indexing.ColumnMapping;
+import org.apache.druid.msq.indexing.ColumnMappings;
+import org.apache.druid.msq.indexing.DataSourceMSQDestination;
+import org.apache.druid.msq.indexing.InputChannelFactory;
+import org.apache.druid.msq.indexing.InputChannelsImpl;
+import org.apache.druid.msq.indexing.MSQControllerTask;
+import org.apache.druid.msq.indexing.MSQSpec;
+import org.apache.druid.msq.indexing.MSQTuningConfig;
+import org.apache.druid.msq.indexing.MSQWorkerTaskLauncher;
+import org.apache.druid.msq.indexing.SegmentGeneratorFrameProcessorFactory;
+import org.apache.druid.msq.indexing.TaskReportMSQDestination;
+import org.apache.druid.msq.indexing.error.CanceledFault;
+import org.apache.druid.msq.indexing.error.CannotParseExternalDataFault;
+import org.apache.druid.msq.indexing.error.FaultsExceededChecker;
+import org.apache.druid.msq.indexing.error.InsertCannotAllocateSegmentFault;
+import org.apache.druid.msq.indexing.error.InsertCannotBeEmptyFault;
+import org.apache.druid.msq.indexing.error.InsertCannotOrderByDescendingFault;
+import 
org.apache.druid.msq.indexing.error.InsertCannotReplaceExistingSegmentFault;
+import org.apache.druid.msq.indexing.error.InsertLockPreemptedFault;
+import org.apache.druid.msq.indexing.error.InsertTimeOutOfBoundsFault;
+import org.apache.druid.msq.indexing.error.MSQErrorReport;
+import org.apache.druid.msq.indexing.error.MSQException;
+import org.apache.druid.msq.indexing.error.MSQFault;
+import org.apache.druid.msq.indexing.error.MSQWarningReportLimiterPublisher;
+import org.apache.druid.msq.indexing.error.MSQWarnings;
+import org.apache.druid.msq.indexing.error.QueryNotSupportedFault;
+import org.apache.druid.msq.indexing.error.TooManyWarningsFault;
+import org.apache.druid.msq.indexing.error.UnknownFault;
+import org.apache.druid.msq.indexing.report.MSQResultsReport;
+import org.apache.druid.msq.indexing.report.MSQStagesReport;
+import org.apache.druid.msq.indexing.report.MSQStatusReport;
+import org.apache.druid.msq.indexing.report.MSQTaskReport;
+import org.apache.druid.msq.indexing.report.MSQTaskReportPayload;
+import org.apache.druid.msq.input.InputSpec;
+import org.apache.druid.msq.input.InputSpecSlicer;
+import org.apache.druid.msq.input.InputSpecSlicerFactory;
+import org.apache.druid.msq.input.InputSpecs;
+import org.apache.druid.msq.input.MapInputSpecSlicer;
+import org.apache.druid.msq.input.external.ExternalInputSpec;
+import org.apache.druid.msq.input.external.ExternalInputSpecSlicer;
+import org.apache.druid.msq.input.stage.InputChannels;
+import org.apache.druid.msq.input.stage.ReadablePartition;
+import org.apache.druid.msq.input.stage.StageInputSlice;
+import org.apache.druid.msq.input.stage.StageInputSpec;
+import org.apache.druid.msq.input.stage.StageInputSpecSlicer;
+import org.apache.druid.msq.input.table.TableInputSpec;
+import org.apache.druid.msq.input.table.TableInputSpecSlicer;
+import org.apache.druid.msq.kernel.QueryDefinition;
+import org.apache.druid.msq.kernel.QueryDefinitionBuilder;
+import org.apache.druid.msq.kernel.StageDefinition;
+import org.apache.druid.msq.kernel.StageId;
+import org.apache.druid.msq.kernel.StagePartition;
+import org.apache.druid.msq.kernel.TargetSizeShuffleSpec;
+import org.apache.druid.msq.kernel.WorkOrder;
+import org.apache.druid.msq.kernel.controller.ControllerQueryKernel;
+import org.apache.druid.msq.kernel.controller.ControllerStagePhase;
+import org.apache.druid.msq.kernel.controller.WorkerInputs;
+import org.apache.druid.msq.querykit.DataSegmentTimelineView;
+import org.apache.druid.msq.querykit.MultiQueryKit;
+import org.apache.druid.msq.querykit.QueryKit;
+import org.apache.druid.msq.querykit.QueryKitUtils;
+import org.apache.druid.msq.querykit.ShuffleSpecFactories;
+import org.apache.druid.msq.querykit.ShuffleSpecFactory;
+import org.apache.druid.msq.querykit.groupby.GroupByQueryKit;
+import org.apache.druid.msq.querykit.scan.ScanQueryKit;
+import org.apache.druid.msq.shuffle.DurableStorageInputChannelFactory;
+import org.apache.druid.msq.shuffle.WorkerInputChannelFactory;
+import org.apache.druid.msq.sql.MSQTaskQueryMaker;
+import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
+import org.apache.druid.msq.util.DimensionSchemaUtils;
+import org.apache.druid.msq.util.IntervalUtils;
+import org.apache.druid.msq.util.MSQFutureUtils;
+import org.apache.druid.msq.util.MultiStageQueryContext;
+import org.apache.druid.msq.util.PassthroughAggregatorFactory;
+import org.apache.druid.query.Query;
+import org.apache.druid.query.aggregation.AggregatorFactory;
+import org.apache.druid.query.groupby.GroupByQuery;
+import org.apache.druid.query.groupby.GroupByQueryConfig;
+import org.apache.druid.query.scan.ScanQuery;
+import org.apache.druid.segment.ColumnSelectorFactory;
+import org.apache.druid.segment.ColumnValueSelector;
+import org.apache.druid.segment.Cursor;
+import org.apache.druid.segment.DimensionHandlerUtils;
+import org.apache.druid.segment.column.ColumnHolder;
+import org.apache.druid.segment.column.ColumnType;
+import org.apache.druid.segment.column.RowSignature;
+import org.apache.druid.segment.column.ValueType;
+import org.apache.druid.segment.indexing.DataSchema;
+import org.apache.druid.segment.indexing.granularity.ArbitraryGranularitySpec;
+import org.apache.druid.segment.indexing.granularity.GranularitySpec;
+import org.apache.druid.segment.realtime.appenderator.SegmentIdWithShardSpec;
+import org.apache.druid.segment.transform.TransformSpec;
+import org.apache.druid.server.DruidNode;
+import org.apache.druid.sql.calcite.rel.DruidQuery;
+import org.apache.druid.timeline.DataSegment;
+import org.apache.druid.timeline.VersionedIntervalTimeline;
+import org.apache.druid.timeline.partition.DimensionRangeShardSpec;
+import org.apache.druid.timeline.partition.NumberedPartialShardSpec;
+import org.apache.druid.timeline.partition.NumberedShardSpec;
+import org.apache.druid.timeline.partition.ShardSpec;
+import org.joda.time.DateTime;
+import org.joda.time.Interval;
+
+import javax.annotation.Nullable;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedHashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Queue;
+import java.util.Set;
+import java.util.concurrent.ArrayBlockingQueue;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.ThreadLocalRandom;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Consumer;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import java.util.stream.StreamSupport;
+
+public class ControllerImpl implements Controller
+{
+  private static final Logger log = new Logger(ControllerImpl.class);
+
+  private final MSQControllerTask task;
+  private final ControllerContext context;
+
+  private final BlockingQueue<Consumer<ControllerQueryKernel>> 
kernelManipulationQueue =
+      new ArrayBlockingQueue<>(Limits.MAX_KERNEL_MANIPULATION_QUEUE_SIZE);
+
+  // For system error reporting. This is the very first error we got from a 
worker. (We only report that one.)
+  private final AtomicReference<MSQErrorReport> workerErrorRef = new 
AtomicReference<>();
+
+  // For system warning reporting
+  private final ConcurrentLinkedQueue<MSQErrorReport> workerWarnings = new 
ConcurrentLinkedQueue<>();
+
+  // For live reports.
+  private final AtomicReference<QueryDefinition> queryDefRef = new 
AtomicReference<>();
+
+  // For live reports. Last reported CounterSnapshots per stage per worker
+  private final CounterSnapshotsTree taskCountersForLiveReports = new 
CounterSnapshotsTree();
+
+  // For live reports. stage number -> stage phase
+  private final ConcurrentHashMap<Integer, ControllerStagePhase> 
stagePhasesForLiveReports = new ConcurrentHashMap<>();
+
+  // For live reports. stage number -> runtime interval. Endpoint is 
eternity's end if the stage is still running.
+  private final ConcurrentHashMap<Integer, Interval> 
stageRuntimesForLiveReports = new ConcurrentHashMap<>();
+
+  // For live reports. stage number -> worker count. Only set for stages that 
have started.
+  private final ConcurrentHashMap<Integer, Integer> 
stageWorkerCountsForLiveReports = new ConcurrentHashMap<>();
+
+  // For live reports. stage number -> partition count. Only set for stages 
that have started.
+  private final ConcurrentHashMap<Integer, Integer> 
stagePartitionCountsForLiveReports = new ConcurrentHashMap<>();
+
+  // For live reports. The time at which the query started
+  private volatile DateTime queryStartTime = null;
+
+  private volatile DruidNode selfDruidNode;
+  private volatile MSQWorkerTaskLauncher workerTaskLauncher;
+  private volatile WorkerClient netClient;
+
+  private volatile FaultsExceededChecker faultsExceededChecker = null;
+
+  public ControllerImpl(
+      final MSQControllerTask task,
+      final ControllerContext context
+  )
+  {
+    this.task = task;
+    this.context = context;
+  }
+
+  @Override
+  public String id()
+  {
+    return task.getId();
+  }
+
+  @Override
+  public MSQControllerTask task()
+  {
+    return task;
+  }
+
+  @Override
+  public TaskStatus run() throws Exception
+  {
+    final Closer closer = Closer.create();
+
+    try {
+      return runTask(closer);
+    }
+    catch (Throwable e) {
+      try {
+        closer.close();
+      }
+      catch (Throwable e2) {
+        e.addSuppressed(e2);
+      }
+
+      // We really don't expect this to error out. runTask should handle 
everything nicely. If it doesn't, something
+      // strange happened, so log it.
+      log.warn(e, "Encountered unhandled controller exception.");
+      return TaskStatus.failure(id(), e.toString());
+    }
+    finally {
+      closer.close();
+    }
+  }
+
+  @Override
+  public void stopGracefully()
+  {
+    final QueryDefinition queryDef = queryDefRef.get();
+
+    // stopGracefully() is called when the containing process is terminated, 
or when the task is canceled.
+    log.info("Query [%s] canceled.", queryDef != null ? queryDef.getQueryId() 
: "<no id yet>");
+
+    addToKernelManipulationQueue(
+        kernel -> {
+          throw new MSQException(CanceledFault.INSTANCE);
+        }
+    );
+  }
+
+  public TaskStatus runTask(final Closer closer)
+  {
+    QueryDefinition queryDef = null;
+    ControllerQueryKernel queryKernel = null;
+    ListenableFuture<?> workerTaskRunnerFuture = null;
+    CounterSnapshotsTree countersSnapshot = null;
+    Yielder<Object[]> resultsYielder = null;
+    Throwable exceptionEncountered = null;
+
+    final TaskState taskStateForReport;
+    final MSQErrorReport errorForReport;
+
+    try {
+      this.queryStartTime = DateTimes.nowUtc();
+      queryDef = initializeQueryDefAndState(closer);
+
+      final InputSpecSlicerFactory inputSpecSlicerFactory = 
makeInputSpecSlicerFactory(makeDataSegmentTimelineView());
+      final Pair<ControllerQueryKernel, ListenableFuture<?>> queryRunResult =
+          runQueryUntilDone(queryDef, inputSpecSlicerFactory, closer);
+
+      queryKernel = Preconditions.checkNotNull(queryRunResult.lhs);
+      workerTaskRunnerFuture = Preconditions.checkNotNull(queryRunResult.rhs);
+      resultsYielder = getFinalResultsYielder(queryDef, queryKernel);
+      publishSegmentsIfNeeded(queryDef, queryKernel);
+    }
+    catch (Throwable e) {
+      exceptionEncountered = e;
+    }
+
+    // Fetch final counters in separate try, in case runQueryUntilDone threw 
an exception.
+    try {
+      countersSnapshot = getFinalCountersSnapshot(queryKernel);
+    }
+    catch (Throwable e) {
+      if (exceptionEncountered != null) {
+        exceptionEncountered.addSuppressed(e);
+      } else {
+        exceptionEncountered = e;
+      }
+    }
+
+    if (queryKernel != null && queryKernel.isSuccess() && exceptionEncountered 
== null) {
+      taskStateForReport = TaskState.SUCCESS;
+      errorForReport = null;
+    } else {
+      // Query failure. Generate an error report and log the error(s) we 
encountered.
+      final String selfHost = MSQTasks.getHostFromSelfNode(selfDruidNode);
+      final MSQErrorReport controllerError =
+          exceptionEncountered != null
+          ? MSQErrorReport.fromException(id(), selfHost, null, 
exceptionEncountered)
+          : null;
+      final MSQErrorReport workerError = workerErrorRef.get();
+
+      taskStateForReport = TaskState.FAILED;
+      errorForReport = MSQTasks.makeErrorReport(id(), selfHost, 
controllerError, workerError);
+
+      // Log the errors we encountered.
+      if (controllerError != null) {
+        log.warn("Controller: %s", 
MSQTasks.errorReportToLogMessage(controllerError));
+      }
+
+      if (workerError != null) {
+        log.warn("Worker: %s", MSQTasks.errorReportToLogMessage(workerError));
+      }
+    }
+
+    try {
+      // Write report even if something went wrong.
+      final MSQStagesReport stagesReport;
+      final MSQResultsReport resultsReport;
+
+      if (queryDef != null) {
+        final Map<Integer, ControllerStagePhase> stagePhaseMap;
+
+        if (queryKernel != null) {
+          // Once the query finishes, cleanup would have happened for all the 
stages that were successful
+          // Therefore we mark it as done to make the reports prettier and 
more accurate
+          queryKernel.markSuccessfulTerminalStagesAsFinished();
+          stagePhaseMap = queryKernel.getActiveStages()
+                                     .stream()
+                                     .collect(
+                                         
Collectors.toMap(StageId::getStageNumber, queryKernel::getStagePhase)
+                                     );
+        } else {
+          stagePhaseMap = Collections.emptyMap();
+        }
+
+        stagesReport = makeStageReport(
+            queryDef,
+            stagePhaseMap,
+            stageRuntimesForLiveReports,
+            stageWorkerCountsForLiveReports,
+            stagePartitionCountsForLiveReports
+        );
+      } else {
+        stagesReport = null;
+      }
+
+      if (resultsYielder != null) {
+        resultsReport = makeResultsTaskReport(
+            queryDef,
+            resultsYielder,
+            task.getQuerySpec().getColumnMappings(),
+            task.getSqlTypeNames()
+        );
+      } else {
+        resultsReport = null;
+      }
+
+      final MSQTaskReportPayload taskReportPayload = new MSQTaskReportPayload(
+          makeStatusReport(
+              taskStateForReport,
+              errorForReport,
+              workerWarnings,
+              queryStartTime,
+              new Interval(queryStartTime, 
DateTimes.nowUtc()).toDurationMillis()
+          ),
+          stagesReport,
+          countersSnapshot,
+          resultsReport
+      );
+
+      context.writeReports(
+          id(),
+          TaskReport.buildTaskReports(new MSQTaskReport(id(), 
taskReportPayload))
+      );
+    }
+    catch (Throwable e) {
+      log.warn(e, "Error encountered while writing task report. Skipping.");
+    }
+
+    if (queryKernel != null && queryKernel.isSuccess()) {
+      // If successful, encourage the tasks to exit successfully.
+      postFinishToAllTasks();
+      workerTaskLauncher.stop(false);
+    } else {
+      // If not successful, cancel running tasks.
+      if (workerTaskLauncher != null) {
+        workerTaskLauncher.stop(true);
+      }
+    }
+
+    // Wait for worker tasks to exit. Ignore their return status. At this 
point, we've done everything we need to do,
+    // so we don't care about the task exit status.
+    if (workerTaskRunnerFuture != null) {
+      try {
+        workerTaskRunnerFuture.get();
+      }
+      catch (Exception ignored) {
+        // Suppress.
+      }
+    }
+
+    cleanUpDurableStorageIfNeeded();
+
+    if (taskStateForReport == TaskState.SUCCESS) {
+      return TaskStatus.success(id());
+    } else {
+      // errorForReport is nonnull when taskStateForReport != SUCCESS. Use 
that message.
+      return TaskStatus.failure(id(), 
errorForReport.getFault().getCodeWithMessage());
+    }
+  }
+
+  /**
+   * Adds some logic to {@link #kernelManipulationQueue}, where it will, in 
due time, be executed by the main
+   * controller loop in {@link #runQueryUntilDone}.
+   *
+   * If the consumer throws an exception, the query fails.
+   */
+  private void addToKernelManipulationQueue(Consumer<ControllerQueryKernel> 
kernelConsumer)
+  {
+    if (!kernelManipulationQueue.offer(kernelConsumer)) {
+      final String message = "Controller kernel queue is full. Main controller 
loop may be delayed or stuck.";
+      log.warn(message);
+      throw new IllegalStateException(message);
+    }
+  }
+
+  private QueryDefinition initializeQueryDefAndState(final Closer closer)
+  {
+    this.selfDruidNode = context.selfNode();
+    context.registerController(this, closer);
+
+    this.netClient = new 
ExceptionWrappingWorkerClient(context.taskClientFor(this));
+    closer.register(netClient::close);
+
+    final boolean isDurableStorageEnabled =
+        
MultiStageQueryContext.isDurableStorageEnabled(task.getQuerySpec().getQuery().getContext());
+
+    final QueryDefinition queryDef = makeQueryDefinition(
+        id(),
+        makeQueryControllerToolKit(),
+        task.getQuerySpec()
+    );
+
+    QueryValidator.validateQueryDef(queryDef);
+    queryDefRef.set(queryDef);
+
+    log.debug("Query [%s] durable storage mode is set to %s.", 
queryDef.getQueryId(), isDurableStorageEnabled);
+
+    this.workerTaskLauncher = new MSQWorkerTaskLauncher(
+        id(),
+        task.getDataSource(),
+        context,
+        isDurableStorageEnabled,
+
+        // 10 minutes +- 2 minutes jitter
+        TimeUnit.SECONDS.toMillis(600 + 
ThreadLocalRandom.current().nextInt(-4, 5) * 30L)
+    );
+
+    long maxParseExceptions = -1;
+
+    if (task.getSqlQueryContext() != null) {
+      maxParseExceptions = Optional.ofNullable(
+                                       
task.getSqlQueryContext().get(MSQWarnings.CTX_MAX_PARSE_EXCEPTIONS_ALLOWED))
+                                   
.map(DimensionHandlerUtils::convertObjectToLong)
+                                   
.orElse(MSQWarnings.DEFAULT_MAX_PARSE_EXCEPTIONS_ALLOWED);
+    }
+
+    this.faultsExceededChecker = new FaultsExceededChecker(
+        ImmutableMap.of(CannotParseExternalDataFault.CODE, maxParseExceptions)
+    );
+
+    return queryDef;
+  }
+
+  private Pair<ControllerQueryKernel, ListenableFuture<?>> runQueryUntilDone(
+      final QueryDefinition queryDef,
+      final InputSpecSlicerFactory inputSpecSlicerFactory,
+      final Closer closer
+  ) throws Exception
+  {
+    // Start tasks.
+    log.debug("Query [%s] starting tasks.", queryDef.getQueryId());
+
+    final ListenableFuture<?> workerTaskLauncherFuture = 
workerTaskLauncher.start();
+    closer.register(() -> workerTaskLauncher.stop(true));
+
+    workerTaskLauncherFuture.addListener(
+        () ->
+            addToKernelManipulationQueue(queryKernel -> {
+              // Throw an exception in the main loop, if anything went wrong.
+              FutureUtils.getUncheckedImmediately(workerTaskLauncherFuture);
+            }),
+        Execs.directExecutor()
+    );
+
+    // Segments to generate; used for making stage-two workers.
+    List<SegmentIdWithShardSpec> segmentsToGenerate = null;
+
+    // Track which stages have got their partition boundaries sent out yet.
+    final Set<StageId> stageResultPartitionBoundariesSent = new HashSet<>();
+
+    // Start query tracking loop.
+    log.debug("Query [%s] starting tracker.", queryDef.getQueryId());
+    final ControllerQueryKernel queryKernel = new 
ControllerQueryKernel(queryDef);
+
+    while (!queryKernel.isDone()) {
+      // Start stages that need to be started.
+      logKernelStatus(queryDef.getQueryId(), queryKernel);
+      final List<StageId> newStageIds = queryKernel.createAndGetNewStageIds(
+          inputSpecSlicerFactory,
+          task.getQuerySpec().getAssignmentStrategy()
+      );
+
+      for (final StageId stageId : newStageIds) {
+        queryKernel.startStage(stageId);
+
+        // Allocate segments, if this is the final stage of an ingestion.
+        if (MSQControllerTask.isIngestion(task.getQuerySpec())
+            && stageId.getStageNumber() == 
queryDef.getFinalStageDefinition().getStageNumber()) {
+          // We need to find the shuffle details (like partition ranges) to 
generate segments. Generally this is
+          // going to correspond to the stage immediately prior to the final 
segment-generator stage.
+          int shuffleStageNumber = 
Iterables.getOnlyElement(queryDef.getFinalStageDefinition().getInputStageNumbers());
+
+          // The following logic assumes that output of all the stages without 
a shuffle retain the partition boundaries
+          // of the input to that stage. This may not always be the case. For 
example GROUP BY queries without an ORDER BY
+          // clause. This works for QueryKit generated queries uptil now, but 
it should be reworked as it might not
+          // always be the case
+          while 
(!queryDef.getStageDefinition(shuffleStageNumber).doesShuffle()) {
+            shuffleStageNumber =
+                
Iterables.getOnlyElement(queryDef.getStageDefinition(shuffleStageNumber).getInputStageNumbers());
+          }
+
+          final StageId shuffleStageId = new StageId(queryDef.getQueryId(), 
shuffleStageNumber);
+          final boolean isTimeBucketed = 
isTimeBucketedIngestion(task.getQuerySpec());
+          final ClusterByPartitions partitionBoundaries =
+              queryKernel.getResultPartitionBoundariesForStage(shuffleStageId);
+
+          // We require some data to be inserted in case it is partitioned by 
anything other than all and we are
+          // inserting everything into a single bucket. This can be handled 
more gracefully instead of throwing an exception
+          // Note: This can also be the case when we have limit queries but 
validation in Broker SQL layer prevents such
+          // queries
+          if (isTimeBucketed && 
partitionBoundaries.equals(ClusterByPartitions.oneUniversalPartition())) {
+            throw new MSQException(new 
InsertCannotBeEmptyFault(task.getDataSource()));
+          } else {
+            log.info("Query [%s] generating %d segments.", 
queryDef.getQueryId(), partitionBoundaries.size());
+          }
+
+          final boolean mayHaveMultiValuedClusterByFields =
+              
!queryKernel.getStageDefinition(shuffleStageId).mustGatherResultKeyStatistics()
+              || 
queryKernel.hasStageCollectorEncounteredAnyMultiValueField(shuffleStageId);
+
+          segmentsToGenerate = generateSegmentIdsWithShardSpecs(
+              (DataSourceMSQDestination) task.getQuerySpec().getDestination(),
+              queryKernel.getStageDefinition(shuffleStageId).getSignature(),
+              
queryKernel.getStageDefinition(shuffleStageId).getShuffleSpec().get().getClusterBy(),
+              partitionBoundaries,
+              mayHaveMultiValuedClusterByFields
+          );
+        }
+
+        final int workerCount = 
queryKernel.getWorkerInputsForStage(stageId).workerCount();
+        log.info(
+            "Query [%s] starting %d workers for stage %d.",
+            stageId.getQueryId(),
+            workerCount,
+            stageId.getStageNumber()
+        );
+
+        workerTaskLauncher.launchTasksIfNeeded(workerCount);
+        stageRuntimesForLiveReports.put(stageId.getStageNumber(), new 
Interval(DateTimes.nowUtc(), DateTimes.MAX));
+        startWorkForStage(queryDef, queryKernel, stageId.getStageNumber(), 
segmentsToGenerate);
+      }
+
+      // Send partition boundaries to tasks, if the time is right.
+      logKernelStatus(queryDef.getQueryId(), queryKernel);
+      for (final StageId stageId : queryKernel.getActiveStages()) {
+
+        if 
(queryKernel.getStageDefinition(stageId).mustGatherResultKeyStatistics()
+            && queryKernel.doesStageHaveResultPartitions(stageId)
+            && stageResultPartitionBoundariesSent.add(stageId)) {
+          if (log.isDebugEnabled()) {
+            final ClusterByPartitions partitions = 
queryKernel.getResultPartitionBoundariesForStage(stageId);
+            log.debug(
+                "Query [%s] sending out partition boundaries for stage %d: %s",
+                stageId.getQueryId(),
+                stageId.getStageNumber(),
+                IntStream.range(0, partitions.size())
+                         .mapToObj(i -> StringUtils.format("%s:%s", i, 
partitions.get(i)))
+                         .collect(Collectors.joining(", "))
+            );
+          } else {
+            log.info(
+                "Query [%s] sending out partition boundaries for stage %d.",
+                stageId.getQueryId(),
+                stageId.getStageNumber()
+            );
+          }
+
+          postResultPartitionBoundariesForStage(
+              queryDef,
+              stageId.getStageNumber(),
+              queryKernel.getResultPartitionBoundariesForStage(stageId),
+              queryKernel.getWorkerInputsForStage(stageId).workers()
+          );
+        }
+      }
+
+      logKernelStatus(queryDef.getQueryId(), queryKernel);
+
+      // Live reports: update stage phases, worker counts, partition counts.
+      for (StageId stageId : queryKernel.getActiveStages()) {
+        final int stageNumber = stageId.getStageNumber();
+        stagePhasesForLiveReports.put(stageNumber, 
queryKernel.getStagePhase(stageId));
+
+        if (queryKernel.doesStageHaveResultPartitions(stageId)) {
+          stagePartitionCountsForLiveReports.computeIfAbsent(
+              stageNumber,
+              k -> 
Iterators.size(queryKernel.getResultPartitionsForStage(stageId).iterator())
+          );
+        }
+
+        stageWorkerCountsForLiveReports.putIfAbsent(
+            stageNumber,
+            queryKernel.getWorkerInputsForStage(stageId).workerCount()
+        );
+      }
+
+      // Live reports: update stage end times for any stages that just ended.
+      for (StageId stageId : queryKernel.getActiveStages()) {
+        if 
(ControllerStagePhase.isSuccessfulTerminalPhase(queryKernel.getStagePhase(stageId)))
 {
+          stageRuntimesForLiveReports.compute(
+              queryKernel.getStageDefinition(stageId).getStageNumber(),
+              (k, currentValue) -> {
+                if (currentValue.getEnd().equals(DateTimes.MAX)) {
+                  return new Interval(currentValue.getStart(), 
DateTimes.nowUtc());
+                } else {
+                  return currentValue;
+                }
+              }
+          );
+        }
+      }
+
+      // Notify the workers to clean up the stages which can be marked as 
finished.
+      cleanUpEffectivelyFinishedStages(queryDef, queryKernel);
+
+      if (!queryKernel.isDone()) {
+        // Run the next command, waiting for it if necessary.
+        Consumer<ControllerQueryKernel> command = 
kernelManipulationQueue.take();
+        command.accept(queryKernel);
+
+        // Run all pending commands after that one. Helps avoid deep queues.
+        // After draining the command queue, move on to the next iteration of 
the controller loop.
+        while ((command = kernelManipulationQueue.poll()) != null) {
+          command.accept(queryKernel);
+        }
+      }
+    }
+
+    if (!queryKernel.isSuccess()) {
+      // Look for a known failure reason and throw a meaningful exception.
+      for (final StageId stageId : queryKernel.getActiveStages()) {
+        if (queryKernel.getStagePhase(stageId) == ControllerStagePhase.FAILED) 
{
+          final MSQFault fault = queryKernel.getFailureReasonForStage(stageId);
+
+          // Fall through (without throwing an exception) in case of 
UnknownFault; we may be able to generate
+          // a better exception later in query teardown.
+          if (!UnknownFault.CODE.equals(fault.getErrorCode())) {
+            throw new MSQException(fault);
+          }
+        }
+      }
+    }
+
+    cleanUpEffectivelyFinishedStages(queryDef, queryKernel);
+    return Pair.of(queryKernel, workerTaskLauncherFuture);
+  }
+
+  private void cleanUpEffectivelyFinishedStages(QueryDefinition queryDef, 
ControllerQueryKernel queryKernel)
+  {
+    for (final StageId stageId : queryKernel.getEffectivelyFinishedStageIds()) 
{
+      log.info("Query [%s] issuing cleanup order for stage %d.", 
queryDef.getQueryId(), stageId.getStageNumber());
+      contactWorkersForStage(
+          (netClient, taskId, workerNumber) -> 
netClient.postCleanupStage(taskId, stageId),
+          queryKernel.getWorkerInputsForStage(stageId).workers()
+      );
+      queryKernel.finishStage(stageId, true);
+    }
+  }
+
+  /**
+   * Provide a {@link ClusterByStatisticsSnapshot} for shuffling stages.
+   */
+  @Override
+  public void updateStatus(int stageNumber, int workerNumber, Object 
keyStatisticsObject)
+  {
+    addToKernelManipulationQueue(
+        queryKernel -> {
+          final StageId stageId = queryKernel.getStageId(stageNumber);
+
+          // We need a specially-decorated ObjectMapper to deserialize key 
statistics.
+          final StageDefinition stageDef = 
queryKernel.getStageDefinition(stageId);
+          final ObjectMapper mapper = 
MSQTasks.decorateObjectMapperForKeyCollectorSnapshot(
+              context.jsonMapper(),
+              stageDef.getShuffleSpec().get().getClusterBy(),
+              stageDef.getShuffleSpec().get().doesAggregateByClusterKey()
+          );
+
+          final ClusterByStatisticsSnapshot keyStatistics;
+          try {
+            keyStatistics = mapper.convertValue(keyStatisticsObject, 
ClusterByStatisticsSnapshot.class);
+          }
+          catch (IllegalArgumentException e) {
+            throw new IAE(
+                e,
+                "Unable to deserialize the key statistic for stage [%s] 
received from the worker [%d]",
+                stageId,
+                workerNumber
+            );
+          }
+
+          queryKernel.addResultKeyStatisticsForStageAndWorker(stageId, 
workerNumber, keyStatistics);
+        }
+    );
+  }
+
+  @Override
+  public void workerError(MSQErrorReport errorReport)
+  {
+    if 
(!workerTaskLauncher.isTaskCanceledByController(errorReport.getTaskId())) {
+      workerErrorRef.compareAndSet(null, errorReport);
+    }
+  }
+
+  /**
+   * This method intakes all the warnings that are generated by the worker. It 
is the responsibility of the
+   * worker node to ensure that it doesn't spam the controller with 
unneseccary warning stack traces. Currently, that
+   * limiting is implemented in {@link MSQWarningReportLimiterPublisher}
+   */
+  @Override
+  public void workerWarning(List<MSQErrorReport> errorReports)
+  {
+    // This check safeguards that the controller doesn't run out of memory. 
Workers apply their own limiting to
+    // protect their own memory, and to conserve worker -> controller 
bandwidth.
+    long numReportsToAddCheck = Math.min(
+        errorReports.size(),
+        Limits.MAX_WORKERS * Limits.MAX_VERBOSE_WARNINGS - 
workerWarnings.size()
+    );
+    if (numReportsToAddCheck > 0) {
+      synchronized (workerWarnings) {
+        long numReportsToAdd = Math.min(
+            errorReports.size(),
+            Limits.MAX_WORKERS * Limits.MAX_VERBOSE_WARNINGS - 
workerWarnings.size()
+        );
+        for (int i = 0; i < numReportsToAdd; ++i) {
+          workerWarnings.add(errorReports.get(i));
+        }
+      }
+    }
+  }
+
+  /**
+   * Periodic update of {@link CounterSnapshots} from subtasks.
+   */
+  @Override
+  public void updateCounters(CounterSnapshotsTree snapshotsTree)
+  {
+    taskCountersForLiveReports.putAll(snapshotsTree);
+    Optional<Pair<String, Long>> warningsExceeded =
+        
faultsExceededChecker.addFaultsAndCheckIfExceeded(taskCountersForLiveReports);
+
+    if (warningsExceeded.isPresent()) {

Review Comment:
   The javadoc to `addFaultsAndCheckIfExceeded` describes what it means. I'll 
add a comment here, too, though.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to