gianm commented on code in PR #12918: URL: https://github.com/apache/druid/pull/12918#discussion_r951574314
########## extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java: ########## @@ -0,0 +1,1230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import com.google.common.base.Function; +import com.google.common.base.Preconditions; +import com.google.common.base.Suppliers; +import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.ListeningExecutorService; +import com.google.common.util.concurrent.SettableFuture; +import it.unimi.dsi.fastutil.bytes.ByteArrays; +import org.apache.druid.frame.allocation.ArenaMemoryAllocator; +import org.apache.druid.frame.channel.BlockingQueueFrameChannel; +import org.apache.druid.frame.channel.ReadableFileFrameChannel; +import org.apache.druid.frame.channel.ReadableFrameChannel; +import org.apache.druid.frame.channel.ReadableNilFrameChannel; +import org.apache.druid.frame.file.FrameFile; +import org.apache.druid.frame.file.FrameFileWriter; +import org.apache.druid.frame.key.ClusterBy; +import org.apache.druid.frame.key.ClusterByPartitions; +import org.apache.druid.frame.processor.BlockingQueueOutputChannelFactory; +import org.apache.druid.frame.processor.Bouncer; +import org.apache.druid.frame.processor.FileOutputChannelFactory; +import org.apache.druid.frame.processor.FrameChannelMuxer; +import org.apache.druid.frame.processor.FrameProcessor; +import org.apache.druid.frame.processor.FrameProcessorExecutor; +import org.apache.druid.frame.processor.OutputChannel; +import org.apache.druid.frame.processor.OutputChannelFactory; +import org.apache.druid.frame.processor.OutputChannels; +import org.apache.druid.frame.processor.SuperSorter; +import org.apache.druid.frame.processor.SuperSorterProgressTracker; +import org.apache.druid.indexer.TaskStatus; +import org.apache.druid.java.util.common.FileUtils; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.Pair; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.java.util.common.guava.Sequence; +import org.apache.druid.java.util.common.guava.Sequences; +import org.apache.druid.java.util.common.io.Closer; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.msq.counters.CounterNames; +import org.apache.druid.msq.counters.CounterSnapshotsTree; +import org.apache.druid.msq.counters.CounterTracker; +import org.apache.druid.msq.indexing.CountingOutputChannelFactory; +import org.apache.druid.msq.indexing.InputChannelFactory; +import org.apache.druid.msq.indexing.InputChannelsImpl; +import org.apache.druid.msq.indexing.KeyStatisticsCollectionProcessor; +import org.apache.druid.msq.indexing.MSQWorkerTask; +import org.apache.druid.msq.indexing.error.CanceledFault; +import org.apache.druid.msq.indexing.error.MSQErrorReport; +import org.apache.druid.msq.indexing.error.MSQException; +import org.apache.druid.msq.indexing.error.MSQWarningReportLimiterPublisher; +import org.apache.druid.msq.indexing.error.MSQWarningReportPublisher; +import org.apache.druid.msq.indexing.error.MSQWarningReportSimplePublisher; +import org.apache.druid.msq.input.InputSlice; +import org.apache.druid.msq.input.InputSliceReader; +import org.apache.druid.msq.input.InputSlices; +import org.apache.druid.msq.input.MapInputSliceReader; +import org.apache.druid.msq.input.NilInputSlice; +import org.apache.druid.msq.input.NilInputSliceReader; +import org.apache.druid.msq.input.external.ExternalInputSlice; +import org.apache.druid.msq.input.external.ExternalInputSliceReader; +import org.apache.druid.msq.input.stage.InputChannels; +import org.apache.druid.msq.input.stage.ReadablePartition; +import org.apache.druid.msq.input.stage.StageInputSlice; +import org.apache.druid.msq.input.stage.StageInputSliceReader; +import org.apache.druid.msq.input.table.SegmentsInputSlice; +import org.apache.druid.msq.input.table.SegmentsInputSliceReader; +import org.apache.druid.msq.kernel.FrameContext; +import org.apache.druid.msq.kernel.FrameProcessorFactory; +import org.apache.druid.msq.kernel.ProcessorsAndChannels; +import org.apache.druid.msq.kernel.QueryDefinition; +import org.apache.druid.msq.kernel.StageDefinition; +import org.apache.druid.msq.kernel.StageId; +import org.apache.druid.msq.kernel.StagePartition; +import org.apache.druid.msq.kernel.WorkOrder; +import org.apache.druid.msq.kernel.worker.WorkerStageKernel; +import org.apache.druid.msq.kernel.worker.WorkerStagePhase; +import org.apache.druid.msq.querykit.DataSegmentProvider; +import org.apache.druid.msq.shuffle.DurableStorageInputChannelFactory; +import org.apache.druid.msq.shuffle.DurableStorageOutputChannelFactory; +import org.apache.druid.msq.shuffle.WorkerInputChannelFactory; +import org.apache.druid.msq.statistics.ClusterByStatisticsCollector; +import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot; +import org.apache.druid.msq.util.DecoratedExecutorService; +import org.apache.druid.msq.util.MultiStageQueryContext; +import org.apache.druid.query.PrioritizedCallable; +import org.apache.druid.query.PrioritizedRunnable; +import org.apache.druid.query.QueryProcessingPool; +import org.apache.druid.server.DruidNode; + +import javax.annotation.Nullable; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.io.RandomAccessFile; +import java.nio.channels.Channels; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.Callable; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; + +/** + * Interface for a worker of a multi-stage query. + */ +public class WorkerImpl implements Worker +{ + private static final Logger log = new Logger(WorkerImpl.class); + + private final MSQWorkerTask task; + private final WorkerContext context; + + private final BlockingQueue<Consumer<KernelHolder>> kernelManipulationQueue = new LinkedBlockingDeque<>(); + private final ConcurrentHashMap<StageId, ConcurrentHashMap<Integer, ReadableFrameChannel>> stageOutputs = new ConcurrentHashMap<>(); + private final ConcurrentHashMap<StageId, CounterTracker> stageCounters = new ConcurrentHashMap<>(); + private final boolean durableStageStorageEnabled; + + private volatile DruidNode selfDruidNode; + private volatile ControllerClient controllerClient; + private volatile WorkerClient workerClient; + private volatile Bouncer processorBouncer; + private volatile boolean controllerAlive = true; + + public WorkerImpl(MSQWorkerTask task, WorkerContext context) + { + this.task = task; + this.context = context; + this.durableStageStorageEnabled = MultiStageQueryContext.isDurableStorageEnabled(task.getContext()); + } + + @Override + public String id() + { + return task.getId(); + } + + @Override + public MSQWorkerTask task() + { + return task; + } + + @Override + public TaskStatus run() throws Exception + { + try (final Closer closer = Closer.create()) { + Optional<MSQErrorReport> maybeErrorReport; + + try { + maybeErrorReport = runTask(closer); + } + catch (Throwable e) { + maybeErrorReport = Optional.of( + MSQErrorReport.fromException(id(), MSQTasks.getHostFromSelfNode(selfDruidNode), null, e) + ); + } + + if (maybeErrorReport.isPresent()) { + final MSQErrorReport errorReport = maybeErrorReport.get(); + final String errorLogMessage = MSQTasks.errorReportToLogMessage(errorReport); + log.warn(errorLogMessage); + + closer.register(() -> { + if (controllerAlive && controllerClient != null && selfDruidNode != null) { + controllerClient.postWorkerError(id(), errorReport); + } + }); + + return TaskStatus.failure(id(), errorReport.getFault().getCodeWithMessage()); + } else { + return TaskStatus.success(id()); + } + } + } + + /** + * Runs worker logic. Returns an empty Optional on success. On failure, returns an error report for errors that + * happened in other threads; throws exceptions for errors that happened in the main worker loop. + */ + public Optional<MSQErrorReport> runTask(final Closer closer) throws Exception + { + this.selfDruidNode = context.selfNode(); + this.controllerClient = context.makeControllerClient(task.getControllerTaskId()); + closer.register(controllerClient::close); + context.registerWorker(this, closer); // Uses controllerClient, so must be called after that is initialized + this.workerClient = new ExceptionWrappingWorkerClient(context.makeWorkerClient()); + closer.register(workerClient::close); + this.processorBouncer = context.processorBouncer(); + + final KernelHolder kernelHolder = new KernelHolder(); + final String cancellationId = id(); + + final FrameProcessorExecutor workerExec = new FrameProcessorExecutor(makeProcessingPool()); + + // Delete all the stage outputs + closer.register(() -> { + for (final StageId stageId : stageOutputs.keySet()) { + cleanStageOutput(stageId); + } + }); + + // Close stage output processors and running futures (if present) + closer.register(() -> { + try { + workerExec.cancel(cancellationId); + } + catch (InterruptedException e) { + // Strange that cancellation would itself be interrupted. Throw an exception, since this is unexpected. + throw new RuntimeException(e); + } + }); + + final MSQWarningReportPublisher msqWarningReportPublisher = new MSQWarningReportLimiterPublisher( + new MSQWarningReportSimplePublisher( + id(), + controllerClient, + id(), + MSQTasks.getHostFromSelfNode(selfDruidNode) + ) + ); + + closer.register(msqWarningReportPublisher); + + final Map<StageId, SettableFuture<ClusterByPartitions>> partitionBoundariesFutureMap = new HashMap<>(); + + final Map<StageId, FrameContext> stageFrameContexts = new HashMap<>(); + + while (!kernelHolder.isDone()) { + boolean didSomething = false; + + for (final WorkerStageKernel kernel : kernelHolder.getStageKernelMap().values()) { + final StageDefinition stageDefinition = kernel.getStageDefinition(); + + if (kernel.getPhase() == WorkerStagePhase.NEW) { + log.debug("New work order: %s", context.jsonMapper().writeValueAsString(kernel.getWorkOrder())); + + // Create separate inputChannelFactory per stage, because the list of tasks can grow between stages, and + // so we need to avoid the memoization in baseInputChannelFactory. + final InputChannelFactory inputChannelFactory = makeBaseInputChannelFactory(closer); + + // Compute memory parameters for all stages, even ones that haven't been assigned yet, so we can fail-fast + // if some won't work. (We expect that all stages will get assigned to the same pool of workers.) + for (final StageDefinition stageDef : kernel.getWorkOrder().getQueryDefinition().getStageDefinitions()) { + stageFrameContexts.computeIfAbsent( + stageDef.getId(), + stageId -> context.frameContext( + kernel.getWorkOrder().getQueryDefinition(), + stageId.getStageNumber() + ) + ); + } + + // Start working on this stage immediately. + kernel.startReading(); + final SettableFuture<ClusterByPartitions> partitionBoundariesFuture = + startWorkOrder( + kernel, + inputChannelFactory, + stageCounters.computeIfAbsent(stageDefinition.getId(), ignored -> new CounterTracker()), + workerExec, + cancellationId, + context.threadCount(), + stageFrameContexts.get(stageDefinition.getId()), + msqWarningReportPublisher + ); + + if (partitionBoundariesFuture != null) { + if (partitionBoundariesFutureMap.put(stageDefinition.getId(), partitionBoundariesFuture) != null) { + throw new ISE("Work order collision for stage [%s]", stageDefinition.getId()); + } + } + + didSomething = true; + logKernelStatus(kernelHolder.getStageKernelMap().values()); + } + + if (kernel.getPhase() == WorkerStagePhase.READING_INPUT && kernel.hasResultKeyStatisticsSnapshot()) { + if (controllerAlive) { + controllerClient.postKeyStatistics( + stageDefinition.getId(), + kernel.getWorkOrder().getWorkerNumber(), + kernel.getResultKeyStatisticsSnapshot() + ); + } + kernel.startPreshuffleWaitingForResultPartitionBoundaries(); + + didSomething = true; + logKernelStatus(kernelHolder.getStageKernelMap().values()); + } + + logKernelStatus(kernelHolder.getStageKernelMap().values()); + if (kernel.getPhase() == WorkerStagePhase.PRESHUFFLE_WAITING_FOR_RESULT_PARTITION_BOUNDARIES + && kernel.hasResultPartitionBoundaries()) { + partitionBoundariesFutureMap.get(stageDefinition.getId()).set(kernel.getResultPartitionBoundaries()); + kernel.startPreshuffleWritingOutput(); + + didSomething = true; + logKernelStatus(kernelHolder.getStageKernelMap().values()); + } + + if (kernel.getPhase() == WorkerStagePhase.RESULTS_READY + && kernel.addPostedResultsComplete(Pair.of(stageDefinition.getId(), kernel.getWorkOrder().getWorkerNumber()))) { + if (controllerAlive) { + controllerClient.postResultsComplete( + stageDefinition.getId(), + kernel.getWorkOrder().getWorkerNumber(), + kernel.getResultObject() + ); + } + } + + if (kernel.getPhase() == WorkerStagePhase.FAILED) { + // Better than throwing an exception, because we can include the stage number. + return Optional.of( + MSQErrorReport.fromException( + id(), + MSQTasks.getHostFromSelfNode(selfDruidNode), + stageDefinition.getId().getStageNumber(), + kernel.getException() + ) + ); + } + } + + if (!didSomething && !kernelHolder.isDone()) { + Consumer<KernelHolder> nextCommand; + + do { + postCountersToController(); + } while ((nextCommand = kernelManipulationQueue.poll(5, TimeUnit.SECONDS)) == null); + + nextCommand.accept(kernelHolder); + logKernelStatus(kernelHolder.getStageKernelMap().values()); + } + } + + // Empty means success. + return Optional.empty(); + } + + @Override + public void stopGracefully() + { + kernelManipulationQueue.add( + kernel -> { + // stopGracefully() is called when the containing process is terminated, or when the task is canceled. + throw new MSQException(CanceledFault.INSTANCE); + } + ); + } + + @Override + public void controllerFailed() + { + controllerAlive = false; + stopGracefully(); + } + + @Override + public InputStream readChannel( + final String queryId, + final int stageNumber, + final int partitionNumber, + final long offset + ) throws IOException + { + final StageId stageId = new StageId(queryId, stageNumber); + final StagePartition stagePartition = new StagePartition(stageId, partitionNumber); + final ConcurrentHashMap<Integer, ReadableFrameChannel> partitionOutputsForStage = stageOutputs.get(stageId); + + if (partitionOutputsForStage == null) { + return null; + } + final ReadableFrameChannel channel = partitionOutputsForStage.get(partitionNumber); + + if (channel == null) { + return null; + } + + if (channel instanceof ReadableNilFrameChannel) { + // Build an empty frame file. + final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + FrameFileWriter.open(Channels.newChannel(baos), null).close(); + + final ByteArrayInputStream in = new ByteArrayInputStream(baos.toByteArray()); + + //noinspection ResultOfMethodCallIgnored: OK to ignore since "skip" always works for ByteArrayInputStream. + in.skip(offset); + + return in; + } else if (channel instanceof ReadableFileFrameChannel) { + // Close frameFile once we've returned an input stream: no need to retain a reference to the mmap after that, + // since we aren't using it. + try (final FrameFile frameFile = ((ReadableFileFrameChannel) channel).newFrameFileReference()) { + final RandomAccessFile randomAccessFile = new RandomAccessFile(frameFile.file(), "r"); + + if (offset >= randomAccessFile.length()) { + randomAccessFile.close(); + return new ByteArrayInputStream(ByteArrays.EMPTY_ARRAY); + } else { + randomAccessFile.seek(offset); + return Channels.newInputStream(randomAccessFile.getChannel()); + } + } + } else { + String errorMsg = StringUtils.format( + "Returned server error to client because channel for [%s] is not nil or file-based (class = %s)", + stagePartition, + channel.getClass().getName() + ); + log.error(StringUtils.encodeForFormat(errorMsg)); + + throw new IOException(errorMsg); + } + } + + @Override + public void postWorkOrder(final WorkOrder workOrder) + { + if (task.getWorkerNumber() != workOrder.getWorkerNumber()) { + throw new ISE("Worker number mismatch: expected [%d]", task.getWorkerNumber()); + } + + kernelManipulationQueue.add( + kernelHolder -> + kernelHolder.getStageKernelMap().computeIfAbsent( + workOrder.getStageDefinition().getId(), + ignored -> WorkerStageKernel.create(workOrder) + ) + ); + } + + @Override + public boolean postResultPartitionBoundaries( + final ClusterByPartitions stagePartitionBoundaries, + final String queryId, + final int stageNumber + ) + { + final StageId stageId = new StageId(queryId, stageNumber); + + kernelManipulationQueue.add( + kernelHolder -> { + final WorkerStageKernel stageKernel = kernelHolder.getStageKernelMap().get(stageId); + + // Ignore the update if we don't have a kernel for this stage. + if (stageKernel != null) { + stageKernel.setResultPartitionBoundaries(stagePartitionBoundaries); + } else { + log.warn("Ignored result partition boundaries call for unknown stage [%s]", stageId); + } + } + ); + return true; + } + + @Override + public void postCleanupStage(final StageId stageId) + { + log.info("Cleanup order for stage: [%s] received", stageId); + kernelManipulationQueue.add( + holder -> { + cleanStageOutput(stageId); + // Mark the stage as FINISHED + holder.getStageKernelMap().get(stageId).setStageFinished(); + } + ); + } + + @Override + public void postFinish() + { + kernelManipulationQueue.add(KernelHolder::setDone); + } + + @Override + public CounterSnapshotsTree getCounters() + { + final CounterSnapshotsTree retVal = new CounterSnapshotsTree(); + + for (final Map.Entry<StageId, CounterTracker> entry : stageCounters.entrySet()) { + retVal.put(entry.getKey().getStageNumber(), task().getWorkerNumber(), entry.getValue().snapshot()); + } + + return retVal; + } + + private InputChannelFactory makeBaseInputChannelFactory(final Closer closer) + { + final Supplier<List<String>> workerTaskList = Suppliers.memoize( + () -> { + try { + return controllerClient.getTaskList(); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + )::get; + + if (durableStageStorageEnabled) { + return DurableStorageInputChannelFactory.createStandardImplementation( + task.getControllerTaskId(), + workerTaskList, + MSQTasks.makeStorageConnector(context.injector()), + closer + ); + } else { + return new WorkerOrLocalInputChannelFactory(workerTaskList); + } + } + + private OutputChannelFactory makeStageOutputChannelFactory(final FrameContext frameContext, final int stageNumber) + { + // Use the standard frame size, since we assume this size when computing how much is needed to merge output + // files from different workers. + final int frameSize = frameContext.memoryParameters().getStandardFrameSize(); + + if (durableStageStorageEnabled) { + return DurableStorageOutputChannelFactory.createStandardImplementation( + task.getControllerTaskId(), + id(), + stageNumber, + frameSize, + MSQTasks.makeStorageConnector(context.injector()) + ); + } else { + final File fileChannelDirectory = + new File(context.tempDir(), StringUtils.format("output_stage_%06d", stageNumber)); + + return new FileOutputChannelFactory(fileChannelDirectory, frameSize); + } + } + + private ListeningExecutorService makeProcessingPool() Review Comment: Added this note: ``` /** * Decoartes the server-wide {@link QueryProcessingPool} such that any Callables and Runnables, not just * {@link PrioritizedCallable} and {@link PrioritizedRunnable}, may be added to it. * * In production, the underlying {@link QueryProcessingPool} pool is set up by * {@link org.apache.druid.guice.DruidProcessingModule}. */ ``` ########## extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java: ########## @@ -0,0 +1,1230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.exec; + +import com.google.common.base.Function; +import com.google.common.base.Preconditions; +import com.google.common.base.Suppliers; +import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.ListeningExecutorService; +import com.google.common.util.concurrent.SettableFuture; +import it.unimi.dsi.fastutil.bytes.ByteArrays; +import org.apache.druid.frame.allocation.ArenaMemoryAllocator; +import org.apache.druid.frame.channel.BlockingQueueFrameChannel; +import org.apache.druid.frame.channel.ReadableFileFrameChannel; +import org.apache.druid.frame.channel.ReadableFrameChannel; +import org.apache.druid.frame.channel.ReadableNilFrameChannel; +import org.apache.druid.frame.file.FrameFile; +import org.apache.druid.frame.file.FrameFileWriter; +import org.apache.druid.frame.key.ClusterBy; +import org.apache.druid.frame.key.ClusterByPartitions; +import org.apache.druid.frame.processor.BlockingQueueOutputChannelFactory; +import org.apache.druid.frame.processor.Bouncer; +import org.apache.druid.frame.processor.FileOutputChannelFactory; +import org.apache.druid.frame.processor.FrameChannelMuxer; +import org.apache.druid.frame.processor.FrameProcessor; +import org.apache.druid.frame.processor.FrameProcessorExecutor; +import org.apache.druid.frame.processor.OutputChannel; +import org.apache.druid.frame.processor.OutputChannelFactory; +import org.apache.druid.frame.processor.OutputChannels; +import org.apache.druid.frame.processor.SuperSorter; +import org.apache.druid.frame.processor.SuperSorterProgressTracker; +import org.apache.druid.indexer.TaskStatus; +import org.apache.druid.java.util.common.FileUtils; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.Pair; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.java.util.common.guava.Sequence; +import org.apache.druid.java.util.common.guava.Sequences; +import org.apache.druid.java.util.common.io.Closer; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.msq.counters.CounterNames; +import org.apache.druid.msq.counters.CounterSnapshotsTree; +import org.apache.druid.msq.counters.CounterTracker; +import org.apache.druid.msq.indexing.CountingOutputChannelFactory; +import org.apache.druid.msq.indexing.InputChannelFactory; +import org.apache.druid.msq.indexing.InputChannelsImpl; +import org.apache.druid.msq.indexing.KeyStatisticsCollectionProcessor; +import org.apache.druid.msq.indexing.MSQWorkerTask; +import org.apache.druid.msq.indexing.error.CanceledFault; +import org.apache.druid.msq.indexing.error.MSQErrorReport; +import org.apache.druid.msq.indexing.error.MSQException; +import org.apache.druid.msq.indexing.error.MSQWarningReportLimiterPublisher; +import org.apache.druid.msq.indexing.error.MSQWarningReportPublisher; +import org.apache.druid.msq.indexing.error.MSQWarningReportSimplePublisher; +import org.apache.druid.msq.input.InputSlice; +import org.apache.druid.msq.input.InputSliceReader; +import org.apache.druid.msq.input.InputSlices; +import org.apache.druid.msq.input.MapInputSliceReader; +import org.apache.druid.msq.input.NilInputSlice; +import org.apache.druid.msq.input.NilInputSliceReader; +import org.apache.druid.msq.input.external.ExternalInputSlice; +import org.apache.druid.msq.input.external.ExternalInputSliceReader; +import org.apache.druid.msq.input.stage.InputChannels; +import org.apache.druid.msq.input.stage.ReadablePartition; +import org.apache.druid.msq.input.stage.StageInputSlice; +import org.apache.druid.msq.input.stage.StageInputSliceReader; +import org.apache.druid.msq.input.table.SegmentsInputSlice; +import org.apache.druid.msq.input.table.SegmentsInputSliceReader; +import org.apache.druid.msq.kernel.FrameContext; +import org.apache.druid.msq.kernel.FrameProcessorFactory; +import org.apache.druid.msq.kernel.ProcessorsAndChannels; +import org.apache.druid.msq.kernel.QueryDefinition; +import org.apache.druid.msq.kernel.StageDefinition; +import org.apache.druid.msq.kernel.StageId; +import org.apache.druid.msq.kernel.StagePartition; +import org.apache.druid.msq.kernel.WorkOrder; +import org.apache.druid.msq.kernel.worker.WorkerStageKernel; +import org.apache.druid.msq.kernel.worker.WorkerStagePhase; +import org.apache.druid.msq.querykit.DataSegmentProvider; +import org.apache.druid.msq.shuffle.DurableStorageInputChannelFactory; +import org.apache.druid.msq.shuffle.DurableStorageOutputChannelFactory; +import org.apache.druid.msq.shuffle.WorkerInputChannelFactory; +import org.apache.druid.msq.statistics.ClusterByStatisticsCollector; +import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot; +import org.apache.druid.msq.util.DecoratedExecutorService; +import org.apache.druid.msq.util.MultiStageQueryContext; +import org.apache.druid.query.PrioritizedCallable; +import org.apache.druid.query.PrioritizedRunnable; +import org.apache.druid.query.QueryProcessingPool; +import org.apache.druid.server.DruidNode; + +import javax.annotation.Nullable; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.io.RandomAccessFile; +import java.nio.channels.Channels; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.Callable; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; + +/** + * Interface for a worker of a multi-stage query. + */ +public class WorkerImpl implements Worker +{ + private static final Logger log = new Logger(WorkerImpl.class); + + private final MSQWorkerTask task; + private final WorkerContext context; + + private final BlockingQueue<Consumer<KernelHolder>> kernelManipulationQueue = new LinkedBlockingDeque<>(); + private final ConcurrentHashMap<StageId, ConcurrentHashMap<Integer, ReadableFrameChannel>> stageOutputs = new ConcurrentHashMap<>(); + private final ConcurrentHashMap<StageId, CounterTracker> stageCounters = new ConcurrentHashMap<>(); + private final boolean durableStageStorageEnabled; + + private volatile DruidNode selfDruidNode; + private volatile ControllerClient controllerClient; + private volatile WorkerClient workerClient; + private volatile Bouncer processorBouncer; + private volatile boolean controllerAlive = true; + + public WorkerImpl(MSQWorkerTask task, WorkerContext context) + { + this.task = task; + this.context = context; + this.durableStageStorageEnabled = MultiStageQueryContext.isDurableStorageEnabled(task.getContext()); + } + + @Override + public String id() + { + return task.getId(); + } + + @Override + public MSQWorkerTask task() + { + return task; + } + + @Override + public TaskStatus run() throws Exception + { + try (final Closer closer = Closer.create()) { + Optional<MSQErrorReport> maybeErrorReport; + + try { + maybeErrorReport = runTask(closer); + } + catch (Throwable e) { + maybeErrorReport = Optional.of( + MSQErrorReport.fromException(id(), MSQTasks.getHostFromSelfNode(selfDruidNode), null, e) + ); + } + + if (maybeErrorReport.isPresent()) { + final MSQErrorReport errorReport = maybeErrorReport.get(); + final String errorLogMessage = MSQTasks.errorReportToLogMessage(errorReport); + log.warn(errorLogMessage); + + closer.register(() -> { + if (controllerAlive && controllerClient != null && selfDruidNode != null) { + controllerClient.postWorkerError(id(), errorReport); + } + }); + + return TaskStatus.failure(id(), errorReport.getFault().getCodeWithMessage()); + } else { + return TaskStatus.success(id()); + } + } + } + + /** + * Runs worker logic. Returns an empty Optional on success. On failure, returns an error report for errors that + * happened in other threads; throws exceptions for errors that happened in the main worker loop. + */ + public Optional<MSQErrorReport> runTask(final Closer closer) throws Exception + { + this.selfDruidNode = context.selfNode(); + this.controllerClient = context.makeControllerClient(task.getControllerTaskId()); + closer.register(controllerClient::close); + context.registerWorker(this, closer); // Uses controllerClient, so must be called after that is initialized + this.workerClient = new ExceptionWrappingWorkerClient(context.makeWorkerClient()); + closer.register(workerClient::close); + this.processorBouncer = context.processorBouncer(); + + final KernelHolder kernelHolder = new KernelHolder(); + final String cancellationId = id(); + + final FrameProcessorExecutor workerExec = new FrameProcessorExecutor(makeProcessingPool()); + + // Delete all the stage outputs + closer.register(() -> { + for (final StageId stageId : stageOutputs.keySet()) { + cleanStageOutput(stageId); + } + }); + + // Close stage output processors and running futures (if present) + closer.register(() -> { + try { + workerExec.cancel(cancellationId); + } + catch (InterruptedException e) { + // Strange that cancellation would itself be interrupted. Throw an exception, since this is unexpected. + throw new RuntimeException(e); + } + }); + + final MSQWarningReportPublisher msqWarningReportPublisher = new MSQWarningReportLimiterPublisher( + new MSQWarningReportSimplePublisher( + id(), + controllerClient, + id(), + MSQTasks.getHostFromSelfNode(selfDruidNode) + ) + ); + + closer.register(msqWarningReportPublisher); + + final Map<StageId, SettableFuture<ClusterByPartitions>> partitionBoundariesFutureMap = new HashMap<>(); + + final Map<StageId, FrameContext> stageFrameContexts = new HashMap<>(); + + while (!kernelHolder.isDone()) { + boolean didSomething = false; + + for (final WorkerStageKernel kernel : kernelHolder.getStageKernelMap().values()) { + final StageDefinition stageDefinition = kernel.getStageDefinition(); + + if (kernel.getPhase() == WorkerStagePhase.NEW) { + log.debug("New work order: %s", context.jsonMapper().writeValueAsString(kernel.getWorkOrder())); + + // Create separate inputChannelFactory per stage, because the list of tasks can grow between stages, and + // so we need to avoid the memoization in baseInputChannelFactory. + final InputChannelFactory inputChannelFactory = makeBaseInputChannelFactory(closer); + + // Compute memory parameters for all stages, even ones that haven't been assigned yet, so we can fail-fast + // if some won't work. (We expect that all stages will get assigned to the same pool of workers.) + for (final StageDefinition stageDef : kernel.getWorkOrder().getQueryDefinition().getStageDefinitions()) { + stageFrameContexts.computeIfAbsent( + stageDef.getId(), + stageId -> context.frameContext( + kernel.getWorkOrder().getQueryDefinition(), + stageId.getStageNumber() + ) + ); + } + + // Start working on this stage immediately. + kernel.startReading(); + final SettableFuture<ClusterByPartitions> partitionBoundariesFuture = + startWorkOrder( + kernel, + inputChannelFactory, + stageCounters.computeIfAbsent(stageDefinition.getId(), ignored -> new CounterTracker()), + workerExec, + cancellationId, + context.threadCount(), + stageFrameContexts.get(stageDefinition.getId()), + msqWarningReportPublisher + ); + + if (partitionBoundariesFuture != null) { + if (partitionBoundariesFutureMap.put(stageDefinition.getId(), partitionBoundariesFuture) != null) { + throw new ISE("Work order collision for stage [%s]", stageDefinition.getId()); + } + } + + didSomething = true; + logKernelStatus(kernelHolder.getStageKernelMap().values()); + } + + if (kernel.getPhase() == WorkerStagePhase.READING_INPUT && kernel.hasResultKeyStatisticsSnapshot()) { + if (controllerAlive) { + controllerClient.postKeyStatistics( + stageDefinition.getId(), + kernel.getWorkOrder().getWorkerNumber(), + kernel.getResultKeyStatisticsSnapshot() + ); + } + kernel.startPreshuffleWaitingForResultPartitionBoundaries(); + + didSomething = true; + logKernelStatus(kernelHolder.getStageKernelMap().values()); + } + + logKernelStatus(kernelHolder.getStageKernelMap().values()); + if (kernel.getPhase() == WorkerStagePhase.PRESHUFFLE_WAITING_FOR_RESULT_PARTITION_BOUNDARIES + && kernel.hasResultPartitionBoundaries()) { + partitionBoundariesFutureMap.get(stageDefinition.getId()).set(kernel.getResultPartitionBoundaries()); + kernel.startPreshuffleWritingOutput(); + + didSomething = true; + logKernelStatus(kernelHolder.getStageKernelMap().values()); + } + + if (kernel.getPhase() == WorkerStagePhase.RESULTS_READY + && kernel.addPostedResultsComplete(Pair.of(stageDefinition.getId(), kernel.getWorkOrder().getWorkerNumber()))) { + if (controllerAlive) { + controllerClient.postResultsComplete( + stageDefinition.getId(), + kernel.getWorkOrder().getWorkerNumber(), + kernel.getResultObject() + ); + } + } + + if (kernel.getPhase() == WorkerStagePhase.FAILED) { + // Better than throwing an exception, because we can include the stage number. + return Optional.of( + MSQErrorReport.fromException( + id(), + MSQTasks.getHostFromSelfNode(selfDruidNode), + stageDefinition.getId().getStageNumber(), + kernel.getException() + ) + ); + } + } + + if (!didSomething && !kernelHolder.isDone()) { + Consumer<KernelHolder> nextCommand; + + do { + postCountersToController(); + } while ((nextCommand = kernelManipulationQueue.poll(5, TimeUnit.SECONDS)) == null); + + nextCommand.accept(kernelHolder); + logKernelStatus(kernelHolder.getStageKernelMap().values()); + } + } + + // Empty means success. + return Optional.empty(); + } + + @Override + public void stopGracefully() + { + kernelManipulationQueue.add( + kernel -> { + // stopGracefully() is called when the containing process is terminated, or when the task is canceled. + throw new MSQException(CanceledFault.INSTANCE); + } + ); + } + + @Override + public void controllerFailed() + { + controllerAlive = false; + stopGracefully(); + } + + @Override + public InputStream readChannel( + final String queryId, + final int stageNumber, + final int partitionNumber, + final long offset + ) throws IOException + { + final StageId stageId = new StageId(queryId, stageNumber); + final StagePartition stagePartition = new StagePartition(stageId, partitionNumber); + final ConcurrentHashMap<Integer, ReadableFrameChannel> partitionOutputsForStage = stageOutputs.get(stageId); + + if (partitionOutputsForStage == null) { + return null; + } + final ReadableFrameChannel channel = partitionOutputsForStage.get(partitionNumber); + + if (channel == null) { + return null; + } + + if (channel instanceof ReadableNilFrameChannel) { + // Build an empty frame file. + final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + FrameFileWriter.open(Channels.newChannel(baos), null).close(); + + final ByteArrayInputStream in = new ByteArrayInputStream(baos.toByteArray()); + + //noinspection ResultOfMethodCallIgnored: OK to ignore since "skip" always works for ByteArrayInputStream. + in.skip(offset); + + return in; + } else if (channel instanceof ReadableFileFrameChannel) { + // Close frameFile once we've returned an input stream: no need to retain a reference to the mmap after that, + // since we aren't using it. + try (final FrameFile frameFile = ((ReadableFileFrameChannel) channel).newFrameFileReference()) { + final RandomAccessFile randomAccessFile = new RandomAccessFile(frameFile.file(), "r"); + + if (offset >= randomAccessFile.length()) { + randomAccessFile.close(); + return new ByteArrayInputStream(ByteArrays.EMPTY_ARRAY); + } else { + randomAccessFile.seek(offset); + return Channels.newInputStream(randomAccessFile.getChannel()); + } + } + } else { + String errorMsg = StringUtils.format( + "Returned server error to client because channel for [%s] is not nil or file-based (class = %s)", + stagePartition, + channel.getClass().getName() + ); + log.error(StringUtils.encodeForFormat(errorMsg)); + + throw new IOException(errorMsg); + } + } + + @Override + public void postWorkOrder(final WorkOrder workOrder) + { + if (task.getWorkerNumber() != workOrder.getWorkerNumber()) { + throw new ISE("Worker number mismatch: expected [%d]", task.getWorkerNumber()); + } + + kernelManipulationQueue.add( + kernelHolder -> + kernelHolder.getStageKernelMap().computeIfAbsent( + workOrder.getStageDefinition().getId(), + ignored -> WorkerStageKernel.create(workOrder) + ) + ); + } + + @Override + public boolean postResultPartitionBoundaries( + final ClusterByPartitions stagePartitionBoundaries, + final String queryId, + final int stageNumber + ) + { + final StageId stageId = new StageId(queryId, stageNumber); + + kernelManipulationQueue.add( + kernelHolder -> { + final WorkerStageKernel stageKernel = kernelHolder.getStageKernelMap().get(stageId); + + // Ignore the update if we don't have a kernel for this stage. + if (stageKernel != null) { + stageKernel.setResultPartitionBoundaries(stagePartitionBoundaries); + } else { + log.warn("Ignored result partition boundaries call for unknown stage [%s]", stageId); + } + } + ); + return true; + } + + @Override + public void postCleanupStage(final StageId stageId) + { + log.info("Cleanup order for stage: [%s] received", stageId); + kernelManipulationQueue.add( + holder -> { + cleanStageOutput(stageId); + // Mark the stage as FINISHED + holder.getStageKernelMap().get(stageId).setStageFinished(); + } + ); + } + + @Override + public void postFinish() + { + kernelManipulationQueue.add(KernelHolder::setDone); + } + + @Override + public CounterSnapshotsTree getCounters() + { + final CounterSnapshotsTree retVal = new CounterSnapshotsTree(); + + for (final Map.Entry<StageId, CounterTracker> entry : stageCounters.entrySet()) { + retVal.put(entry.getKey().getStageNumber(), task().getWorkerNumber(), entry.getValue().snapshot()); + } + + return retVal; + } + + private InputChannelFactory makeBaseInputChannelFactory(final Closer closer) + { + final Supplier<List<String>> workerTaskList = Suppliers.memoize( + () -> { + try { + return controllerClient.getTaskList(); + } + catch (IOException e) { + throw new RuntimeException(e); + } + } + )::get; + + if (durableStageStorageEnabled) { + return DurableStorageInputChannelFactory.createStandardImplementation( + task.getControllerTaskId(), + workerTaskList, + MSQTasks.makeStorageConnector(context.injector()), + closer + ); + } else { + return new WorkerOrLocalInputChannelFactory(workerTaskList); + } + } + + private OutputChannelFactory makeStageOutputChannelFactory(final FrameContext frameContext, final int stageNumber) + { + // Use the standard frame size, since we assume this size when computing how much is needed to merge output + // files from different workers. + final int frameSize = frameContext.memoryParameters().getStandardFrameSize(); + + if (durableStageStorageEnabled) { + return DurableStorageOutputChannelFactory.createStandardImplementation( + task.getControllerTaskId(), + id(), + stageNumber, + frameSize, + MSQTasks.makeStorageConnector(context.injector()) + ); + } else { + final File fileChannelDirectory = + new File(context.tempDir(), StringUtils.format("output_stage_%06d", stageNumber)); + + return new FileOutputChannelFactory(fileChannelDirectory, frameSize); + } + } + + private ListeningExecutorService makeProcessingPool() Review Comment: Added this note: ``` /** * Decorates the server-wide {@link QueryProcessingPool} such that any Callables and Runnables, not just * {@link PrioritizedCallable} and {@link PrioritizedRunnable}, may be added to it. * * In production, the underlying {@link QueryProcessingPool} pool is set up by * {@link org.apache.druid.guice.DruidProcessingModule}. */ ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
