imply-cheddar commented on a change in pull request #12236:
URL: https://github.com/apache/druid/pull/12236#discussion_r803294216



##########
File path: 
indexing-service/src/main/java/org/apache/druid/indexing/common/task/batch/parallel/ParallelIndexSupervisorTask.java
##########
@@ -923,58 +924,66 @@ private PartitionBoundaries 
determineRangePartition(Collection<StringDistributio
     return partitions;
   }
 
-  private static Map<Pair<Interval, Integer>, List<PartitionLocation>> 
groupGenericPartitionLocationsPerPartition(
+  /**
+   * Creates a map from partition (interval + bucketId) to the corresponding
+   * PartitionLocations. Note that the bucketId maybe different from the final
+   * partitionId (refer to {@link BuildingShardSpec} for more details).
+   */
+  static Map<Partition, List<PartitionLocation>> getPartitionToLocations(
       Map<String, GeneratedPartitionsReport> subTaskIdToReport
   )
   {
-    final Map<Pair<Interval, Integer>, BuildingShardSpec<?>> 
intervalAndIntegerToShardSpec = new HashMap<>();
-    final Object2IntMap<Interval> intervalToNextPartitionId = new 
Object2IntOpenHashMap<>();
-    final BiFunction<String, PartitionStat, PartitionLocation> 
createPartitionLocationFunction =
-        (subtaskId, partitionStat) -> {
-          final BuildingShardSpec<?> shardSpec = 
intervalAndIntegerToShardSpec.computeIfAbsent(
-              Pair.of(partitionStat.getInterval(), 
partitionStat.getBucketId()),
-              key -> {
-                // Lazily determine the partitionId to create packed 
partitionIds for the core partitions.
-                // See the Javadoc of BucketNumberedShardSpec for details.
-                final int partitionId = intervalToNextPartitionId.computeInt(
-                    partitionStat.getInterval(),
-                    ((interval, nextPartitionId) -> nextPartitionId == null ? 
0 : nextPartitionId + 1)
-                );
-                return 
partitionStat.getSecondaryPartition().convert(partitionId);
-              }
-          );
-          return partitionStat.toPartitionLocation(subtaskId, shardSpec);
-        };
+    // Create a map from partition to list of reports (PartitionStat and 
subTaskId)
+    final Map<Partition, List<PartitionReport>> partitionToReports = new 
TreeMap<>(
+        // Sort by (interval, bucketId) to maintain order of partitionIds 
within interval
+        Comparator
+            .comparingLong((Partition partition) -> 
partition.getInterval().getStartMillis())
+            .thenComparingLong(partition -> 
partition.getInterval().getEndMillis())
+            .thenComparingInt(Partition::getBucketId)
+    );
+    subTaskIdToReport.forEach(
+        (subTaskId, report) -> report.getPartitionStats().forEach(
+            partitionStat -> partitionToReports
+                .computeIfAbsent(Partition.fromStat(partitionStat), p -> new 
ArrayList<>())
+                .add(new PartitionReport(subTaskId, partitionStat))
+        )
+    );
 
-    return groupPartitionLocationsPerPartition(subTaskIdToReport, 
createPartitionLocationFunction);
-  }
+    final Map<Partition, List<PartitionLocation>> partitionToLocations = new 
HashMap<>();
 
-  private static <L extends PartitionLocation>
-      Map<Pair<Interval, Integer>, List<L>> 
groupPartitionLocationsPerPartition(
-      Map<String, ? extends GeneratedPartitionsReport> subTaskIdToReport,
-      BiFunction<String, PartitionStat, L> createPartitionLocationFunction
-  )
-  {
-    // partition (interval, partitionId) -> partition locations
-    final Map<Pair<Interval, Integer>, List<L>> partitionToLocations = new 
HashMap<>();
-    for (Entry<String, ? extends GeneratedPartitionsReport> entry : 
subTaskIdToReport.entrySet()) {
-      final String subTaskId = entry.getKey();
-      final GeneratedPartitionsReport report = entry.getValue();
-      for (PartitionStat partitionStat : report.getPartitionStats()) {
-        final List<L> locationsOfSamePartition = 
partitionToLocations.computeIfAbsent(
-            Pair.of(partitionStat.getInterval(), partitionStat.getBucketId()),
-            k -> new ArrayList<>()
-        );
-        
locationsOfSamePartition.add(createPartitionLocationFunction.apply(subTaskId, 
partitionStat));
+    Interval prevInterval = null;
+    final AtomicInteger partitionId = new AtomicInteger(0);
+    for (Entry<Partition, List<PartitionReport>> entry : 
partitionToReports.entrySet()) {
+      final Partition partition = entry.getKey();
+
+      // Reset the partitionId if this is a new interval
+      Interval interval = partition.getInterval();
+      if (!interval.equals(prevInterval)) {
+        partitionId.set(0);
+        prevInterval = interval;
       }
+
+      // Use any PartitionStat of this partition to create a shard spec
+      final List<PartitionReport> reportsOfPartition = entry.getValue();
+      final BuildingShardSpec<?> shardSpec = reportsOfPartition
+          .get(0).getPartitionStat().getSecondaryPartition()
+          .convert(partitionId.getAndIncrement());

Review comment:
       Okay, looking at this code, we seem to be passing in what is supposed to 
be a new `bucketId` into `convert()`.  However, when I look at 
`BuildingShardSpec.convert()` the argument is called `numCorePartitions` which 
would be the total number of partitions, not the actual bucket number...  This 
seems like a bug to me and now has me wondering if the previous code was 
actually doing the right thing...




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to