This is an automated email from the ASF dual-hosted git repository.

karan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git


The following commit(s) were added to refs/heads/master by this push:
     new 06bbdb38ce8 MSQ: Allow for worker gaps. (#17277)
06bbdb38ce8 is described below

commit 06bbdb38ce8fe438e37ee4ac5e457d3b62cc64df
Author: Gian Merlino <[email protected]>
AuthorDate: Tue Oct 8 02:37:57 2024 -0700

    MSQ: Allow for worker gaps. (#17277)
    
    In a Dart query, all Historicals are given worker IDs, but not all of them
    are going to actually be started or receive work orders. This can create 
gaps
    in the set of workers. For example, workers 1 and 3 could have work assigned
    while workers 0 and 2 do not.
    
    This patch updates ControllerStageTracker and WorkerInputs to handle such
    gaps, by using the set of actual worker numbers, rather than 0..workerCount,
    in various places.
---
 .../druid/msq/input/stage/ReadablePartition.java   |  12 ++
 .../druid/msq/input/stage/ReadablePartitions.java  |  33 ++++-
 .../stage/SparseStripedReadablePartitions.java     | 142 +++++++++++++++++++++
 .../kernel/controller/ControllerStageTracker.java  |  26 ++--
 .../druid/msq/kernel/controller/WorkerInputs.java  |  41 +++---
 .../stage/CollectedReadablePartitionsTest.java     |  12 +-
 .../stage/CombinedReadablePartitionsTest.java      |   2 +-
 ...va => SparseStripedReadablePartitionsTest.java} |  29 +++--
 .../input/stage/StripedReadablePartitionsTest.java |  34 ++++-
 .../msq/kernel/controller/WorkerInputsTest.java    |  98 +++++++++++---
 10 files changed, 359 insertions(+), 70 deletions(-)

diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/ReadablePartition.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/ReadablePartition.java
index 99098d1d4cb..5f366c60009 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/ReadablePartition.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/ReadablePartition.java
@@ -59,6 +59,18 @@ public class ReadablePartition
     return new ReadablePartition(stageNumber, workerNumbers, partitionNumber);
   }
 
+  /**
+   * Returns an output partition that is striped across a set of {@code 
workerNumbers}.
+   */
+  public static ReadablePartition striped(
+      final int stageNumber,
+      final IntSortedSet workerNumbers,
+      final int partitionNumber
+  )
+  {
+    return new ReadablePartition(stageNumber, workerNumbers, partitionNumber);
+  }
+
   /**
    * Returns an output partition that has been collected onto a single worker.
    */
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/ReadablePartitions.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/ReadablePartitions.java
index a71535fbcfc..dcf0042f68b 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/ReadablePartitions.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/ReadablePartitions.java
@@ -24,6 +24,7 @@ import com.fasterxml.jackson.annotation.JsonTypeInfo;
 import it.unimi.dsi.fastutil.ints.Int2IntAVLTreeMap;
 import it.unimi.dsi.fastutil.ints.Int2IntSortedMap;
 import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
+import it.unimi.dsi.fastutil.ints.IntSortedSet;
 
 import java.util.Collections;
 import java.util.List;
@@ -39,6 +40,7 @@ import java.util.Map;
 @JsonSubTypes(value = {
     @JsonSubTypes.Type(name = "collected", value = 
CollectedReadablePartitions.class),
     @JsonSubTypes.Type(name = "striped", value = 
StripedReadablePartitions.class),
+    @JsonSubTypes.Type(name = "sparseStriped", value = 
SparseStripedReadablePartitions.class),
     @JsonSubTypes.Type(name = "combined", value = 
CombinedReadablePartitions.class)
 })
 public interface ReadablePartitions extends Iterable<ReadablePartition>
@@ -59,7 +61,7 @@ public interface ReadablePartitions extends 
Iterable<ReadablePartition>
   /**
    * Combines various sets of partitions into a single set.
    */
-  static CombinedReadablePartitions combine(List<ReadablePartitions> 
readablePartitions)
+  static ReadablePartitions combine(List<ReadablePartitions> 
readablePartitions)
   {
     return new CombinedReadablePartitions(readablePartitions);
   }
@@ -68,7 +70,7 @@ public interface ReadablePartitions extends 
Iterable<ReadablePartition>
    * Returns a set of {@code numPartitions} partitions striped across {@code 
numWorkers} workers: each worker contains
    * a "stripe" of each partition.
    */
-  static StripedReadablePartitions striped(
+  static ReadablePartitions striped(
       final int stageNumber,
       final int numWorkers,
       final int numPartitions
@@ -82,11 +84,36 @@ public interface ReadablePartitions extends 
Iterable<ReadablePartition>
     return new StripedReadablePartitions(stageNumber, numWorkers, 
partitionNumbers);
   }
 
+  /**
+   * Returns a set of {@code numPartitions} partitions striped across {@code 
workers}: each worker contains
+   * a "stripe" of each partition.
+   */
+  static ReadablePartitions striped(
+      final int stageNumber,
+      final IntSortedSet workers,
+      final int numPartitions
+  )
+  {
+    final IntAVLTreeSet partitionNumbers = new IntAVLTreeSet();
+    for (int i = 0; i < numPartitions; i++) {
+      partitionNumbers.add(i);
+    }
+
+    if (workers.lastInt() == workers.size() - 1) {
+      // Dense worker set. Use StripedReadablePartitions for compactness (send 
a single number rather than the
+      // entire worker set) and for backwards compatibility (older workers 
cannot understand
+      // SparseStripedReadablePartitions).
+      return new StripedReadablePartitions(stageNumber, workers.size(), 
partitionNumbers);
+    } else {
+      return new SparseStripedReadablePartitions(stageNumber, workers, 
partitionNumbers);
+    }
+  }
+
   /**
    * Returns a set of partitions that have been collected onto specific 
workers: each partition is on exactly
    * one worker.
    */
-  static CollectedReadablePartitions collected(
+  static ReadablePartitions collected(
       final int stageNumber,
       final Map<Integer, Integer> partitionToWorkerMap
   )
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/SparseStripedReadablePartitions.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/SparseStripedReadablePartitions.java
new file mode 100644
index 00000000000..e9a02a7d488
--- /dev/null
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/input/stage/SparseStripedReadablePartitions.java
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.msq.input.stage;
+
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.google.common.collect.Iterators;
+import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
+import it.unimi.dsi.fastutil.ints.IntSortedSet;
+import org.apache.druid.msq.input.SlicerUtils;
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Objects;
+import java.util.Set;
+
+/**
+ * Set of partitions striped across a sparse set of {@code workers}. Each 
worker contains a "stripe" of each partition.
+ *
+ * @see StripedReadablePartitions dense version, where workers from [0..N) are 
all used.
+ */
+public class SparseStripedReadablePartitions implements ReadablePartitions
+{
+  private final int stageNumber;
+  private final IntSortedSet workers;
+  private final IntSortedSet partitionNumbers;
+
+  /**
+   * Constructor. Most callers should use {@link 
ReadablePartitions#striped(int, int, int)} instead, which takes
+   * a partition count rather than a set of partition numbers.
+   */
+  public SparseStripedReadablePartitions(
+      final int stageNumber,
+      final IntSortedSet workers,
+      final IntSortedSet partitionNumbers
+  )
+  {
+    this.stageNumber = stageNumber;
+    this.workers = workers;
+    this.partitionNumbers = partitionNumbers;
+  }
+
+  @JsonCreator
+  private SparseStripedReadablePartitions(
+      @JsonProperty("stageNumber") final int stageNumber,
+      @JsonProperty("workers") final Set<Integer> workers,
+      @JsonProperty("partitionNumbers") final Set<Integer> partitionNumbers
+  )
+  {
+    this(stageNumber, new IntAVLTreeSet(workers), new 
IntAVLTreeSet(partitionNumbers));
+  }
+
+  @Override
+  public Iterator<ReadablePartition> iterator()
+  {
+    return Iterators.transform(
+        partitionNumbers.iterator(),
+        partitionNumber -> ReadablePartition.striped(stageNumber, workers, 
partitionNumber)
+    );
+  }
+
+  @Override
+  public List<ReadablePartitions> split(final int maxNumSplits)
+  {
+    final List<ReadablePartitions> retVal = new ArrayList<>();
+
+    for (List<Integer> entries : 
SlicerUtils.makeSlicesStatic(partitionNumbers.iterator(), maxNumSplits)) {
+      if (!entries.isEmpty()) {
+        retVal.add(new SparseStripedReadablePartitions(stageNumber, workers, 
new IntAVLTreeSet(entries)));
+      }
+    }
+
+    return retVal;
+  }
+
+  @JsonProperty
+  int getStageNumber()
+  {
+    return stageNumber;
+  }
+
+  @JsonProperty
+  IntSortedSet getWorkers()
+  {
+    return workers;
+  }
+
+  @JsonProperty
+  IntSortedSet getPartitionNumbers()
+  {
+    return partitionNumbers;
+  }
+
+  @Override
+  public boolean equals(Object o)
+  {
+    if (this == o) {
+      return true;
+    }
+    if (o == null || getClass() != o.getClass()) {
+      return false;
+    }
+    SparseStripedReadablePartitions that = (SparseStripedReadablePartitions) o;
+    return stageNumber == that.stageNumber
+           && Objects.equals(workers, that.workers)
+           && Objects.equals(partitionNumbers, that.partitionNumbers);
+  }
+
+  @Override
+  public int hashCode()
+  {
+    return Objects.hash(stageNumber, workers, partitionNumbers);
+  }
+
+  @Override
+  public String toString()
+  {
+    return "StripedReadablePartitions{" +
+           "stageNumber=" + stageNumber +
+           ", workers=" + workers +
+           ", partitionNumbers=" + partitionNumbers +
+           '}';
+  }
+}
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStageTracker.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStageTracker.java
index 338a35e0d24..533cb57b97f 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStageTracker.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/ControllerStageTracker.java
@@ -403,7 +403,7 @@ class ControllerStageTracker
       throw new ISE("Stage does not gather result key statistics");
     }
 
-    if (workerNumber < 0 || workerNumber >= workerCount) {
+    if (!workerInputs.workers().contains(workerNumber)) {
       throw new IAE("Invalid workerNumber [%s]", workerNumber);
     }
 
@@ -522,7 +522,7 @@ class ControllerStageTracker
       throw new ISE("Stage does not gather result key statistics");
     }
 
-    if (workerNumber < 0 || workerNumber >= workerCount) {
+    if (!workerInputs.workers().contains(workerNumber)) {
       throw new IAE("Invalid workerNumber [%s]", workerNumber);
     }
 
@@ -656,7 +656,7 @@ class ControllerStageTracker
       throw new ISE("Stage does not gather result key statistics");
     }
 
-    if (workerNumber < 0 || workerNumber >= workerCount) {
+    if (!workerInputs.workers().contains(workerNumber)) {
       throw new IAE("Invalid workerNumber [%s]", workerNumber);
     }
 
@@ -763,7 +763,7 @@ class ControllerStageTracker
     this.resultPartitionBoundaries = clusterByPartitions;
     this.resultPartitions = ReadablePartitions.striped(
         stageDef.getStageNumber(),
-        workerCount,
+        workerInputs.workers(),
         clusterByPartitions.size()
     );
 
@@ -788,7 +788,7 @@ class ControllerStageTracker
       throw DruidException.defensive("Cannot setDoneReadingInput for 
stage[%s], it is not sorting", stageDef.getId());
     }
 
-    if (workerNumber < 0 || workerNumber >= workerCount) {
+    if (!workerInputs.workers().contains(workerNumber)) {
       throw new IAE("Invalid workerNumber[%s] for stage[%s]", workerNumber, 
stageDef.getId());
     }
 
@@ -830,7 +830,7 @@ class ControllerStageTracker
   @SuppressWarnings("unchecked")
   boolean setResultsCompleteForWorker(final int workerNumber, final Object 
resultObject)
   {
-    if (workerNumber < 0 || workerNumber >= workerCount) {
+    if (!workerInputs.workers().contains(workerNumber)) {
       throw new IAE("Invalid workerNumber [%s]", workerNumber);
     }
 
@@ -947,14 +947,18 @@ class ControllerStageTracker
         resultPartitionBoundaries = 
maybeResultPartitionBoundaries.valueOrThrow();
         resultPartitions = ReadablePartitions.striped(
             stageNumber,
-            workerCount,
+            workerInputs.workers(),
             resultPartitionBoundaries.size()
         );
-      } else if (shuffleSpec.kind() == ShuffleKind.MIX) {
-        resultPartitionBoundaries = 
ClusterByPartitions.oneUniversalPartition();
-        resultPartitions = ReadablePartitions.striped(stageNumber, 
workerCount, shuffleSpec.partitionCount());
       } else {
-        resultPartitions = ReadablePartitions.striped(stageNumber, 
workerCount, shuffleSpec.partitionCount());
+        if (shuffleSpec.kind() == ShuffleKind.MIX) {
+          resultPartitionBoundaries = 
ClusterByPartitions.oneUniversalPartition();
+        }
+        resultPartitions = ReadablePartitions.striped(
+            stageNumber,
+            workerInputs.workers(),
+            shuffleSpec.partitionCount()
+        );
       }
     } else {
       // No reshuffling: retain partitioning from nonbroadcast inputs.
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/WorkerInputs.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/WorkerInputs.java
index 83d7a602bc1..8dcaee9c213 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/WorkerInputs.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/kernel/controller/WorkerInputs.java
@@ -24,7 +24,9 @@ import com.google.common.collect.Iterables;
 import it.unimi.dsi.fastutil.ints.Int2IntMap;
 import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap;
 import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
-import it.unimi.dsi.fastutil.ints.IntSet;
+import it.unimi.dsi.fastutil.ints.Int2ObjectSortedMap;
+import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
+import it.unimi.dsi.fastutil.ints.IntSortedSet;
 import it.unimi.dsi.fastutil.objects.ObjectIterator;
 import org.apache.druid.msq.input.InputSlice;
 import org.apache.druid.msq.input.InputSpec;
@@ -45,9 +47,9 @@ import java.util.stream.IntStream;
 public class WorkerInputs
 {
   // Worker number -> input number -> input slice.
-  private final Int2ObjectMap<List<InputSlice>> assignmentsMap;
+  private final Int2ObjectSortedMap<List<InputSlice>> assignmentsMap;
 
-  private WorkerInputs(final Int2ObjectMap<List<InputSlice>> assignmentsMap)
+  private WorkerInputs(final Int2ObjectSortedMap<List<InputSlice>> 
assignmentsMap)
   {
     this.assignmentsMap = assignmentsMap;
   }
@@ -64,7 +66,7 @@ public class WorkerInputs
   )
   {
     // Split each inputSpec and assign to workers. This list maps worker 
number -> input number -> input slice.
-    final Int2ObjectMap<List<InputSlice>> assignmentsMap = new 
Int2ObjectAVLTreeMap<>();
+    final Int2ObjectSortedMap<List<InputSlice>> assignmentsMap = new 
Int2ObjectAVLTreeMap<>();
     final int numInputs = stageDef.getInputSpecs().size();
 
     if (numInputs == 0) {
@@ -117,8 +119,8 @@ public class WorkerInputs
 
     final ObjectIterator<Int2ObjectMap.Entry<List<InputSlice>>> 
assignmentsIterator =
         assignmentsMap.int2ObjectEntrySet().iterator();
+    final IntSortedSet nilWorkers = new IntAVLTreeSet();
 
-    boolean first = true;
     while (assignmentsIterator.hasNext()) {
       final Int2ObjectMap.Entry<List<InputSlice>> entry = 
assignmentsIterator.next();
       final List<InputSlice> slices = entry.getValue();
@@ -130,20 +132,29 @@ public class WorkerInputs
         }
       }
 
-      // Eliminate workers that have no non-nil, non-broadcast inputs. (Except 
the first one, because if all input
-      // is nil, *some* worker has to do *something*.)
-      final boolean hasNonNilNonBroadcastInput =
+      // Identify nil workers (workers with no non-broadcast inputs).
+      final boolean isNilWorker =
           IntStream.range(0, numInputs)
-                   .anyMatch(i ->
-                                 !slices.get(i).equals(NilInputSlice.INSTANCE) 
 // Non-nil
-                                 && 
!stageDef.getBroadcastInputNumbers().contains(i) // Non-broadcast
+                   .allMatch(i ->
+                                 slices.get(i).equals(NilInputSlice.INSTANCE)  
// Nil regular input
+                                 || 
stageDef.getBroadcastInputNumbers().contains(i) // Broadcast
                    );
 
-      if (!first && !hasNonNilNonBroadcastInput) {
-        assignmentsIterator.remove();
+      if (isNilWorker) {
+        nilWorkers.add(entry.getIntKey());
       }
+    }
 
-      first = false;
+    if (nilWorkers.size() == assignmentsMap.size()) {
+      // All workers have nil regular inputs. Remove all workers exept the 
first (*some* worker has to do *something*).
+      final List<InputSlice> firstSlices = 
assignmentsMap.get(nilWorkers.firstInt());
+      assignmentsMap.clear();
+      assignmentsMap.put(nilWorkers.firstInt(), firstSlices);
+    } else {
+      // Remove all nil workers.
+      for (final int nilWorker : nilWorkers) {
+        assignmentsMap.remove(nilWorker);
+      }
     }
 
     return new WorkerInputs(assignmentsMap);
@@ -154,7 +165,7 @@ public class WorkerInputs
     return Preconditions.checkNotNull(assignmentsMap.get(workerNumber), 
"worker [%s]", workerNumber);
   }
 
-  public IntSet workers()
+  public IntSortedSet workers()
   {
     return assignmentsMap.keySet();
   }
diff --git 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/CollectedReadablePartitionsTest.java
 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/CollectedReadablePartitionsTest.java
index 6ed7d2d43d4..d4db7a0a7c5 100644
--- 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/CollectedReadablePartitionsTest.java
+++ 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/CollectedReadablePartitionsTest.java
@@ -33,21 +33,24 @@ public class CollectedReadablePartitionsTest
   @Test
   public void testPartitionToWorkerMap()
   {
-    final CollectedReadablePartitions partitions = 
ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1));
+    final CollectedReadablePartitions partitions =
+        (CollectedReadablePartitions) ReadablePartitions.collected(1, 
ImmutableMap.of(0, 1, 1, 2, 2, 1));
     Assert.assertEquals(ImmutableMap.of(0, 1, 1, 2, 2, 1), 
partitions.getPartitionToWorkerMap());
   }
 
   @Test
   public void testStageNumber()
   {
-    final CollectedReadablePartitions partitions = 
ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1));
+    final CollectedReadablePartitions partitions =
+        (CollectedReadablePartitions) ReadablePartitions.collected(1, 
ImmutableMap.of(0, 1, 1, 2, 2, 1));
     Assert.assertEquals(1, partitions.getStageNumber());
   }
 
   @Test
   public void testSplit()
   {
-    final CollectedReadablePartitions partitions = 
ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1));
+    final CollectedReadablePartitions partitions =
+        (CollectedReadablePartitions) ReadablePartitions.collected(1, 
ImmutableMap.of(0, 1, 1, 2, 2, 1));
 
     Assert.assertEquals(
         ImmutableList.of(
@@ -64,7 +67,8 @@ public class CollectedReadablePartitionsTest
     final ObjectMapper mapper = TestHelper.makeJsonMapper()
                                           .registerModules(new 
MSQIndexingModule().getJacksonModules());
 
-    final CollectedReadablePartitions partitions = 
ReadablePartitions.collected(1, ImmutableMap.of(0, 1, 1, 2, 2, 1));
+    final CollectedReadablePartitions partitions =
+        (CollectedReadablePartitions) ReadablePartitions.collected(1, 
ImmutableMap.of(0, 1, 1, 2, 2, 1));
 
     Assert.assertEquals(
         partitions,
diff --git 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/CombinedReadablePartitionsTest.java
 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/CombinedReadablePartitionsTest.java
index 685f4ff7a8a..16bd047b624 100644
--- 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/CombinedReadablePartitionsTest.java
+++ 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/CombinedReadablePartitionsTest.java
@@ -31,7 +31,7 @@ import org.junit.Test;
 
 public class CombinedReadablePartitionsTest
 {
-  private static final CombinedReadablePartitions PARTITIONS = 
ReadablePartitions.combine(
+  private static final ReadablePartitions PARTITIONS = 
ReadablePartitions.combine(
       ImmutableList.of(
           ReadablePartitions.striped(0, 2, 2),
           ReadablePartitions.striped(1, 2, 4)
diff --git 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/StripedReadablePartitionsTest.java
 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/SparseStripedReadablePartitionsTest.java
similarity index 61%
copy from 
extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/StripedReadablePartitionsTest.java
copy to 
extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/SparseStripedReadablePartitionsTest.java
index 38e0707f5d0..5268fd60180 100644
--- 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/StripedReadablePartitionsTest.java
+++ 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/SparseStripedReadablePartitionsTest.java
@@ -23,44 +23,50 @@ import com.fasterxml.jackson.databind.ObjectMapper;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
 import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
+import it.unimi.dsi.fastutil.ints.IntSet;
 import nl.jqno.equalsverifier.EqualsVerifier;
 import org.apache.druid.msq.guice.MSQIndexingModule;
 import org.apache.druid.segment.TestHelper;
 import org.junit.Assert;
 import org.junit.Test;
 
-public class StripedReadablePartitionsTest
+public class SparseStripedReadablePartitionsTest
 {
   @Test
   public void testPartitionNumbers()
   {
-    final StripedReadablePartitions partitions = ReadablePartitions.striped(1, 
2, 3);
+    final SparseStripedReadablePartitions partitions =
+        (SparseStripedReadablePartitions) ReadablePartitions.striped(1, new 
IntAVLTreeSet(new int[]{1, 3}), 3);
     Assert.assertEquals(ImmutableSet.of(0, 1, 2), 
partitions.getPartitionNumbers());
   }
 
   @Test
-  public void testNumWorkers()
+  public void testWorkers()
   {
-    final StripedReadablePartitions partitions = ReadablePartitions.striped(1, 
2, 3);
-    Assert.assertEquals(2, partitions.getNumWorkers());
+    final SparseStripedReadablePartitions partitions =
+        (SparseStripedReadablePartitions) ReadablePartitions.striped(1, new 
IntAVLTreeSet(new int[]{1, 3}), 3);
+    Assert.assertEquals(IntSet.of(1, 3), partitions.getWorkers());
   }
 
   @Test
   public void testStageNumber()
   {
-    final StripedReadablePartitions partitions = ReadablePartitions.striped(1, 
2, 3);
+    final SparseStripedReadablePartitions partitions =
+        (SparseStripedReadablePartitions) ReadablePartitions.striped(1, new 
IntAVLTreeSet(new int[]{1, 3}), 3);
     Assert.assertEquals(1, partitions.getStageNumber());
   }
 
   @Test
   public void testSplit()
   {
-    final StripedReadablePartitions partitions = ReadablePartitions.striped(1, 
2, 3);
+    final IntAVLTreeSet workers = new IntAVLTreeSet(new int[]{1, 3});
+    final SparseStripedReadablePartitions partitions =
+        (SparseStripedReadablePartitions) ReadablePartitions.striped(1, 
workers, 3);
 
     Assert.assertEquals(
         ImmutableList.of(
-            new StripedReadablePartitions(1, 2, new IntAVLTreeSet(new int[]{0, 
2})),
-            new StripedReadablePartitions(1, 2, new IntAVLTreeSet(new 
int[]{1}))
+            new SparseStripedReadablePartitions(1, workers, new 
IntAVLTreeSet(new int[]{0, 2})),
+            new SparseStripedReadablePartitions(1, workers, new 
IntAVLTreeSet(new int[]{1}))
         ),
         partitions.split(2)
     );
@@ -72,7 +78,8 @@ public class StripedReadablePartitionsTest
     final ObjectMapper mapper = TestHelper.makeJsonMapper()
                                           .registerModules(new 
MSQIndexingModule().getJacksonModules());
 
-    final StripedReadablePartitions partitions = ReadablePartitions.striped(1, 
2, 3);
+    final IntAVLTreeSet workers = new IntAVLTreeSet(new int[]{1, 3});
+    final ReadablePartitions partitions = ReadablePartitions.striped(1, 
workers, 3);
 
     Assert.assertEquals(
         partitions,
@@ -86,6 +93,6 @@ public class StripedReadablePartitionsTest
   @Test
   public void testEquals()
   {
-    
EqualsVerifier.forClass(StripedReadablePartitions.class).usingGetClass().verify();
+    
EqualsVerifier.forClass(SparseStripedReadablePartitions.class).usingGetClass().verify();
   }
 }
diff --git 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/StripedReadablePartitionsTest.java
 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/StripedReadablePartitionsTest.java
index 38e0707f5d0..05b42b33250 100644
--- 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/StripedReadablePartitionsTest.java
+++ 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/input/stage/StripedReadablePartitionsTest.java
@@ -26,36 +26,60 @@ import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
 import nl.jqno.equalsverifier.EqualsVerifier;
 import org.apache.druid.msq.guice.MSQIndexingModule;
 import org.apache.druid.segment.TestHelper;
+import org.hamcrest.CoreMatchers;
+import org.hamcrest.MatcherAssert;
 import org.junit.Assert;
 import org.junit.Test;
 
 public class StripedReadablePartitionsTest
 {
+  @Test
+  public void testFromDenseSet()
+  {
+    // Tests that when ReadablePartitions.striped is called with a dense set, 
we get StripedReadablePartitions.
+
+    final IntAVLTreeSet workers = new IntAVLTreeSet();
+    workers.add(0);
+    workers.add(1);
+
+    final ReadablePartitions readablePartitionsFromSet = 
ReadablePartitions.striped(1, workers, 3);
+
+    MatcherAssert.assertThat(
+        readablePartitionsFromSet,
+        CoreMatchers.instanceOf(StripedReadablePartitions.class)
+    );
+
+    Assert.assertEquals(
+        ReadablePartitions.striped(1, 2, 3),
+        readablePartitionsFromSet
+    );
+  }
+
   @Test
   public void testPartitionNumbers()
   {
-    final StripedReadablePartitions partitions = ReadablePartitions.striped(1, 
2, 3);
+    final StripedReadablePartitions partitions = (StripedReadablePartitions) 
ReadablePartitions.striped(1, 2, 3);
     Assert.assertEquals(ImmutableSet.of(0, 1, 2), 
partitions.getPartitionNumbers());
   }
 
   @Test
   public void testNumWorkers()
   {
-    final StripedReadablePartitions partitions = ReadablePartitions.striped(1, 
2, 3);
+    final StripedReadablePartitions partitions = (StripedReadablePartitions) 
ReadablePartitions.striped(1, 2, 3);
     Assert.assertEquals(2, partitions.getNumWorkers());
   }
 
   @Test
   public void testStageNumber()
   {
-    final StripedReadablePartitions partitions = ReadablePartitions.striped(1, 
2, 3);
+    final StripedReadablePartitions partitions = (StripedReadablePartitions) 
ReadablePartitions.striped(1, 2, 3);
     Assert.assertEquals(1, partitions.getStageNumber());
   }
 
   @Test
   public void testSplit()
   {
-    final StripedReadablePartitions partitions = ReadablePartitions.striped(1, 
2, 3);
+    final ReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3);
 
     Assert.assertEquals(
         ImmutableList.of(
@@ -72,7 +96,7 @@ public class StripedReadablePartitionsTest
     final ObjectMapper mapper = TestHelper.makeJsonMapper()
                                           .registerModules(new 
MSQIndexingModule().getJacksonModules());
 
-    final StripedReadablePartitions partitions = ReadablePartitions.striped(1, 
2, 3);
+    final ReadablePartitions partitions = ReadablePartitions.striped(1, 2, 3);
 
     Assert.assertEquals(
         partitions,
diff --git 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/WorkerInputsTest.java
 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/WorkerInputsTest.java
index 605e0bf2de7..e74125b0830 100644
--- 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/WorkerInputsTest.java
+++ 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/kernel/controller/WorkerInputsTest.java
@@ -25,9 +25,11 @@ import it.unimi.dsi.fastutil.ints.Int2IntMaps;
 import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap;
 import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
 import it.unimi.dsi.fastutil.ints.IntSet;
+import it.unimi.dsi.fastutil.ints.IntSortedSet;
 import it.unimi.dsi.fastutil.longs.LongArrayList;
 import it.unimi.dsi.fastutil.longs.LongList;
 import nl.jqno.equalsverifier.EqualsVerifier;
+import org.apache.druid.error.DruidException;
 import org.apache.druid.msq.exec.Limits;
 import org.apache.druid.msq.exec.OutputChannelMode;
 import org.apache.druid.msq.input.InputSlice;
@@ -75,7 +77,7 @@ public class WorkerInputsTest
     final WorkerInputs inputs = WorkerInputs.create(
         stageDef,
         Int2IntMaps.EMPTY_MAP,
-        new TestInputSpecSlicer(true),
+        new TestInputSpecSlicer(denseWorkers(4), true),
         WorkerAssignmentStrategy.MAX,
         Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
     );
@@ -91,6 +93,35 @@ public class WorkerInputsTest
     );
   }
 
+  @Test
+  public void test_max_threeInputs_fourWorkers_withGaps()
+  {
+    final StageDefinition stageDef =
+        StageDefinition.builder(0)
+                       .inputs(new TestInputSpec(1, 2, 3))
+                       .maxWorkerCount(4)
+                       .processorFactory(new 
OffsetLimitFrameProcessorFactory(0, 0L))
+                       .build(QUERY_ID);
+
+    final WorkerInputs inputs = WorkerInputs.create(
+        stageDef,
+        Int2IntMaps.EMPTY_MAP,
+        new TestInputSpecSlicer(new IntAVLTreeSet(new int[]{1, 3, 4, 5}), 
true),
+        WorkerAssignmentStrategy.MAX,
+        Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
+    );
+
+    Assert.assertEquals(
+        ImmutableMap.<Integer, List<InputSlice>>builder()
+                    .put(1, Collections.singletonList(new TestInputSlice(1)))
+                    .put(3, Collections.singletonList(new TestInputSlice(2)))
+                    .put(4, Collections.singletonList(new TestInputSlice(3)))
+                    .put(5, Collections.singletonList(new TestInputSlice()))
+                    .build(),
+        inputs.assignmentsMap()
+    );
+  }
+
   @Test
   public void test_max_zeroInputs_fourWorkers()
   {
@@ -104,7 +135,7 @@ public class WorkerInputsTest
     final WorkerInputs inputs = WorkerInputs.create(
         stageDef,
         Int2IntMaps.EMPTY_MAP,
-        new TestInputSpecSlicer(true),
+        new TestInputSpecSlicer(denseWorkers(4), true),
         WorkerAssignmentStrategy.MAX,
         Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
     );
@@ -133,7 +164,7 @@ public class WorkerInputsTest
     final WorkerInputs inputs = WorkerInputs.create(
         stageDef,
         Int2IntMaps.EMPTY_MAP,
-        new TestInputSpecSlicer(true),
+        new TestInputSpecSlicer(denseWorkers(4), true),
         WorkerAssignmentStrategy.AUTO,
         Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
     );
@@ -159,7 +190,7 @@ public class WorkerInputsTest
     final WorkerInputs inputs = WorkerInputs.create(
         stageDef,
         Int2IntMaps.EMPTY_MAP,
-        new TestInputSpecSlicer(true),
+        new TestInputSpecSlicer(denseWorkers(4), true),
         WorkerAssignmentStrategy.AUTO,
         Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
     );
@@ -186,7 +217,7 @@ public class WorkerInputsTest
     final WorkerInputs inputs = WorkerInputs.create(
         stageDef,
         Int2IntMaps.EMPTY_MAP,
-        new TestInputSpecSlicer(true),
+        new TestInputSpecSlicer(denseWorkers(4), true),
         WorkerAssignmentStrategy.AUTO,
         Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
     );
@@ -212,7 +243,7 @@ public class WorkerInputsTest
     final WorkerInputs inputs = WorkerInputs.create(
         stageDef,
         Int2IntMaps.EMPTY_MAP,
-        new TestInputSpecSlicer(true),
+        new TestInputSpecSlicer(denseWorkers(4), true),
         WorkerAssignmentStrategy.AUTO,
         Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
     );
@@ -324,7 +355,7 @@ public class WorkerInputsTest
     final WorkerInputs inputs = WorkerInputs.create(
         stageDef,
         Int2IntMaps.EMPTY_MAP,
-        new TestInputSpecSlicer(true),
+        new TestInputSpecSlicer(denseWorkers(4), true),
         WorkerAssignmentStrategy.AUTO,
         Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
     );
@@ -351,7 +382,7 @@ public class WorkerInputsTest
     final WorkerInputs inputs = WorkerInputs.create(
         stageDef,
         Int2IntMaps.EMPTY_MAP,
-        new TestInputSpecSlicer(true),
+        new TestInputSpecSlicer(denseWorkers(2), true),
         WorkerAssignmentStrategy.AUTO,
         Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
     );
@@ -384,7 +415,7 @@ public class WorkerInputsTest
     final WorkerInputs inputs = WorkerInputs.create(
         stageDef,
         Int2IntMaps.EMPTY_MAP,
-        new TestInputSpecSlicer(true),
+        new TestInputSpecSlicer(denseWorkers(1), true),
         WorkerAssignmentStrategy.AUTO,
         Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
     );
@@ -411,7 +442,7 @@ public class WorkerInputsTest
                        .processorFactory(new 
OffsetLimitFrameProcessorFactory(0, 0L))
                        .build(QUERY_ID);
 
-    TestInputSpecSlicer testInputSpecSlicer = spy(new 
TestInputSpecSlicer(true));
+    TestInputSpecSlicer testInputSpecSlicer = spy(new 
TestInputSpecSlicer(denseWorkers(3), true));
 
     final WorkerInputs inputs = WorkerInputs.create(
         stageDef,
@@ -455,7 +486,7 @@ public class WorkerInputsTest
                        .processorFactory(new 
OffsetLimitFrameProcessorFactory(0, 0L))
                        .build(QUERY_ID);
 
-    TestInputSpecSlicer testInputSpecSlicer = spy(new 
TestInputSpecSlicer(true));
+    TestInputSpecSlicer testInputSpecSlicer = spy(new 
TestInputSpecSlicer(denseWorkers(3), true));
 
     final WorkerInputs inputs = WorkerInputs.create(
         stageDef,
@@ -498,7 +529,7 @@ public class WorkerInputsTest
                        .processorFactory(new 
OffsetLimitFrameProcessorFactory(0, 0L))
                        .build(QUERY_ID);
 
-    TestInputSpecSlicer testInputSpecSlicer = spy(new 
TestInputSpecSlicer(true));
+    TestInputSpecSlicer testInputSpecSlicer = spy(new 
TestInputSpecSlicer(denseWorkers(3), true));
 
     final WorkerInputs inputs = WorkerInputs.create(
         stageDef,
@@ -585,11 +616,23 @@ public class WorkerInputsTest
 
   private static class TestInputSpecSlicer implements InputSpecSlicer
   {
+    private final IntSortedSet workers;
     private final boolean canSliceDynamic;
 
-    public TestInputSpecSlicer(boolean canSliceDynamic)
+    /**
+     * Create a test slicer.
+     *
+     * @param workers         Set of workers to consider assigning work to.
+     * @param canSliceDynamic Whether this slicer can slice dynamically.
+     */
+    public TestInputSpecSlicer(final IntSortedSet workers, final boolean 
canSliceDynamic)
     {
+      this.workers = workers;
       this.canSliceDynamic = canSliceDynamic;
+
+      if (workers.isEmpty()) {
+        throw DruidException.defensive("Need more than one worker in 
workers[%s]", workers);
+      }
     }
 
     @Override
@@ -606,9 +649,9 @@ public class WorkerInputsTest
           SlicerUtils.makeSlicesStatic(
               testInputSpec.values.iterator(),
               i -> i,
-              maxNumSlices
+              Math.min(maxNumSlices, workers.size())
           );
-      return makeSlices(assignments);
+      return makeSlices(workers, assignments);
     }
 
     @Override
@@ -624,24 +667,39 @@ public class WorkerInputsTest
           SlicerUtils.makeSlicesDynamic(
               testInputSpec.values.iterator(),
               i -> i,
-              maxNumSlices,
+              Math.min(maxNumSlices, workers.size()),
               maxFilesPerSlice,
               maxBytesPerSlice
           );
-      return makeSlices(assignments);
+      return makeSlices(workers, assignments);
     }
 
     private static List<InputSlice> makeSlices(
+        final IntSortedSet workers,
         final List<List<Long>> assignments
     )
     {
       final List<InputSlice> retVal = new ArrayList<>(assignments.size());
-
-      for (final List<Long> assignment : assignments) {
-        retVal.add(new TestInputSlice(new LongArrayList(assignment)));
+      for (int assignment = 0, workerNumber = 0;
+           workerNumber <= workers.lastInt() && assignment < 
assignments.size();
+           workerNumber++) {
+        if (workers.contains(workerNumber)) {
+          retVal.add(new TestInputSlice(new 
LongArrayList(assignments.get(assignment++))));
+        } else {
+          retVal.add(NilInputSlice.INSTANCE);
+        }
       }
 
       return retVal;
     }
   }
+
+  private static IntSortedSet denseWorkers(final int numWorkers)
+  {
+    final IntAVLTreeSet workers = new IntAVLTreeSet();
+    for (int i = 0; i < numWorkers; i++) {
+      workers.add(i);
+    }
+    return workers;
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to