This is an automated email from the ASF dual-hosted git repository.
gian pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git
The following commit(s) were added to refs/heads/master by this push:
new b2eed67a2c5 feat: Combiner for MSQ groupBys. (#19193)
b2eed67a2c5 is described below
commit b2eed67a2c5c93934915e4890ff92c7586fa7a5e
Author: Gian Merlino <[email protected]>
AuthorDate: Thu Mar 26 09:04:11 2026 -0700
feat: Combiner for MSQ groupBys. (#19193)
This patch adds optional row combining during merge-sort for MSQ groupBy.
When enabled using the "useCombiner" parameter (default: false),
FrameChannelMerger detects adjacent rows with identical sort keys and
combines them using a FrameCombiner provided by the query logic.
Main files:
1) FrameCombiner: stateful row combiner interface.
2) FrameChannelMerger: logic for determining which rows to combine.
3) GroupByFrameCombiner: implementation for the groupBy query.
---
.../frame/FrameChannelMergerBenchmark.java | 1 +
.../msq/exec/std/StandardPartitionReader.java | 13 +
.../msq/exec/std/StandardShuffleOperations.java | 16 +-
.../druid/msq/exec/std/StandardStageRunner.java | 17 +-
.../processor/SegmentGeneratorStageProcessor.java | 14 +-
.../druid/msq/querykit/BaseLeafStageProcessor.java | 15 +-
.../apache/druid/msq/querykit/QueryKitUtils.java | 22 +-
.../apache/druid/msq/querykit/ReadableInput.java | 60 +--
.../druid/msq/querykit/ReadableInputQueue.java | 11 +-
.../WindowOperatorQueryStageProcessor.java | 6 +-
.../querykit/common/OffsetLimitStageProcessor.java | 4 +-
.../common/SortMergeJoinStageProcessor.java | 10 +-
.../msq/querykit/groupby/GroupByFrameCombiner.java | 281 +++++++++++++
.../groupby/GroupByPostShuffleStageProcessor.java | 20 +-
.../groupby/GroupByPreShuffleStageProcessor.java | 22 ++
.../results/ExportResultsStageProcessor.java | 8 +-
.../results/QueryResultStageProcessor.java | 6 +-
.../druid/msq/util/MultiStageQueryContext.java | 14 +
.../org/apache/druid/msq/exec/MSQInsertTest.java | 3 +-
.../org/apache/druid/msq/exec/MSQSelectTest.java | 4 +-
.../druid/msq/querykit/FrameProcessorTestBase.java | 7 +-
.../common/SortMergeJoinFrameProcessorTest.java | 8 +-
.../results/QueryResultsFrameProcessorTest.java | 7 +-
.../querykit/scan/ScanQueryFrameProcessorTest.java | 6 +-
.../org/apache/druid/msq/test/MSQTestBase.java | 7 +
.../druid/frame/processor/FrameChannelMerger.java | 439 +++++++++++++++++----
.../druid/frame/processor/FrameCombiner.java | 60 +++
.../frame/processor/FrameCombinerFactory.java | 30 ++
.../apache/druid/frame/processor/SuperSorter.java | 7 +-
.../apache/druid/frame/read/FrameReaderUtils.java | 17 +-
.../frame/processor/SummingFrameCombiner.java | 165 ++++++++
.../druid/frame/processor/SuperSorterTest.java | 390 +++++++++++++++++-
32 files changed, 1517 insertions(+), 173 deletions(-)
diff --git
a/benchmarks/src/test/java/org/apache/druid/benchmark/frame/FrameChannelMergerBenchmark.java
b/benchmarks/src/test/java/org/apache/druid/benchmark/frame/FrameChannelMergerBenchmark.java
index b24dc2f2a66..a77859d6ec3 100644
---
a/benchmarks/src/test/java/org/apache/druid/benchmark/frame/FrameChannelMergerBenchmark.java
+++
b/benchmarks/src/test/java/org/apache/druid/benchmark/frame/FrameChannelMergerBenchmark.java
@@ -359,6 +359,7 @@ public class FrameChannelMergerBenchmark
),
sortKey,
null,
+ null,
-1
);
diff --git
a/multi-stage-query/src/main/java/org/apache/druid/msq/exec/std/StandardPartitionReader.java
b/multi-stage-query/src/main/java/org/apache/druid/msq/exec/std/StandardPartitionReader.java
index 87f1fcc143e..895d31e6527 100644
---
a/multi-stage-query/src/main/java/org/apache/druid/msq/exec/std/StandardPartitionReader.java
+++
b/multi-stage-query/src/main/java/org/apache/druid/msq/exec/std/StandardPartitionReader.java
@@ -28,6 +28,7 @@ import org.apache.druid.frame.channel.ReadableFrameChannel;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.processor.FrameChannelMerger;
import org.apache.druid.frame.processor.FrameChannelMixer;
+import org.apache.druid.frame.processor.FrameCombinerFactory;
import org.apache.druid.frame.processor.FrameProcessorExecutor;
import org.apache.druid.frame.read.FrameReader;
import org.apache.druid.frame.write.FrameWriters;
@@ -63,6 +64,8 @@ public class StandardPartitionReader
@Nullable
private final CounterTracker counters;
private final MemoryAllocatorFactory allocatorFactory;
+ @Nullable
+ private FrameCombinerFactory combinerFactory;
public StandardPartitionReader(ExecutionContext executionContext)
{
@@ -95,6 +98,15 @@ public class StandardPartitionReader
this.allocatorFactory = allocatorFactory;
}
+ /**
+ * Set a combiner for sorted merges.
+ */
+ public StandardPartitionReader setCombiner(@Nullable final
FrameCombinerFactory combinerFactory)
+ {
+ this.combinerFactory = combinerFactory;
+ return this;
+ }
+
public ReadableFrameChannel openChannel(final ReadablePartition
readablePartition) throws IOException
{
final StageDefinition stageDef =
queryDef.getStageDefinition(readablePartition.getStageNumber());
@@ -142,6 +154,7 @@ public class StandardPartitionReader
frameWriterSpec.getRemoveNullBytes()
),
stageDefinition.getSortKey(),
+ combinerFactory != null ? combinerFactory.newCombiner() : null,
null,
-1
);
diff --git
a/multi-stage-query/src/main/java/org/apache/druid/msq/exec/std/StandardShuffleOperations.java
b/multi-stage-query/src/main/java/org/apache/druid/msq/exec/std/StandardShuffleOperations.java
index 10cbdabfd66..59618401406 100644
---
a/multi-stage-query/src/main/java/org/apache/druid/msq/exec/std/StandardShuffleOperations.java
+++
b/multi-stage-query/src/main/java/org/apache/druid/msq/exec/std/StandardShuffleOperations.java
@@ -33,6 +33,7 @@ import org.apache.druid.frame.key.ClusterByPartitions;
import org.apache.druid.frame.processor.Bouncer;
import org.apache.druid.frame.processor.FrameChannelHashPartitioner;
import org.apache.druid.frame.processor.FrameChannelMixer;
+import org.apache.druid.frame.processor.FrameCombinerFactory;
import org.apache.druid.frame.processor.FrameProcessor;
import org.apache.druid.frame.processor.FrameProcessorDecorator;
import org.apache.druid.frame.processor.OutputChannel;
@@ -59,6 +60,7 @@ import org.apache.druid.msq.kernel.WorkOrder;
import org.apache.druid.msq.statistics.ClusterByStatisticsCollector;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
+import javax.annotation.Nullable;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
@@ -73,11 +75,17 @@ public class StandardShuffleOperations
{
private final ExecutionContext executionContext;
private final WorkOrder workOrder;
+ @Nullable
+ private final FrameCombinerFactory combinerFactory;
- public StandardShuffleOperations(final ExecutionContext executionContext)
+ public StandardShuffleOperations(
+ final ExecutionContext executionContext,
+ @Nullable final FrameCombinerFactory combinerFactory
+ )
{
this.executionContext = executionContext;
this.workOrder = executionContext.workOrder();
+ this.combinerFactory = combinerFactory;
}
/**
@@ -224,7 +232,8 @@ public class StandardShuffleOperations
stageDefinition.getShuffleSpec().limitHint(),
executionContext.cancellationId(),
executionContext.counters().sortProgress(),
-
executionContext.frameContext().frameWriterSpec().getRemoveNullBytes()
+
executionContext.frameContext().frameWriterSpec().getRemoveNullBytes(),
+ combinerFactory
);
return FutureUtils.transform(
@@ -402,7 +411,8 @@ public class StandardShuffleOperations
// There's a single SuperSorterProgressTrackerCounter
per worker, but workers that do local
// sorting have a SuperSorter per partition.
new SuperSorterProgressTracker(),
-
executionContext.frameContext().frameWriterSpec().getRemoveNullBytes()
+
executionContext.frameContext().frameWriterSpec().getRemoveNullBytes(),
+ combinerFactory
);
return FutureUtils.transform(sorter.run(), r ->
Iterables.getOnlyElement(r.getAllChannels()));
diff --git
a/multi-stage-query/src/main/java/org/apache/druid/msq/exec/std/StandardStageRunner.java
b/multi-stage-query/src/main/java/org/apache/druid/msq/exec/std/StandardStageRunner.java
index 7216a1bff82..4f2024bd8df 100644
---
a/multi-stage-query/src/main/java/org/apache/druid/msq/exec/std/StandardStageRunner.java
+++
b/multi-stage-query/src/main/java/org/apache/druid/msq/exec/std/StandardStageRunner.java
@@ -24,6 +24,7 @@ import com.google.common.util.concurrent.ListenableFuture;
import org.apache.druid.common.guava.FutureUtils;
import org.apache.druid.frame.channel.BlockingQueueFrameChannel;
import org.apache.druid.frame.processor.BlockingQueueOutputChannelFactory;
+import org.apache.druid.frame.processor.FrameCombinerFactory;
import org.apache.druid.frame.processor.OutputChannelFactory;
import org.apache.druid.frame.processor.manager.ProcessorManager;
import org.apache.druid.java.util.common.UOE;
@@ -37,6 +38,8 @@ import org.apache.druid.msq.kernel.ShuffleSpec;
import org.apache.druid.msq.kernel.StageDefinition;
import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
+import javax.annotation.Nullable;
+
/**
* Runner for {@link StageProcessor} that build a {@link
ProcessorsAndChannels} for some shuffle-agnostic work.
* The shuffle-related work is then taken care of generically in a "standard"
way, hence the name.
@@ -51,6 +54,9 @@ public class StandardStageRunner<T, R>
private final int threadCount;
private final FrameContext frameContext;
+ @Nullable
+ private FrameCombinerFactory combinerFactory;
+
@MonotonicNonNull
private OutputChannelFactory workOutputChannelFactory;
@MonotonicNonNull
@@ -65,6 +71,15 @@ public class StandardStageRunner<T, R>
this.frameContext = executionContext.frameContext();
}
+ /**
+ * Set a combiner for sorted shuffle.
+ */
+ public StandardStageRunner<T, R> setCombiner(@Nullable final
FrameCombinerFactory combinerFactory)
+ {
+ this.combinerFactory = combinerFactory;
+ return this;
+ }
+
/**
* Start execution.
*
@@ -162,7 +177,7 @@ public class StandardStageRunner<T, R>
private void makeAndRunShuffleProcessors()
{
final ShuffleSpec shuffleSpec =
executionContext.workOrder().getStageDefinition().getShuffleSpec();
- final StandardShuffleOperations stageOperations = new
StandardShuffleOperations(executionContext);
+ final StandardShuffleOperations stageOperations = new
StandardShuffleOperations(executionContext, combinerFactory);
pipelineFuture =
stageOperations.gatherResultKeyStatisticsIfNeeded(pipelineFuture);
diff --git
a/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/processor/SegmentGeneratorStageProcessor.java
b/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/processor/SegmentGeneratorStageProcessor.java
index c80c03268fc..9699020f92d 100644
---
a/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/processor/SegmentGeneratorStageProcessor.java
+++
b/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/processor/SegmentGeneratorStageProcessor.java
@@ -45,10 +45,10 @@ import org.apache.druid.msq.exec.FrameContext;
import org.apache.druid.msq.exec.StageProcessor;
import org.apache.druid.msq.exec.WorkerMemoryParameters;
import org.apache.druid.msq.exec.std.ProcessorsAndChannels;
+import org.apache.druid.msq.exec.std.StandardPartitionReader;
import org.apache.druid.msq.exec.std.StandardStageRunner;
import org.apache.druid.msq.indexing.MSQTuningConfig;
import org.apache.druid.msq.input.stage.StageInputSlice;
-import org.apache.druid.msq.kernel.StagePartition;
import org.apache.druid.msq.querykit.QueryKitUtils;
import org.apache.druid.msq.querykit.ReadableInput;
import org.apache.druid.segment.IndexSpec;
@@ -157,7 +157,7 @@ public class SegmentGeneratorStageProcessor implements
StageProcessor<Set<DataSe
// Expect a single input slice.
final StageInputSlice slice = (StageInputSlice)
Iterables.getOnlyElement(context.workOrder().getInputs());
final Sequence<Pair<Integer, ReadableInput>> inputSequence =
- QueryKitUtils.readPartitions(context, slice.getPartitions())
+ QueryKitUtils.readPartitions(new StandardPartitionReader(context),
slice.getPartitions())
.map(
new Function<>()
{
@@ -177,9 +177,15 @@ public class SegmentGeneratorStageProcessor implements
StageProcessor<Set<DataSe
final Sequence<SegmentGeneratorFrameProcessor> workers = inputSequence.map(
readableInputPair -> {
- final StagePartition stagePartition =
Preconditions.checkNotNull(readableInputPair.rhs.getStagePartition());
+ final ReadableInput readableInput = readableInputPair.rhs;
final SegmentIdWithShardSpec segmentIdWithShardSpec =
extra.get(readableInputPair.lhs);
- final String idString = StringUtils.format("%s:%s", stagePartition,
context.workOrder().getWorkerNumber());
+ final String idString = StringUtils.format(
+ "%s_%s_%s:%s",
+ context.workOrder().getStageDefinition().getId().getQueryId(),
+ readableInput.getStageNumber(),
+ readableInput.getPartitionNumber(),
+ context.workOrder().getWorkerNumber()
+ );
final File persistDirectory = new File(
frameContext.persistDir(),
segmentIdWithShardSpec.asSegmentId().toString()
diff --git
a/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafStageProcessor.java
b/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafStageProcessor.java
index 357829cdd29..1748b528113 100644
---
a/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafStageProcessor.java
+++
b/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BaseLeafStageProcessor.java
@@ -86,6 +86,7 @@ public abstract class BaseLeafStageProcessor extends
BasicStageProcessor
public ListenableFuture<Long> execute(ExecutionContext context)
{
final StandardStageRunner<Object, Long> stageRunner = new
StandardStageRunner<>(context);
+ configureStageRunner(stageRunner, context);
final List<InputSlice> inputSlices = context.workOrder().getInputs();
final StageDefinition stageDefinition =
context.workOrder().getStageDefinition();
final FrameContext frameContext = context.frameContext();
@@ -228,6 +229,17 @@ public abstract class BaseLeafStageProcessor extends
BasicStageProcessor
);
}
+ /**
+ * Hook for subclasses to configure the stage runner before execution.
+ */
+ protected void configureStageRunner(
+ final StandardStageRunner<Object, Long> stageRunner,
+ final ExecutionContext context
+ )
+ {
+ // Default: no-op
+ }
+
protected abstract FrameProcessor<Object> makeProcessor(
ReadableInput baseInput,
SegmentMapFunction segmentMapFn,
@@ -278,7 +290,6 @@ public abstract class BaseLeafStageProcessor extends
BasicStageProcessor
final Integer segmentLoadAheadCount =
MultiStageQueryContext.getSegmentLoadAheadCount(context.workOrder().getWorkerContext());
return new ReadableInputQueue(
- stageDef.getId().getQueryId(),
new StandardPartitionReader(context),
filteredSlices,
segmentLoadAheadCount != null ? segmentLoadAheadCount :
context.threadCount()
@@ -319,7 +330,7 @@ public abstract class BaseLeafStageProcessor extends
BasicStageProcessor
)
);
final FrameReader frameReader =
partitionReader.frameReader(slice.getStageNumber());
- broadcastInputs.put(inputNumber, ReadableInput.channel(channel,
frameReader, null));
+ broadcastInputs.put(inputNumber, ReadableInput.channel(channel,
frameReader, slice.getStageNumber(), ReadableInput.NO_PARTITION));
}
}
diff --git
a/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/QueryKitUtils.java
b/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/QueryKitUtils.java
index 478a7cfb554..4ba1ba8ab6a 100644
---
a/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/QueryKitUtils.java
+++
b/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/QueryKitUtils.java
@@ -34,14 +34,11 @@ import
org.apache.druid.java.util.common.granularity.PeriodGranularity;
import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.math.expr.ExprMacroTable;
-import org.apache.druid.msq.exec.ExecutionContext;
import org.apache.druid.msq.exec.std.StandardPartitionReader;
import org.apache.druid.msq.indexing.error.ColumnNameRestrictedFault;
import org.apache.druid.msq.indexing.error.MSQException;
import org.apache.druid.msq.input.stage.ReadablePartition;
import org.apache.druid.msq.input.stage.ReadablePartitions;
-import org.apache.druid.msq.kernel.StageId;
-import org.apache.druid.msq.kernel.StagePartition;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.expression.TimestampFloorExprMacro;
import org.apache.druid.segment.VirtualColumn;
@@ -237,23 +234,19 @@ public class QueryKitUtils
}
/**
- * Create a sequence of {@link ReadableInput} corresponding to {@link
ReadablePartitions}, read with a standard merger.
+ * Create a {@link ReadableInput} for a single {@link ReadablePartition},
read with the given partition reader.
*/
public static ReadableInput readPartition(
- final ExecutionContext context,
+ final StandardPartitionReader partitionReader,
final ReadablePartition readablePartition
)
{
- final StandardPartitionReader partitionReader = new
StandardPartitionReader(context);
- final String queryId =
context.workOrder().getStageDefinition().getId().getQueryId();
try {
return ReadableInput.channel(
partitionReader.openChannel(readablePartition),
partitionReader.frameReader(readablePartition.getStageNumber()),
- new StagePartition(
- new StageId(queryId, readablePartition.getStageNumber()),
- readablePartition.getPartitionNumber()
- )
+ readablePartition.getStageNumber(),
+ readablePartition.getPartitionNumber()
);
}
catch (IOException e) {
@@ -262,13 +255,14 @@ public class QueryKitUtils
}
/**
- * Create a sequence of {@link ReadableInput} corresponding to {@link
ReadablePartitions}, read with a standard merger.
+ * Create a sequence of {@link ReadableInput} corresponding to {@link
ReadablePartitions}, read with the given
+ * partition reader.
*/
public static Sequence<ReadableInput> readPartitions(
- final ExecutionContext context,
+ final StandardPartitionReader partitionReader,
final ReadablePartitions readablePartitions
)
{
- return Sequences.simple(readablePartitions).map(partition ->
readPartition(context, partition));
+ return Sequences.simple(readablePartitions).map(partition ->
readPartition(partitionReader, partition));
}
}
diff --git
a/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/ReadableInput.java
b/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/ReadableInput.java
index 3095c13aa78..348ce224462 100644
---
a/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/ReadableInput.java
+++
b/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/ReadableInput.java
@@ -26,7 +26,6 @@ import org.apache.druid.java.util.common.ISE;
import org.apache.druid.msq.exec.DataServerQueryHandler;
import org.apache.druid.msq.input.InputSlice;
import org.apache.druid.msq.input.InputSliceReader;
-import org.apache.druid.msq.kernel.StagePartition;
import javax.annotation.Nullable;
@@ -38,6 +37,16 @@ import javax.annotation.Nullable;
*/
public class ReadableInput
{
+ /**
+ * Constant indicating that no stage number is associated with this input.
+ */
+ public static final int NO_STAGE = -1;
+
+ /**
+ * Constant indicating that no partition number is associated with this
input.
+ */
+ public static final int NO_PARTITION = -1;
+
@Nullable
private final SegmentReferenceHolder segment;
@@ -50,22 +59,24 @@ public class ReadableInput
@Nullable
private final FrameReader frameReader;
- @Nullable
- private final StagePartition stagePartition;
+ private final int stageNumber;
+ private final int partitionNumber;
private ReadableInput(
@Nullable SegmentReferenceHolder segment,
@Nullable DataServerQueryHandler dataServerQuery,
@Nullable ReadableFrameChannel channel,
@Nullable FrameReader frameReader,
- @Nullable StagePartition stagePartition
+ final int stageNumber,
+ final int partitionNumber
)
{
this.segment = segment;
this.dataServerQuery = dataServerQuery;
this.channel = channel;
this.frameReader = frameReader;
- this.stagePartition = stagePartition;
+ this.stageNumber = stageNumber;
+ this.partitionNumber = partitionNumber;
if ((segment == null) && (channel == null) && (dataServerQuery == null)) {
throw new ISE("Provide 'segment', 'dataServerQuery' or 'channel'");
@@ -79,7 +90,7 @@ public class ReadableInput
*/
public static ReadableInput segment(final SegmentReferenceHolder segment)
{
- return new ReadableInput(Preconditions.checkNotNull(segment, "segment"),
null, null, null, null);
+ return new ReadableInput(Preconditions.checkNotNull(segment, "segment"),
null, null, null, NO_STAGE, NO_PARTITION);
}
/**
@@ -89,21 +100,22 @@ public class ReadableInput
*/
public static ReadableInput dataServerQuery(final DataServerQueryHandler
dataServerQueryHandler)
{
- return new ReadableInput(null,
Preconditions.checkNotNull(dataServerQueryHandler, "dataServerQuery"), null,
null, null);
+ return new ReadableInput(null,
Preconditions.checkNotNull(dataServerQueryHandler, "dataServerQuery"), null,
null, NO_STAGE, NO_PARTITION);
}
/**
* Create an input associated with a channel.
*
- * @param channel the channel
- * @param frameReader reader for the channel
- * @param stagePartition stage-partition associated with the channel, if
meaningful. May be null if this channel
- * does not correspond to any one particular
stage-partition.
+ * @param channel the channel
+ * @param frameReader reader for the channel
+ * @param stageNumber stage number associated with the channel
+ * @param partitionNumber partition number associated with the channel
*/
public static ReadableInput channel(
final ReadableFrameChannel channel,
final FrameReader frameReader,
- @Nullable final StagePartition stagePartition
+ final int stageNumber,
+ final int partitionNumber
)
{
return new ReadableInput(
@@ -111,7 +123,8 @@ public class ReadableInput
null,
Preconditions.checkNotNull(channel, "channel"),
Preconditions.checkNotNull(frameReader, "frameReader"),
- stagePartition
+ stageNumber,
+ partitionNumber
);
}
@@ -132,7 +145,7 @@ public class ReadableInput
}
/**
- * Whether this input is a channel (from {@link
#channel(ReadableFrameChannel, FrameReader, StagePartition)}.
+ * Whether this input is a channel (from {@link
#channel(ReadableFrameChannel, FrameReader, int, int)}.
*/
public boolean hasChannel()
{
@@ -176,18 +189,21 @@ public class ReadableInput
}
/**
- * The stage-partition this input. Only valid if {@link #hasChannel()}, and
if a stage-partition was provided
- * during construction. Throws {@link IllegalStateException} if no
stage-partition was provided during construction.
+ * The stage number for this input. Only valid if {@link #hasChannel()}.
*/
- public StagePartition getStagePartition()
+ public int getStageNumber()
{
checkIsChannel();
+ return stageNumber;
+ }
- if (stagePartition == null) {
- throw new ISE("Stage-partition is not set for this channel");
- }
-
- return stagePartition;
+ /**
+ * The partition number for this input. Only valid if {@link #hasChannel()}.
+ */
+ public int getPartitionNumber()
+ {
+ checkIsChannel();
+ return partitionNumber;
}
private void checkIsSegment()
diff --git
a/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/ReadableInputQueue.java
b/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/ReadableInputQueue.java
index 4b3ed148596..b965e8babc2 100644
---
a/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/ReadableInputQueue.java
+++
b/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/ReadableInputQueue.java
@@ -33,8 +33,6 @@ import org.apache.druid.msq.exec.std.StandardPartitionReader;
import org.apache.druid.msq.input.LoadableSegment;
import org.apache.druid.msq.input.PhysicalInputSlice;
import org.apache.druid.msq.input.stage.ReadablePartition;
-import org.apache.druid.msq.kernel.StageId;
-import org.apache.druid.msq.kernel.StagePartition;
import org.apache.druid.segment.Segment;
import org.apache.druid.segment.SegmentReference;
import org.apache.druid.segment.loading.AcquireSegmentAction;
@@ -99,19 +97,16 @@ public class ReadableInputQueue implements Closeable
@GuardedBy("this")
private final Set<ListenableFuture<ReadableInput>> pendingNextInputs =
Sets.newIdentityHashSet();
- private final String queryId;
private final StandardPartitionReader partitionReader;
private final int loadahead;
private final AtomicBoolean started = new AtomicBoolean(false);
public ReadableInputQueue(
- final String queryId,
final StandardPartitionReader partitionReader,
final List<PhysicalInputSlice> slices,
final int loadahead
)
{
- this.queryId = queryId;
this.partitionReader = partitionReader;
this.loadahead = loadahead;
@@ -242,10 +237,8 @@ public class ReadableInputQueue implements Closeable
ReadableInput.channel(
channel,
partitionReader.frameReader(readablePartition.getStageNumber()),
- new StagePartition(
- new StageId(queryId, readablePartition.getStageNumber()),
- readablePartition.getPartitionNumber()
- )
+ readablePartition.getStageNumber(),
+ readablePartition.getPartitionNumber()
)
);
}
diff --git
a/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/WindowOperatorQueryStageProcessor.java
b/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/WindowOperatorQueryStageProcessor.java
index 78eface70d6..5458fcb9a9b 100644
---
a/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/WindowOperatorQueryStageProcessor.java
+++
b/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/WindowOperatorQueryStageProcessor.java
@@ -37,6 +37,7 @@ import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.msq.exec.ExecutionContext;
import org.apache.druid.msq.exec.std.BasicStageProcessor;
import org.apache.druid.msq.exec.std.ProcessorsAndChannels;
+import org.apache.druid.msq.exec.std.StandardPartitionReader;
import org.apache.druid.msq.exec.std.StandardStageRunner;
import org.apache.druid.msq.input.InputSlice;
import org.apache.druid.msq.input.stage.ReadablePartition;
@@ -133,11 +134,12 @@ public class WindowOperatorQueryStageProcessor extends
BasicStageProcessor
);
}
- final Sequence<ReadableInput> readableInputs =
QueryKitUtils.readPartitions(context, slice.getPartitions());
+ final Sequence<ReadableInput> readableInputs =
+ QueryKitUtils.readPartitions(new StandardPartitionReader(context),
slice.getPartitions());
final Sequence<FrameProcessor<Object>> processors = readableInputs.map(
readableInput -> {
final OutputChannel outputChannel =
-
outputChannels.get(readableInput.getStagePartition().getPartitionNumber());
+ outputChannels.get(readableInput.getPartitionNumber());
return new WindowOperatorQueryFrameProcessor(
query.context(),
diff --git
a/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/OffsetLimitStageProcessor.java
b/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/OffsetLimitStageProcessor.java
index 93458f614eb..dcf23ce0199 100644
---
a/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/OffsetLimitStageProcessor.java
+++
b/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/OffsetLimitStageProcessor.java
@@ -37,6 +37,7 @@ import org.apache.druid.java.util.common.ISE;
import org.apache.druid.msq.exec.ExecutionContext;
import org.apache.druid.msq.exec.std.BasicStageProcessor;
import org.apache.druid.msq.exec.std.ProcessorsAndChannels;
+import org.apache.druid.msq.exec.std.StandardPartitionReader;
import org.apache.druid.msq.exec.std.StandardStageRunner;
import org.apache.druid.msq.input.stage.StageInputSlice;
import org.apache.druid.msq.querykit.QueryKitUtils;
@@ -113,10 +114,11 @@ public class OffsetLimitStageProcessor extends
BasicStageProcessor
throw new RuntimeException(e);
}
+ final StandardPartitionReader partitionReader = new
StandardPartitionReader(context);
final Supplier<FrameProcessor<Object>> workerSupplier = () -> {
final Iterable<ReadableInput> readableInputs = Iterables.transform(
slice.getPartitions(),
- readablePartition -> QueryKitUtils.readPartition(context,
readablePartition)
+ readablePartition -> QueryKitUtils.readPartition(partitionReader,
readablePartition)
);
// Note: OffsetLimitFrameProcessor does not use allocator from the
outputChannel; it uses unlimited instead.
diff --git
a/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinStageProcessor.java
b/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinStageProcessor.java
index 7b5502a8e07..d80a8f48882 100644
---
a/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinStageProcessor.java
+++
b/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinStageProcessor.java
@@ -48,6 +48,7 @@ import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.msq.exec.ExecutionContext;
import org.apache.druid.msq.exec.std.BasicStageProcessor;
import org.apache.druid.msq.exec.std.ProcessorsAndChannels;
+import org.apache.druid.msq.exec.std.StandardPartitionReader;
import org.apache.druid.msq.exec.std.StandardStageRunner;
import org.apache.druid.msq.input.InputSlice;
import org.apache.druid.msq.input.InputSliceReader;
@@ -55,7 +56,6 @@ import org.apache.druid.msq.input.NilInputSlice;
import org.apache.druid.msq.input.stage.ReadablePartition;
import org.apache.druid.msq.input.stage.ReadablePartitions;
import org.apache.druid.msq.input.stage.StageInputSlice;
-import org.apache.druid.msq.kernel.StagePartition;
import org.apache.druid.msq.querykit.QueryKitUtils;
import org.apache.druid.msq.querykit.ReadableInput;
import org.apache.druid.segment.column.RowSignature;
@@ -321,7 +321,7 @@ public class SortMergeJoinStageProcessor extends
BasicStageProcessor
* first provided {@link InputSlice}.
*
* "Missing" partitions -- which occur when one slice has no data for a
given partition -- are replaced with
- * {@link ReadableInput} based on {@link ReadableNilFrameChannel}, with no
{@link StagePartition}.
+ * {@link ReadableInput} based on {@link ReadableNilFrameChannel}.
*
* @throws IllegalStateException if any slices are not {@link
StageInputSlice} or {@link NilInputSlice}
*/
@@ -334,6 +334,7 @@ public class SortMergeJoinStageProcessor extends
BasicStageProcessor
// Partition number -> Input number -> Input channel
final Int2ObjectMap<List<ReadableInput>> retVal = new
Int2ObjectRBTreeMap<>();
+ final StandardPartitionReader partitionReader = new
StandardPartitionReader(context);
for (int inputNumber = 0; inputNumber < slices.size(); inputNumber++) {
final InputSlice slice = slices.get(inputNumber);
@@ -352,7 +353,7 @@ public class SortMergeJoinStageProcessor extends
BasicStageProcessor
final ReadablePartitions partitions = ((StageInputSlice)
slice).getPartitions();
for (final ReadablePartition partition : partitions) {
retVal.computeIfAbsent(partition.getPartitionNumber(), ignored ->
Arrays.asList(new ReadableInput[slices.size()]))
- .set(inputNumber, QueryKitUtils.readPartition(context,
partition));
+ .set(inputNumber, QueryKitUtils.readPartition(partitionReader,
partition));
}
} else if (!(slice instanceof NilInputSlice)) {
throw DruidException.defensive("Slice[%s] is not a 'stage' or 'nil'
slice", slice);
@@ -368,7 +369,8 @@ public class SortMergeJoinStageProcessor extends
BasicStageProcessor
ReadableInput.channel(
ReadableNilFrameChannel.INSTANCE,
frameReadersByInputNumber.get(inputNumber),
- null
+ ReadableInput.NO_STAGE,
+ ReadableInput.NO_PARTITION
)
);
}
diff --git
a/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByFrameCombiner.java
b/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByFrameCombiner.java
new file mode 100644
index 00000000000..6f51b7569f2
--- /dev/null
+++
b/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByFrameCombiner.java
@@ -0,0 +1,281 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.msq.querykit.groupby;
+
+import org.apache.druid.frame.Frame;
+import org.apache.druid.frame.processor.FrameCombiner;
+import org.apache.druid.frame.processor.FrameProcessors;
+import org.apache.druid.frame.read.FrameReader;
+import org.apache.druid.frame.segment.FrameCursor;
+import org.apache.druid.query.aggregation.AggregatorFactory;
+import org.apache.druid.query.dimension.DimensionSpec;
+import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
+import org.apache.druid.segment.BaseSingleValueDimensionSelector;
+import org.apache.druid.segment.ColumnSelectorFactory;
+import org.apache.druid.segment.ColumnValueSelector;
+import org.apache.druid.segment.DimensionSelector;
+import org.apache.druid.segment.NilColumnValueSelector;
+import org.apache.druid.segment.column.ColumnCapabilities;
+import org.apache.druid.segment.column.RowSignature;
+
+import javax.annotation.Nullable;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Implementation of {@link FrameCombiner} for groupBy queries. Combines
aggregate values for rows with
+ * identical dimension keys using {@link AggregatorFactory#combine}.
+ */
+public class GroupByFrameCombiner implements FrameCombiner
+{
+ private final RowSignature signature;
+ private final int aggregatorStart;
+ private final Object[] aggregateValues;
+ private final List<AggregatorFactory> aggregatorFactories;
+ private final CombinedColumnSelectorFactory combinedColumnSelectorFactory;
+
+ private FrameReader frameReader;
+
+ /**
+ * Frame for {@link #cachedCursor} and {@link #cachedAggregatorSelectors}.
Updated by {@link #getCursor(Frame)}.
+ */
+ @Nullable
+ private Frame cachedFrame;
+
+ /**
+ * Cached cursor for the current frame. Updated by {@link #getCursor(Frame)}.
+ */
+ @Nullable
+ private FrameCursor cachedCursor;
+
+ /**
+ * Cached aggregator selectors for the current frame. Updated by {@link
#getCursor(Frame)}.
+ * Same length as {@link #aggregatorFactories}.
+ */
+ @Nullable
+ private ColumnValueSelector<?>[] cachedAggregatorSelectors;
+
+ public GroupByFrameCombiner(
+ final RowSignature signature,
+ final List<AggregatorFactory> aggregatorFactories,
+ final int aggregatorStart
+ )
+ {
+ this.signature = signature;
+ this.aggregatorStart = aggregatorStart;
+ this.aggregatorFactories = aggregatorFactories;
+ this.aggregateValues = new Object[aggregatorFactories.size()];
+ this.combinedColumnSelectorFactory = new CombinedColumnSelectorFactory();
+ }
+
+ @Override
+ public void init(final FrameReader frameReader)
+ {
+ this.frameReader = frameReader;
+ }
+
+ @Override
+ public void reset(final Frame frame, final int row)
+ {
+ final FrameCursor cursor = getCursor(frame);
+ cursor.setCurrentRow(row);
+
+ // Read aggregate values from this row using cached selectors.
+ for (int i = 0; i < aggregatorFactories.size(); i++) {
+ aggregateValues[i] = cachedAggregatorSelectors[i].getObject();
+ }
+ }
+
+ @Override
+ public void combine(final Frame frame, final int row)
+ {
+ final FrameCursor cursor = getCursor(frame);
+ cursor.setCurrentRow(row);
+
+ // Read and combine aggregate values using cached selectors.
+ for (int i = 0; i < aggregatorFactories.size(); i++) {
+ final Object newValue = cachedAggregatorSelectors[i].getObject();
+ aggregateValues[i] =
aggregatorFactories.get(i).combine(aggregateValues[i], newValue);
+ }
+ }
+
+ @Override
+ public ColumnSelectorFactory getCombinedColumnSelectorFactory()
+ {
+ return combinedColumnSelectorFactory;
+ }
+
+ /**
+ * Returns a cursor for the given frame, reusing a cached cursor if the
frame has not changed.
+ * Also rebuilds {@link #cachedAggregatorSelectors} when the cursor changes.
+ */
+ private FrameCursor getCursor(final Frame frame)
+ {
+ //noinspection ObjectEquality
+ if (frame != cachedFrame) {
+ cachedFrame = frame;
+ cachedCursor = FrameProcessors.makeCursor(frame, frameReader);
+
+ // Reset dimension selectors, they need to be recreated for the new
cursor.
+ combinedColumnSelectorFactory.resetSelectorCache();
+
+ // Rebuild aggregator selectors for the new cursor.
+ final ColumnSelectorFactory columnSelectorFactory =
cachedCursor.getColumnSelectorFactory();
+ cachedAggregatorSelectors = new
ColumnValueSelector<?>[aggregatorFactories.size()];
+ for (int i = 0; i < aggregatorFactories.size(); i++) {
+ cachedAggregatorSelectors[i] =
+
columnSelectorFactory.makeColumnValueSelector(signature.getColumnName(aggregatorStart
+ i));
+ }
+ }
+ return cachedCursor;
+ }
+
+ /**
+ * ColumnSelectorFactory that reads dimension columns from the cached frame
cursor, and aggregate columns from
+ * {@link #aggregateValues}. Key columns can be read from any row in the
current group, since
+ * all rows in a group share the same key. The cached cursor is always
positioned at the most recent row passed to
+ * {@link #reset} or {@link #combine}.
+ */
+ private class CombinedColumnSelectorFactory implements ColumnSelectorFactory
+ {
+ /**
+ * Cached dimension value selectors from {@link #cachedCursor}.
+ */
+ private final Map<String, ColumnValueSelector<?>>
valueDimensionSelectorCache = new HashMap<>();
+
+ /**
+ * Cached dimension string selectors from {@link #cachedCursor}.
+ */
+ private final Map<DimensionSpec, DimensionSelector>
stringDimensionSelectorCache = new HashMap<>();
+
+ @Override
+ public DimensionSelector makeDimensionSelector(final DimensionSpec
dimensionSpec)
+ {
+ final int columnIndex = signature.indexOf(dimensionSpec.getDimension());
+
+ if (columnIndex < 0) {
+ return DimensionSelector.constant(null,
dimensionSpec.getExtractionFn());
+ } else if (columnIndex >= aggregatorStart && columnIndex <
aggregatorStart + aggregatorFactories.size()) {
+ // Aggregate column: return combined value as a single-valued
DimensionSelector.
+ final int aggIndex = columnIndex - aggregatorStart;
+ return new BaseSingleValueDimensionSelector()
+ {
+ @Nullable
+ @Override
+ protected String getValue()
+ {
+ final Object val = aggregateValues[aggIndex];
+ return val == null ? null : String.valueOf(val);
+ }
+
+ @Override
+ public void inspectRuntimeShape(final RuntimeShapeInspector
inspector)
+ {
+ // Do nothing.
+ }
+ };
+ } else {
+ // Dimension: delegate to cached dimension selector.
+ return stringDimensionSelectorCache.computeIfAbsent(
+ dimensionSpec,
+ spec ->
cachedCursor.getColumnSelectorFactory().makeDimensionSelector(spec)
+ );
+ }
+ }
+
+ @Override
+ public ColumnValueSelector<?> makeColumnValueSelector(final String
columnName)
+ {
+ final int columnIndex = signature.indexOf(columnName);
+
+ if (columnIndex < 0) {
+ return NilColumnValueSelector.instance();
+ } else if (columnIndex >= aggregatorStart && columnIndex <
aggregatorStart + aggregatorFactories.size()) {
+ // Aggregate column: return combined value as a ColumnValueSelector.
+ final int aggIndex = columnIndex - aggregatorStart;
+ return new ColumnValueSelector<>()
+ {
+ @Override
+ public double getDouble()
+ {
+ return ((Number) aggregateValues[aggIndex]).doubleValue();
+ }
+
+ @Override
+ public float getFloat()
+ {
+ return ((Number) aggregateValues[aggIndex]).floatValue();
+ }
+
+ @Override
+ public long getLong()
+ {
+ return ((Number) aggregateValues[aggIndex]).longValue();
+ }
+
+ @Override
+ public boolean isNull()
+ {
+ return aggregateValues[aggIndex] == null;
+ }
+
+ @Nullable
+ @Override
+ public Object getObject()
+ {
+ return aggregateValues[aggIndex];
+ }
+
+ @Override
+ public Class<?> classOfObject()
+ {
+ return Object.class;
+ }
+
+ @Override
+ public void inspectRuntimeShape(final RuntimeShapeInspector
inspector)
+ {
+ // Do nothing.
+ }
+ };
+ } else {
+ // Dimension: delegate to cached dimension value selector.
+ return valueDimensionSelectorCache.computeIfAbsent(
+ columnName,
+ name ->
cachedCursor.getColumnSelectorFactory().makeColumnValueSelector(name)
+ );
+ }
+ }
+
+ private void resetSelectorCache()
+ {
+ valueDimensionSelectorCache.clear();
+ stringDimensionSelectorCache.clear();
+ }
+
+ @Nullable
+ @Override
+ public ColumnCapabilities getColumnCapabilities(final String column)
+ {
+ return signature.getColumnCapabilities(column);
+ }
+ }
+}
diff --git
a/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByPostShuffleStageProcessor.java
b/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByPostShuffleStageProcessor.java
index b41c1efb49c..f8cea2e4afa 100644
---
a/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByPostShuffleStageProcessor.java
+++
b/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByPostShuffleStageProcessor.java
@@ -36,14 +36,18 @@ import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.msq.exec.ExecutionContext;
import org.apache.druid.msq.exec.std.BasicStageProcessor;
import org.apache.druid.msq.exec.std.ProcessorsAndChannels;
+import org.apache.druid.msq.exec.std.StandardPartitionReader;
import org.apache.druid.msq.exec.std.StandardStageRunner;
import org.apache.druid.msq.input.InputSlice;
import org.apache.druid.msq.input.stage.ReadablePartition;
import org.apache.druid.msq.input.stage.StageInputSlice;
import org.apache.druid.msq.querykit.QueryKitUtils;
import org.apache.druid.msq.querykit.ReadableInput;
+import org.apache.druid.msq.util.MultiStageQueryContext;
+import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.groupby.GroupByQuery;
import org.apache.druid.query.groupby.GroupingEngine;
+import org.apache.druid.segment.column.RowSignature;
import javax.annotation.Nullable;
import java.io.IOException;
@@ -96,11 +100,23 @@ public class GroupByPostShuffleStageProcessor extends
BasicStageProcessor
);
}
- final Sequence<ReadableInput> readableInputs =
QueryKitUtils.readPartitions(context, slice.getPartitions());
+ final StandardPartitionReader partitionReader = new
StandardPartitionReader(context);
+ if
(MultiStageQueryContext.isUseCombiner(context.workOrder().getWorkerContext())) {
+ final RowSignature inputSignature =
partitionReader.frameReader(slice.getStageNumber()).signature();
+ final List<AggregatorFactory> aggregatorFactories =
query.getAggregatorSpecs();
+ final int aggregatorStart = query.getResultRowAggregatorStart();
+ partitionReader.setCombiner(
+ () -> new GroupByFrameCombiner(inputSignature, aggregatorFactories,
aggregatorStart)
+ );
+ }
+
+ final Sequence<ReadableInput> readableInputs =
+ QueryKitUtils.readPartitions(partitionReader, slice.getPartitions());
+
final Sequence<FrameProcessor<Object>> processors = readableInputs.map(
readableInput -> {
final OutputChannel outputChannel =
-
outputChannels.get(readableInput.getStagePartition().getPartitionNumber());
+ outputChannels.get(readableInput.getPartitionNumber());
return new GroupByPostShuffleFrameProcessor(
query,
diff --git
a/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByPreShuffleStageProcessor.java
b/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByPreShuffleStageProcessor.java
index b7cb8ece55c..36eee4bef94 100644
---
a/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByPreShuffleStageProcessor.java
+++
b/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByPreShuffleStageProcessor.java
@@ -28,15 +28,20 @@ import org.apache.druid.collections.ResourceHolder;
import org.apache.druid.frame.channel.WritableFrameChannel;
import org.apache.druid.frame.processor.FrameProcessor;
import org.apache.druid.frame.write.FrameWriterFactory;
+import org.apache.druid.msq.exec.ExecutionContext;
import org.apache.druid.msq.exec.FrameContext;
+import org.apache.druid.msq.exec.std.StandardStageRunner;
import org.apache.druid.msq.input.LoadableSegment;
import org.apache.druid.msq.input.PhysicalInputSlice;
import org.apache.druid.msq.querykit.BaseLeafStageProcessor;
import org.apache.druid.msq.querykit.ReadableInput;
+import org.apache.druid.msq.util.MultiStageQueryContext;
import org.apache.druid.query.QueryToolChestWarehouse;
+import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.groupby.GroupByQuery;
import org.apache.druid.query.groupby.GroupingEngine;
import org.apache.druid.segment.SegmentMapFunction;
+import org.apache.druid.segment.column.RowSignature;
import org.joda.time.Interval;
import javax.annotation.Nullable;
@@ -90,6 +95,23 @@ public class GroupByPreShuffleStageProcessor extends
BaseLeafStageProcessor
);
}
+ @Override
+ protected void configureStageRunner(
+ final StandardStageRunner<Object, Long> stageRunner,
+ final ExecutionContext context
+ )
+ {
+ if
(MultiStageQueryContext.isUseCombiner(context.workOrder().getWorkerContext())) {
+ final RowSignature intermediateSignature =
+ context.workOrder().getStageDefinition().getSignature();
+ final List<AggregatorFactory> aggregatorFactories =
query.getAggregatorSpecs();
+ final int aggregatorStart = query.getResultRowAggregatorStart();
+ stageRunner.setCombiner(
+ () -> new GroupByFrameCombiner(intermediateSignature,
aggregatorFactories, aggregatorStart)
+ );
+ }
+ }
+
@Override
protected List<PhysicalInputSlice> filterBaseInput(final
List<PhysicalInputSlice> slices)
{
diff --git
a/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/results/ExportResultsStageProcessor.java
b/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/results/ExportResultsStageProcessor.java
index a1fd161f629..cedf86431c9 100644
---
a/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/results/ExportResultsStageProcessor.java
+++
b/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/results/ExportResultsStageProcessor.java
@@ -39,6 +39,7 @@ import org.apache.druid.msq.exec.ExtraInfoHolder;
import org.apache.druid.msq.exec.ResultsContext;
import org.apache.druid.msq.exec.StageProcessor;
import org.apache.druid.msq.exec.std.ProcessorsAndChannels;
+import org.apache.druid.msq.exec.std.StandardPartitionReader;
import org.apache.druid.msq.exec.std.StandardStageRunner;
import org.apache.druid.msq.input.InputSlice;
import org.apache.druid.msq.input.stage.StageInputSlice;
@@ -141,7 +142,8 @@ public class ExportResultsStageProcessor implements
StageProcessor<List<String>,
}
final ChannelCounters channelCounter =
context.counters().channel(CounterNames.outputChannel());
- final Sequence<ReadableInput> readableInputs =
QueryKitUtils.readPartitions(context, slice.getPartitions());
+ final Sequence<ReadableInput> readableInputs =
+ QueryKitUtils.readPartitions(new StandardPartitionReader(context),
slice.getPartitions());
final Sequence<FrameProcessor<Object>> processors = readableInputs.map(
readableInput -> new ExportResultsFrameProcessor(
@@ -154,12 +156,12 @@ public class ExportResultsStageProcessor implements
StageProcessor<List<String>,
getExportFilePath(
queryId,
context.workOrder().getWorkerNumber(),
- readableInput.getStagePartition().getPartitionNumber(),
+ readableInput.getPartitionNumber(),
exportFormat
),
columnMappings,
resultsContext,
- readableInput.getStagePartition().getPartitionNumber()
+ readableInput.getPartitionNumber()
)
);
diff --git
a/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/results/QueryResultStageProcessor.java
b/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/results/QueryResultStageProcessor.java
index f64be1f6bd8..4c51c1b92a9 100644
---
a/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/results/QueryResultStageProcessor.java
+++
b/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/results/QueryResultStageProcessor.java
@@ -34,6 +34,7 @@ import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.msq.exec.ExecutionContext;
import org.apache.druid.msq.exec.std.BasicStageProcessor;
import org.apache.druid.msq.exec.std.ProcessorsAndChannels;
+import org.apache.druid.msq.exec.std.StandardPartitionReader;
import org.apache.druid.msq.exec.std.StandardStageRunner;
import org.apache.druid.msq.input.stage.ReadablePartition;
import org.apache.druid.msq.input.stage.StageInputSlice;
@@ -84,11 +85,12 @@ public class QueryResultStageProcessor extends
BasicStageProcessor
);
}
- final Sequence<ReadableInput> readableInputs =
QueryKitUtils.readPartitions(context, slice.getPartitions());
+ final Sequence<ReadableInput> readableInputs =
+ QueryKitUtils.readPartitions(new StandardPartitionReader(context),
slice.getPartitions());
final Sequence<FrameProcessor<Object>> processors = readableInputs.map(
readableInput -> {
final OutputChannel outputChannel =
-
outputChannels.get(readableInput.getStagePartition().getPartitionNumber());
+ outputChannels.get(readableInput.getPartitionNumber());
return new QueryResultsFrameProcessor(
readableInput.getChannel(),
diff --git
a/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java
b/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java
index c12ce2d30ec..df8fd3da310 100644
---
a/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java
+++
b/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java
@@ -28,6 +28,7 @@ import com.opencsv.RFC4180ParserBuilder;
import org.apache.druid.data.input.impl.DimensionsSpec;
import org.apache.druid.error.InvalidInput;
import org.apache.druid.frame.FrameType;
+import org.apache.druid.frame.processor.FrameCombiner;
import org.apache.druid.indexing.common.TaskLockType;
import org.apache.druid.indexing.common.task.Tasks;
import org.apache.druid.java.util.common.DateTimes;
@@ -38,6 +39,7 @@ import org.apache.druid.msq.exec.ClusterStatisticsMergeMode;
import org.apache.druid.msq.exec.ExecutionContext;
import org.apache.druid.msq.exec.Limits;
import org.apache.druid.msq.exec.SegmentSource;
+import org.apache.druid.msq.exec.StageProcessor;
import org.apache.druid.msq.exec.WorkerMemoryParameters;
import org.apache.druid.msq.indexing.destination.MSQSelectDestination;
import org.apache.druid.msq.indexing.error.MSQWarnings;
@@ -168,6 +170,13 @@ public class MultiStageQueryContext
public static final String CTX_REMOVE_NULL_BYTES = "removeNullBytes";
public static final boolean DEFAULT_REMOVE_NULL_BYTES = false;
+ /**
+ * Hint to {@link StageProcessor} implementations about whether they should
attempt to use
+ * {@link FrameCombiner} when doing sort-based aggregations.
+ */
+ public static final String CTX_USE_COMBINER = "useCombiner";
+ public static final boolean DEFAULT_USE_COMBINER = false;
+
/**
* Used by {@link #getMaxRowsInMemory(QueryContext)}.
*/
@@ -458,6 +467,11 @@ public class MultiStageQueryContext
return queryContext.getBoolean(CTX_REMOVE_NULL_BYTES,
DEFAULT_REMOVE_NULL_BYTES);
}
+ public static boolean isUseCombiner(final QueryContext queryContext)
+ {
+ return queryContext.getBoolean(CTX_USE_COMBINER, DEFAULT_USE_COMBINER);
+ }
+
public static boolean isDartQuery(final QueryContext queryContext)
{
return queryContext.get(QueryContexts.CTX_DART_QUERY_ID) != null;
diff --git
a/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQInsertTest.java
b/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQInsertTest.java
index 7cb3c57e4a3..5c83a850a3b 100644
---
a/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQInsertTest.java
+++
b/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQInsertTest.java
@@ -118,7 +118,8 @@ public class MSQInsertTest extends MSQTestBase
{DURABLE_STORAGE, DURABLE_STORAGE_MSQ_CONTEXT},
{FAULT_TOLERANCE, FAULT_TOLERANCE_MSQ_CONTEXT},
{PARALLEL_MERGE, PARALLEL_MERGE_MSQ_CONTEXT},
- {WITH_APPEND_LOCK, QUERY_CONTEXT_WITH_APPEND_LOCK}
+ {WITH_APPEND_LOCK, QUERY_CONTEXT_WITH_APPEND_LOCK},
+ {USE_COMBINER, USE_COMBINER_MSQ_CONTEXT}
};
return Arrays.asList(data);
}
diff --git
a/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java
b/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java
index 320807e9653..509e93333a8 100644
---
a/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java
+++
b/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java
@@ -145,11 +145,13 @@ public class MSQSelectTest extends MSQTestBase
{PARALLEL_MERGE, PARALLEL_MERGE_MSQ_CONTEXT},
{QUERY_RESULTS_WITH_DURABLE_STORAGE,
QUERY_RESULTS_WITH_DURABLE_STORAGE_CONTEXT},
{QUERY_RESULTS_WITH_DEFAULT, QUERY_RESULTS_WITH_DEFAULT_CONTEXT},
- {SUPERUSER, SUPERUSER_MSQ_CONTEXT}
+ {SUPERUSER, SUPERUSER_MSQ_CONTEXT},
+ {USE_COMBINER, USE_COMBINER_MSQ_CONTEXT}
};
return Arrays.asList(data);
}
+
@MethodSource("data")
@ParameterizedTest(name = "{index}:with context {0}")
public void testCalculator(String contextName, Map<String, Object> context)
diff --git
a/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/FrameProcessorTestBase.java
b/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/FrameProcessorTestBase.java
index 45e00579a08..e6175925123 100644
---
a/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/FrameProcessorTestBase.java
+++
b/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/FrameProcessorTestBase.java
@@ -31,8 +31,6 @@ import org.apache.druid.frame.read.FrameReader;
import org.apache.druid.frame.testutil.FrameSequenceBuilder;
import org.apache.druid.java.util.common.concurrent.Execs;
import org.apache.druid.java.util.common.guava.Sequence;
-import org.apache.druid.msq.kernel.StageId;
-import org.apache.druid.msq.kernel.StagePartition;
import org.apache.druid.segment.CursorFactory;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.testing.InitializedNullHandlingTest;
@@ -45,7 +43,8 @@ import java.util.concurrent.TimeUnit;
public class FrameProcessorTestBase extends InitializedNullHandlingTest
{
- protected static final StagePartition STAGE_PARTITION = new
StagePartition(new StageId("q", 0), 0);
+ protected static final int TEST_STAGE_NUMBER = 0;
+ protected static final int TEST_PARTITION_NUMBER = 0;
private ListeningExecutorService innerExec;
protected FrameProcessorExecutor exec;
@@ -103,6 +102,6 @@ public class FrameProcessorTestBase extends
InitializedNullHandlingTest
);
channel.writable().close();
- return ReadableInput.channel(channel.readable(),
FrameReader.create(signature), STAGE_PARTITION);
+ return ReadableInput.channel(channel.readable(),
FrameReader.create(signature), TEST_STAGE_NUMBER, TEST_PARTITION_NUMBER);
}
}
diff --git
a/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessorTest.java
b/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessorTest.java
index 85a0511e0f3..cde1ec12dfa 100644
---
a/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessorTest.java
+++
b/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessorTest.java
@@ -106,7 +106,8 @@ public class SortMergeJoinFrameProcessorTest extends
FrameProcessorTestBase
final ReadableInput factChannel = ReadableInput.channel(
ReadableNilFrameChannel.INSTANCE,
FrameReader.create(JoinTestHelper.FACT_SIGNATURE),
- STAGE_PARTITION
+ TEST_STAGE_NUMBER,
+ TEST_PARTITION_NUMBER
);
final ReadableInput countriesChannel =
@@ -155,7 +156,8 @@ public class SortMergeJoinFrameProcessorTest extends
FrameProcessorTestBase
final ReadableInput countriesChannel = ReadableInput.channel(
ReadableNilFrameChannel.INSTANCE,
FrameReader.create(JoinTestHelper.COUNTRIES_SIGNATURE),
- STAGE_PARTITION
+ TEST_STAGE_NUMBER,
+ TEST_PARTITION_NUMBER
);
final BlockingQueueFrameChannel outputChannel =
BlockingQueueFrameChannel.minimal();
@@ -232,7 +234,7 @@ public class SortMergeJoinFrameProcessorTest extends
FrameProcessorTestBase
final ReadableInput countriesChannel = ReadableInput.channel(
ReadableNilFrameChannel.INSTANCE,
FrameReader.create(JoinTestHelper.COUNTRIES_SIGNATURE),
- STAGE_PARTITION
+ TEST_STAGE_NUMBER, TEST_PARTITION_NUMBER
);
final BlockingQueueFrameChannel outputChannel =
BlockingQueueFrameChannel.minimal();
diff --git
a/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/results/QueryResultsFrameProcessorTest.java
b/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/results/QueryResultsFrameProcessorTest.java
index 97a4f868e20..bb8374c3d86 100644
---
a/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/results/QueryResultsFrameProcessorTest.java
+++
b/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/results/QueryResultsFrameProcessorTest.java
@@ -30,8 +30,6 @@ import org.apache.druid.frame.testutil.FrameSequenceBuilder;
import org.apache.druid.frame.testutil.FrameTestUtil;
import org.apache.druid.java.util.common.Unit;
import org.apache.druid.java.util.common.guava.Sequence;
-import org.apache.druid.msq.kernel.StageId;
-import org.apache.druid.msq.kernel.StagePartition;
import org.apache.druid.msq.querykit.FrameProcessorTestBase;
import org.apache.druid.msq.querykit.ReadableInput;
import org.apache.druid.segment.TestIndex;
@@ -70,14 +68,13 @@ public class QueryResultsFrameProcessorTest extends
FrameProcessorTestBase
}
}
- final StagePartition stagePartition = new StagePartition(new
StageId("query", 0), 0);
-
final QueryResultsFrameProcessor processor =
new QueryResultsFrameProcessor(
ReadableInput.channel(
inputChannel.readable(),
FrameReader.create(signature),
- stagePartition
+ 0,
+ 0
).getChannel(),
outputChannel.writable()
);
diff --git
a/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/scan/ScanQueryFrameProcessorTest.java
b/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/scan/ScanQueryFrameProcessorTest.java
index 7f07e27c166..8b5952adcaf 100644
---
a/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/scan/ScanQueryFrameProcessorTest.java
+++
b/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/scan/ScanQueryFrameProcessorTest.java
@@ -39,8 +39,6 @@ import org.apache.druid.jackson.DefaultObjectMapper;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.Unit;
import org.apache.druid.java.util.common.guava.Sequence;
-import org.apache.druid.msq.kernel.StageId;
-import org.apache.druid.msq.kernel.StagePartition;
import org.apache.druid.msq.querykit.FrameProcessorTestBase;
import org.apache.druid.msq.querykit.ReadableInput;
import org.apache.druid.msq.querykit.SegmentReferenceHolder;
@@ -205,8 +203,6 @@ public class ScanQueryFrameProcessorTest extends
FrameProcessorTestBase
.columns(cursorFactory.getRowSignature().getColumnNames())
.build();
- final StagePartition stagePartition = new StagePartition(new
StageId("query", 0), 0);
-
// Limit output frames to 1 row to ensure we test edge cases
final FrameWriterFactory frameWriterFactory = new
LimitedFrameWriterFactory(
FrameWriters.makeFrameWriterFactory(
@@ -223,7 +219,7 @@ public class ScanQueryFrameProcessorTest extends
FrameProcessorTestBase
query,
null,
new DefaultObjectMapper(),
- ReadableInput.channel(inputChannel.readable(),
FrameReader.create(signature), stagePartition),
+ ReadableInput.channel(inputChannel.readable(),
FrameReader.create(signature), 0, 0),
SegmentMapFunction.IDENTITY,
new ResourceHolder<>()
{
diff --git
a/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java
b/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java
index 64379f711a0..5e50f1dcab6 100644
--- a/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java
+++ b/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java
@@ -313,6 +313,12 @@ public class MSQTestBase extends BaseCalciteQueryTest
)
.build();
+ public static final Map<String, Object> USE_COMBINER_MSQ_CONTEXT =
+ ImmutableMap.<String, Object>builder()
+ .putAll(DEFAULT_MSQ_CONTEXT)
+ .put(MultiStageQueryContext.CTX_USE_COMBINER, true)
+ .build();
+
public static final Map<String, Object>
FAIL_EMPTY_INSERT_ENABLED_MSQ_CONTEXT =
ImmutableMap.<String, Object>builder()
.putAll(DEFAULT_MSQ_CONTEXT)
@@ -333,6 +339,7 @@ public class MSQTestBase extends BaseCalciteQueryTest
public static final String DEFAULT = "default";
public static final String PARALLEL_MERGE = "parallel_merge";
public static final String SUPERUSER = "superuser";
+ public static final String USE_COMBINER = "use_combiner";
protected File localFileStorageDir;
protected LocalFileStorageConnector localFileStorageConnector;
diff --git
a/processing/src/main/java/org/apache/druid/frame/processor/FrameChannelMerger.java
b/processing/src/main/java/org/apache/druid/frame/processor/FrameChannelMerger.java
index da52cbf2096..8fe1682ecab 100644
---
a/processing/src/main/java/org/apache/druid/frame/processor/FrameChannelMerger.java
+++
b/processing/src/main/java/org/apache/druid/frame/processor/FrameChannelMerger.java
@@ -52,6 +52,9 @@ import java.util.function.Supplier;
*
* Frames from input channels must be {@link FrameType#isRowBased()}. Output
frames will be row-based as well.
*
+ * Optionally supports combining adjacent rows with identical sort keys via a
{@link FrameCombiner}, which
+ * reduces intermediate data volume during merge-sort. All combine-specific
state is held in {@link CombineState}.
+ *
* For unsorted output, use {@link FrameChannelMixer} instead.
*/
public class FrameChannelMerger implements FrameProcessor<Long>
@@ -70,13 +73,19 @@ public class FrameChannelMerger implements
FrameProcessor<Long>
private long rowsOutput = 0;
private int currentPartition = 0;
+ /**
+ * Non-null when combining is enabled. Holds the combiner and all
pending-row state.
+ */
+ @Nullable
+ private final CombineState combineState;
+
/**
* Channels that still have input to read.
*/
private final IntSet remainingChannels;
// ColumnSelectorFactory that always reads from the current row in the
merged sequence.
- final MultiColumnSelectorFactory mergedColumnSelectorFactory;
+ private final MultiColumnSelectorFactory mergedColumnSelectorFactory;
/**
* @param inputChannels readable frame channels. Each channel must be
sorted (i.e., if all frames in the channel
@@ -85,6 +94,7 @@ public class FrameChannelMerger implements
FrameProcessor<Long>
* @param outputChannel writable channel to receive the merge-sorted
data
* @param frameWriterFactory writer for frames
* @param sortKey sort key for input and output frames
+ * @param combiner optional combiner for merging rows with
identical sort keys
* @param partitions partitions for output frames. If non-null,
output frames are written with
* partition numbers set according to this
parameter
* @param rowLimit maximum number of rows to write to the output
channel
@@ -95,6 +105,7 @@ public class FrameChannelMerger implements
FrameProcessor<Long>
final WritableFrameChannel outputChannel,
final FrameWriterFactory frameWriterFactory,
final List<KeyColumn> sortKey,
+ @Nullable final FrameCombiner combiner,
@Nullable final ClusterByPartitions partitions,
final long rowLimit
)
@@ -148,13 +159,21 @@ public class FrameChannelMerger implements
FrameProcessor<Long>
);
final List<Supplier<ColumnSelectorFactory>>
frameColumnSelectorFactorySuppliers =
- new ArrayList<>(inputChannels.size());
+ new ArrayList<>(inputChannels.size() + (combiner != null ? 1 : 0));
for (int i = 0; i < inputChannels.size(); i++) {
final int frameNumber = i;
frameColumnSelectorFactorySuppliers.add(() ->
currentFrames[frameNumber].cursor.getColumnSelectorFactory());
}
+ if (combiner != null) {
+ combiner.init(frameReader);
+
frameColumnSelectorFactorySuppliers.add(combiner::getCombinedColumnSelectorFactory);
+ this.combineState = new CombineState(combiner, inputChannels.size());
+ } else {
+ this.combineState = null;
+ }
+
this.mergedColumnSelectorFactory = new MultiColumnSelectorFactory(
frameColumnSelectorFactorySuppliers,
frameReader.signature()
@@ -182,16 +201,16 @@ public class FrameChannelMerger implements
FrameProcessor<Long>
return ReturnOrAwait.awaitAll(awaitSet);
}
- // Check finished() after populateCurrentFramesAndTournamentTree().
- if (finished()) {
+ // After populateCurrentFramesAndTournamentTree(), check if we're done.
+ if (doneReadingInput() && !hasPendingCombineRow()) {
return ReturnOrAwait.returnObject(rowsOutput);
}
// Generate one output frame and stop for now.
writeNextFrame();
- // Check finished() after writeNextFrame().
- if (finished()) {
+ // After writeNextFrame(), check if we're done.
+ if (doneReadingInput() && !hasPendingCombineRow()) {
return ReturnOrAwait.returnObject(rowsOutput);
} else {
return ReturnOrAwait.runAgain();
@@ -200,7 +219,7 @@ public class FrameChannelMerger implements
FrameProcessor<Long>
private void writeNextFrame() throws IOException
{
- if (finished()) {
+ if (doneReadingInput() && !hasPendingCombineRow()) {
throw new NoSuchElementException();
}
@@ -208,88 +227,245 @@ public class FrameChannelMerger implements
FrameProcessor<Long>
int mergedFramePartition = currentPartition;
RowKey currentPartitionEnd = partitions.get(currentPartition).getEnd();
- while (!finished()) {
- final int currentChannel = tournamentTree.getMin();
- mergedColumnSelectorFactory.setCurrentFactory(currentChannel);
-
- if (currentPartitionEnd != null) {
- final FramePlus currentFrame = currentFrames[currentChannel];
- if (currentFrame.comparisonWidget.compare(currentFrame.rowNumber(),
currentPartitionEnd) >= 0) {
- // Current key is past the end of the partition. Advance
currentPartition til it matches the current key.
- do {
- currentPartition++;
- currentPartitionEnd = partitions.get(currentPartition).getEnd();
- } while (currentPartitionEnd != null
- &&
currentFrame.comparisonWidget.compare(currentFrame.rowNumber(),
currentPartitionEnd) >= 0);
-
- if (mergedFrameWriter.getNumRows() == 0) {
- // Fall through: keep reading into the new partition.
- mergedFramePartition = currentPartition;
- } else {
- // Write current frame.
- break;
- }
- }
- }
+ if (combineState != null) {
+ writeNextFrameWithCombiner(mergedFrameWriter, mergedFramePartition,
currentPartitionEnd);
+ } else {
+ writeNextFrameNoCombiner(mergedFrameWriter, mergedFramePartition,
currentPartitionEnd);
+ }
+ }
+ }
+
+ /**
+ * Merge logic without combining.
+ */
+ private void writeNextFrameNoCombiner(
+ final FrameWriter mergedFrameWriter,
+ int mergedFramePartition,
+ RowKey currentPartitionEnd
+ ) throws IOException
+ {
+ while (!doneReadingInput()) {
+ final int currentChannel = tournamentTree.getMin();
+ mergedColumnSelectorFactory.setCurrentFactory(currentChannel);
+
+ if (currentPartitionEnd != null) {
+ final FramePlus currentFrame = currentFrames[currentChannel];
+ if (currentFrame.comparisonWidget.compare(currentFrame.rowNumber(),
currentPartitionEnd) >= 0) {
+ // Current key is past the end of the partition. Advance
currentPartition til it matches the current key.
+ do {
+ currentPartition++;
+ currentPartitionEnd = partitions.get(currentPartition).getEnd();
+ } while (currentPartitionEnd != null
+ &&
currentFrame.comparisonWidget.compare(currentFrame.rowNumber(),
currentPartitionEnd) >= 0);
- if (mergedFrameWriter.addSelection()) {
- rowsOutput++;
- } else {
if (mergedFrameWriter.getNumRows() == 0) {
- throw new
FrameRowTooLargeException(frameWriterFactory.allocatorCapacity());
+ // Fall through: keep reading into the new partition.
+ mergedFramePartition = currentPartition;
+ } else {
+ // Write current frame.
+ break;
}
+ }
+ }
- // Frame is full. Write the current frame.
- break;
+ if (mergedFrameWriter.addSelection()) {
+ rowsOutput++;
+ } else {
+ if (mergedFrameWriter.getNumRows() == 0) {
+ throw new
FrameRowTooLargeException(frameWriterFactory.allocatorCapacity());
}
- if (rowLimit != UNLIMITED && rowsOutput >= rowLimit) {
- // Limit reached; we're done.
- Arrays.fill(currentFrames, null);
- remainingChannels.clear();
- } else {
- // Continue reading the currentChannel.
- final FramePlus channelFramePlus = currentFrames[currentChannel];
- channelFramePlus.cursor.advance();
-
- if (channelFramePlus.isDone()) {
- // Done reading current frame from "channel".
- // Clear it and see if there is another one available for
immediate loading.
- currentFrames[currentChannel] = null;
-
- final ReadableFrameChannel channel =
inputChannels.get(currentChannel);
-
- if (channel.canRead()) {
- // Read next frame from this channel.
- final Frame frame = channel.readFrame();
- final FramePlus framePlus = makeFramePlus(frame, frameReader);
- if (framePlus.isDone()) {
- // Nothing to read in this frame. Not finished; we can't
continue.
- // Finish up the current frame and write it.
- break;
- } else {
- currentFrames[currentChannel] = framePlus;
- }
- } else if (channel.isFinished()) {
- // Done reading this channel. Fall through and continue with
other channels.
- remainingChannels.remove(currentChannel);
- } else {
- // Nothing available, not finished; we can't continue. Finish up
the current frame and write it.
- break;
- }
- }
+ // Frame is full. Write the current frame.
+ break;
+ }
+
+ if (rowLimit != UNLIMITED && rowsOutput >= rowLimit) {
+ // Limit reached; we're done.
+ Arrays.fill(currentFrames, null);
+ remainingChannels.clear();
+ } else if (!advanceChannel(currentChannel)) {
+ // Channel not ready; finish up the current frame.
+ break;
+ }
+ }
+
+ final Frame nextFrame = Frame.wrap(mergedFrameWriter.toByteArray());
+ outputChannel.write(nextFrame, mergedFramePartition);
+ }
+
+ /**
+ * Merge logic with combining: buffers a "pending" row and combines adjacent
rows with the same key.
+ *
+ * Each iteration of the outer loop completes at most one group and writes
it to the frame. The logic may
+ * break early if not enough input channels are ready to generate a group.
+ */
+ private void writeNextFrameWithCombiner(
+ final FrameWriter mergedFrameWriter,
+ int mergedFramePartition,
+ RowKey currentPartitionEnd
+ ) throws IOException
+ {
+ OUTER:
+ while (!doneReadingInput() || combineState.hasPendingRow()) {
+ // Step 1: Ensure there is a pending row.
+ if (!combineState.hasPendingRow()) {
+ final int currentChannel = tournamentTree.getMin();
+ final FramePlus currentFrame = currentFrames[currentChannel];
+ combineState.savePending(currentFrame, currentChannel);
+ if (!advanceChannel(currentChannel)) {
+ break; // Channel not ready; pending saved, will resume next call.
+ }
+ }
+
+ // Step 2: Fold in subsequent rows with the same sort key.
+ while (!doneReadingInput()) {
+ final int currentChannel = tournamentTree.getMin();
+ final FramePlus currentFrame = currentFrames[currentChannel];
+ if (combineState.comparePendingKey(currentFrame.comparisonWidget,
currentFrame.rowNumber()) != 0) {
+ break; // Different key; group is complete.
+ }
+ if (!combineState.pendingCombined) {
+ combineState.combiner.reset(combineState.pendingFrame,
combineState.pendingRow);
+ combineState.pendingCombined = true;
+ }
+ combineState.combiner.combine(currentFrame.frame,
currentFrame.rowNumber());
+ if (!advanceChannel(currentChannel)) {
+ break OUTER; // Channel not ready; group may be incomplete.
}
}
+ // Step 3: Check whether the pending row crosses a partition boundary.
+ if (currentPartitionEnd != null &&
combineState.comparePendingToPartitionEnd(currentPartitionEnd) >= 0) {
+ do {
+ currentPartition++;
+ currentPartitionEnd = partitions.get(currentPartition).getEnd();
+ } while (currentPartitionEnd != null &&
combineState.comparePendingToPartitionEnd(currentPartitionEnd) >= 0);
+
+ if (mergedFrameWriter.getNumRows() > 0) {
+ break; // Frame has rows from the previous partition; write it first.
+ }
+ mergedFramePartition = currentPartition;
+ }
+
+ // Step 4: Flush the completed group to the frame writer.
+ if (!flushPendingCombineRow(mergedFrameWriter)) {
+ break; // Frame is full.
+ }
+
+ // Step 5: Check the row limit.
+ if (rowLimit != UNLIMITED && rowsOutput >= rowLimit) {
+ Arrays.fill(currentFrames, null);
+ remainingChannels.clear();
+ break;
+ }
+ }
+
+ if (mergedFrameWriter.getNumRows() > 0) {
final Frame nextFrame = Frame.wrap(mergedFrameWriter.toByteArray());
outputChannel.write(nextFrame, mergedFramePartition);
}
}
+ /**
+ * Flush the pending row to the frame writer. Returns true if the row was
written or did not exist (and therefore
+ * did not need to be written). Returns false if the frame is full. Only
used when combining, i.e. when
+ * {@link #combineState} is nonnull.
+ */
+ private boolean flushPendingCombineRow(final FrameWriter mergedFrameWriter)
+ {
+ if (!combineState.hasPendingRow()) {
+ return true;
+ }
+
+ final boolean didAdd;
+
+ if (combineState.pendingCombined) {
+ // Combined row: write via combiner's ColumnSelectorFactory.
+ mergedColumnSelectorFactory.setCurrentFactory(combineState.combinerSlot);
+ didAdd = mergedFrameWriter.addSelection();
+ } else {
+ // Non-combined row: try writing using original frame to avoid needing
to use the combiner.
+ //noinspection ObjectEquality
+ if (currentFrames[combineState.pendingChannel] != null
+ && currentFrames[combineState.pendingChannel].frame ==
combineState.pendingFrame) {
+ // Frame is still live.
+ final int savedRow =
currentFrames[combineState.pendingChannel].cursor.getCurrentRow();
+
currentFrames[combineState.pendingChannel].cursor.setCurrentRow(combineState.pendingRow);
+
mergedColumnSelectorFactory.setCurrentFactory(combineState.pendingChannel);
+ didAdd = mergedFrameWriter.addSelection();
+
currentFrames[combineState.pendingChannel].cursor.setCurrentRow(savedRow);
+ } else {
+ // Frame is no longer live, perhaps the pending row was the last row
of the prior frame.
+ // Write using the combiner as a fallback.
+ combineState.combiner.reset(combineState.pendingFrame,
combineState.pendingRow);
+
mergedColumnSelectorFactory.setCurrentFactory(combineState.combinerSlot);
+ didAdd = mergedFrameWriter.addSelection();
+ }
+ }
+
+ if (didAdd) {
+ rowsOutput++;
+ combineState.clearPending();
+ return true;
+ } else {
+ if (mergedFrameWriter.getNumRows() == 0) {
+ throw new
FrameRowTooLargeException(frameWriterFactory.allocatorCapacity());
+ }
+ return false;
+ }
+ }
+
+ /**
+ * Whether there is a pending row waiting to be flushed (only possible when
combining).
+ */
+ private boolean hasPendingCombineRow()
+ {
+ return combineState != null && combineState.hasPendingRow();
+ }
+
+ /**
+ * Advance the cursor for a channel, loading a new frame if necessary.
Returns true if the channel was
+ * successfully advanced (more rows available, channel finished, or next
frame loaded). Returns false if
+ * the channel is not ready to read and not finished, meaning the caller
should finish up the current
+ * output frame and wait for more data.
+ */
+ private boolean advanceChannel(final int currentChannel)
+ {
+ final FramePlus channelFramePlus = currentFrames[currentChannel];
+ channelFramePlus.cursor.advance();
+
+ if (channelFramePlus.isDone()) {
+ // Done reading current frame from "channel".
+ // Clear it and see if there is another one available for immediate
loading.
+ currentFrames[currentChannel] = null;
+
+ final ReadableFrameChannel channel = inputChannels.get(currentChannel);
+
+ if (channel.canRead()) {
+ // Read next frame from this channel.
+ final Frame frame = channel.readFrame();
+ final FramePlus framePlus = makeFramePlus(frame, frameReader);
+ if (framePlus.isDone()) {
+ // Nothing to read in this frame, can't continue.
+ return false;
+ }
+ currentFrames[currentChannel] = framePlus;
+ } else if (channel.isFinished()) {
+ // Done reading this channel.
+ remainingChannels.remove(currentChannel);
+ } else {
+ // Nothing available, not finished; we can't continue. Caller should
finish up the current frame.
+ return false;
+ }
+ }
+
+ return true;
+ }
+
/**
* Returns whether all input is done being read.
*/
- private boolean finished()
+ private boolean doneReadingInput()
{
return remainingChannels.isEmpty();
}
@@ -352,7 +528,7 @@ public class FrameChannelMerger implements
FrameProcessor<Long>
endRow = findRow(frame, comparisonWidget, endRowKey);
}
- return new FramePlus(cursor, comparisonWidget, endRow);
+ return new FramePlus(frame, cursor, comparisonWidget, endRow);
}
/**
@@ -391,16 +567,19 @@ public class FrameChannelMerger implements
FrameProcessor<Long>
*/
private static class FramePlus
{
+ private final Frame frame;
private final FrameCursor cursor;
private final FrameComparisonWidget comparisonWidget;
private final int endRow;
public FramePlus(
+ final Frame frame,
final FrameCursor cursor,
final FrameComparisonWidget comparisonWidget,
final int endRow
)
{
+ this.frame = frame;
this.cursor = cursor;
this.comparisonWidget = comparisonWidget;
this.endRow = endRow;
@@ -416,4 +595,118 @@ public class FrameChannelMerger implements
FrameProcessor<Long>
return cursor.getCurrentRow() >= endRow;
}
}
+
+ /**
+ * Holds all state related to combine-during-merge. This includes the {@link
FrameCombiner} itself,
+ * the index of the combiner's ColumnSelectorFactory slot in the {@link
MultiColumnSelectorFactory}, and the
+ * "pending row" that is buffered while we check whether the next row has
the same key.
+ */
+ private static class CombineState
+ {
+ /**
+ * The combiner for this run.
+ */
+ private final FrameCombiner combiner;
+
+ /**
+ * Index of the combiner's {@link ColumnSelectorFactory} within the {@link
MultiColumnSelectorFactory}.
+ * Equal to the number of input channels (i.e., one past the last channel
slot).
+ */
+ private final int combinerSlot;
+
+ /**
+ * Whether there is a pending row buffered.
+ */
+ private boolean hasPendingRow;
+
+ /**
+ * Channel index that the pending row came from. Used by {@code
flushPending()} to attempt the fast path:
+ * if the channel's current {@link FramePlus} still references {@link
#pendingFrame}, we can write the row
+ * directly from the channel's cursor without involving the combiner. Only
valid if {@link #hasPendingRow}.
+ */
+ private int pendingChannel;
+
+ /**
+ * Row number within {@link #pendingFrame} for the pending row. Only valid
if {@link #hasPendingRow}.
+ */
+ private int pendingRow;
+
+ /**
+ * Frame reference for the pending row. Kept alive so the pending row's
data can be read even after
+ * the channel has advanced to a new frame. Only valid if {@link
#hasPendingRow}.
+ */
+ @Nullable
+ private Frame pendingFrame;
+
+ /**
+ * Comparison widget for {@link #pendingFrame}. Only valid if {@link
#hasPendingRow}.
+ */
+ @Nullable
+ private FrameComparisonWidget pendingComparisonWidget;
+
+ /**
+ * Whether {@link FrameCombiner#combine} has been called for the current
pending row. When false,
+ * the pending row is a singleton that has not yet been passed to the
combiner. In this case,
+ * {@code flushPending()} can use the fast path of writing directly from
the original frame.
+ * When true, the combiner holds the accumulated state and must be used
for writing.
+ */
+ private boolean pendingCombined;
+
+ CombineState(final FrameCombiner combiner, final int combinerSlot)
+ {
+ this.combiner = combiner;
+ this.combinerSlot = combinerSlot;
+ }
+
+ /**
+ * Whether a row is currently pending.
+ */
+ boolean hasPendingRow()
+ {
+ return hasPendingRow;
+ }
+
+ /**
+ * Update the pending row.
+ */
+ void savePending(
+ final FramePlus framePlus,
+ final int channel
+ )
+ {
+ this.hasPendingRow = true;
+ this.pendingChannel = channel;
+ this.pendingRow = framePlus.rowNumber();
+ this.pendingFrame = framePlus.frame;
+ this.pendingComparisonWidget = framePlus.comparisonWidget;
+ this.pendingCombined = false;
+ }
+
+ /**
+ * Clear the pending row.
+ */
+ void clearPending()
+ {
+ this.hasPendingRow = false;
+ this.pendingFrame = null;
+ this.pendingComparisonWidget = null;
+ }
+
+ /**
+ * Compare the pending row's key against a row from another frame. Uses
frame-to-frame comparison
+ * to avoid materializing the key as a byte array.
+ */
+ int comparePendingKey(final FrameComparisonWidget otherWidget, final int
otherRow)
+ {
+ return pendingComparisonWidget.compare(pendingRow, otherWidget,
otherRow);
+ }
+
+ /**
+ * Compare the pending row to a partition boundary key.
+ */
+ int comparePendingToPartitionEnd(final RowKey partitionEnd)
+ {
+ return pendingComparisonWidget.compare(pendingRow, partitionEnd);
+ }
+ }
}
diff --git
a/processing/src/main/java/org/apache/druid/frame/processor/FrameCombiner.java
b/processing/src/main/java/org/apache/druid/frame/processor/FrameCombiner.java
new file mode 100644
index 00000000000..9be4da03889
--- /dev/null
+++
b/processing/src/main/java/org/apache/druid/frame/processor/FrameCombiner.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.frame.processor;
+
+import org.apache.druid.frame.Frame;
+import org.apache.druid.frame.read.FrameReader;
+import org.apache.druid.segment.ColumnSelectorFactory;
+
+/**
+ * Interface for combining adjacent rows with identical sort keys during
merge-sort. Works directly with
+ * {@link Frame} objects and row indices for efficient random-access reads.
+ *
+ * Used by {@link FrameChannelMerger} to reduce intermediate data volume by
combining rows that share
+ * the same sort key.
+ */
+public interface FrameCombiner
+{
+ /**
+ * Initialize with the FrameReader for understanding frame schema and format.
+ */
+ void init(FrameReader frameReader);
+
+ /**
+ * Start a new group with the given row from the given frame.
+ * Reads and stores values from this row.
+ */
+ void reset(Frame frame, int row);
+
+ /**
+ * Fold another row into the accumulated group. Only called when the row
+ * has the same sort key as the group started in {@link #reset}.
+ */
+ void combine(Frame frame, int row);
+
+ /**
+ * Returns a {@link ColumnSelectorFactory} backed by the accumulated
combined values.
+ * Used internally by {@link FrameChannelMerger} for writing combined rows
via FrameWriter.
+ *
+ * For key columns, returns values from the first row of the group.
+ * For value columns, returns the result of combining all rows.
+ */
+ ColumnSelectorFactory getCombinedColumnSelectorFactory();
+}
diff --git
a/processing/src/main/java/org/apache/druid/frame/processor/FrameCombinerFactory.java
b/processing/src/main/java/org/apache/druid/frame/processor/FrameCombinerFactory.java
new file mode 100644
index 00000000000..b3688ba5549
--- /dev/null
+++
b/processing/src/main/java/org/apache/druid/frame/processor/FrameCombinerFactory.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.frame.processor;
+
+/**
+ * Factory for creating {@link FrameCombiner} instances. This interface exists
because {@link FrameCombiner} objects
+ * are stateful (they accumulate values across rows within a group), so each
merge operation needs its own instance.
+ */
+@FunctionalInterface
+public interface FrameCombinerFactory
+{
+ FrameCombiner newCombiner();
+}
diff --git
a/processing/src/main/java/org/apache/druid/frame/processor/SuperSorter.java
b/processing/src/main/java/org/apache/druid/frame/processor/SuperSorter.java
index 3c13b39c8b1..7253feb06aa 100644
--- a/processing/src/main/java/org/apache/druid/frame/processor/SuperSorter.java
+++ b/processing/src/main/java/org/apache/druid/frame/processor/SuperSorter.java
@@ -140,6 +140,8 @@ public class SuperSorter
private final int maxActiveProcessors;
private final String cancellationId;
private final boolean removeNullBytes;
+ @Nullable
+ private final FrameCombinerFactory combinerFactory;
private final Object runWorkersLock = new Object();
@GuardedBy("runWorkersLock")
@@ -238,7 +240,8 @@ public class SuperSorter
final long rowLimit,
@Nullable final String cancellationId,
final SuperSorterProgressTracker superSorterProgressTracker,
- final boolean removeNullBytes
+ final boolean removeNullBytes,
+ @Nullable final FrameCombinerFactory combinerFactory
)
{
this.inputChannels = inputChannels;
@@ -256,6 +259,7 @@ public class SuperSorter
this.cancellationId = cancellationId;
this.superSorterProgressTracker = superSorterProgressTracker;
this.removeNullBytes = removeNullBytes;
+ this.combinerFactory = combinerFactory;
for (int i = 0; i < inputChannels.size(); i++) {
inputChannelsToRead.add(i);
@@ -752,6 +756,7 @@ public class SuperSorter
removeNullBytes
),
sortKey,
+ combinerFactory != null ? combinerFactory.newCombiner() : null,
outPartitions,
rowLimit
);
diff --git
a/processing/src/main/java/org/apache/druid/frame/read/FrameReaderUtils.java
b/processing/src/main/java/org/apache/druid/frame/read/FrameReaderUtils.java
index daf95a5ac7c..4dc03297836 100644
--- a/processing/src/main/java/org/apache/druid/frame/read/FrameReaderUtils.java
+++ b/processing/src/main/java/org/apache/druid/frame/read/FrameReaderUtils.java
@@ -41,6 +41,15 @@ import java.util.function.Supplier;
*/
public class FrameReaderUtils
{
+ /**
+ * Columns that are used to directly select row memory through a {@link
ColumnSelectorFactory}.
+ */
+ public static final List<String> ROW_MEMORY_COLUMNS = List.of(
+ FrameColumnSelectorFactory.FRAME_TYPE_COLUMN,
+ FrameColumnSelectorFactory.ROW_SIGNATURE_COLUMN,
+ FrameColumnSelectorFactory.ROW_MEMORY_COLUMN
+ );
+
/**
* Returns a ByteBuffer containing data from the provided {@link Memory}.
The ByteBuffer is always newly
* created, so it is OK to change its position, limit, etc. However, it may
point directly to the backing memory
@@ -287,13 +296,7 @@ public class FrameReaderUtils
*/
private static boolean mayBeAbleToSelectRowMemory(final
ColumnSelectorFactory columnSelectorFactory)
{
- final List<String> requiredColumns = List.of(
- FrameColumnSelectorFactory.FRAME_TYPE_COLUMN,
- FrameColumnSelectorFactory.ROW_SIGNATURE_COLUMN,
- FrameColumnSelectorFactory.ROW_MEMORY_COLUMN
- );
-
- for (final String columnName : requiredColumns) {
+ for (final String columnName : ROW_MEMORY_COLUMNS) {
if (columnSelectorFactory.getColumnCapabilities(columnName) == null) {
return false;
}
diff --git
a/processing/src/test/java/org/apache/druid/frame/processor/SummingFrameCombiner.java
b/processing/src/test/java/org/apache/druid/frame/processor/SummingFrameCombiner.java
new file mode 100644
index 00000000000..26d4fa740d7
--- /dev/null
+++
b/processing/src/test/java/org/apache/druid/frame/processor/SummingFrameCombiner.java
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.frame.processor;
+
+import org.apache.druid.frame.Frame;
+import org.apache.druid.frame.read.FrameReader;
+import org.apache.druid.frame.segment.FrameCursor;
+import org.apache.druid.query.dimension.DimensionSpec;
+import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
+import org.apache.druid.segment.ColumnSelectorFactory;
+import org.apache.druid.segment.ColumnValueSelector;
+import org.apache.druid.segment.DimensionSelector;
+import org.apache.druid.segment.NilColumnValueSelector;
+import org.apache.druid.segment.column.ColumnCapabilities;
+import org.apache.druid.segment.column.RowSignature;
+
+import javax.annotation.Nullable;
+
+/**
+ * Simple test combiner that sums a long column at {@link #sumColumnNumber}.
+ * All columns before {@link #sumColumnNumber} are treated as key columns.
+ */
+public class SummingFrameCombiner implements FrameCombiner
+{
+ private final RowSignature signature;
+ private final int sumColumnNumber;
+
+ private FrameReader frameReader;
+ private FrameCursor keyCursor;
+ private long summedValue;
+
+ public SummingFrameCombiner(final RowSignature signature, final int
sumColumnNumber)
+ {
+ this.signature = signature;
+ this.sumColumnNumber = sumColumnNumber;
+ }
+
+ @Override
+ public void init(final FrameReader frameReader)
+ {
+ this.frameReader = frameReader;
+ }
+
+ @Override
+ public void reset(final Frame frame, final int row)
+ {
+ this.keyCursor = FrameProcessors.makeCursor(frame, frameReader);
+ this.keyCursor.setCurrentRow(row);
+ this.summedValue = readLongValue(frame, row);
+ }
+
+ @Override
+ public void combine(final Frame frame, final int row)
+ {
+ this.summedValue += readLongValue(frame, row);
+ }
+
+ @Override
+ public ColumnSelectorFactory getCombinedColumnSelectorFactory()
+ {
+ return new ColumnSelectorFactory()
+ {
+ @Override
+ public DimensionSelector makeDimensionSelector(final DimensionSpec
dimensionSpec)
+ {
+ final int columnNumber =
signature.indexOf(dimensionSpec.getDimension());
+ if (columnNumber < 0) {
+ return DimensionSelector.constant(null,
dimensionSpec.getExtractionFn());
+ } else if (columnNumber == sumColumnNumber) {
+ throw new UnsupportedOperationException();
+ } else {
+ return
keyCursor.getColumnSelectorFactory().makeDimensionSelector(dimensionSpec);
+ }
+ }
+
+ @Override
+ public ColumnValueSelector<?> makeColumnValueSelector(final String
columnName)
+ {
+ final int columnNumber = signature.indexOf(columnName);
+ if (columnNumber < 0) {
+ return NilColumnValueSelector.instance();
+ } else if (columnNumber == sumColumnNumber) {
+ return new ColumnValueSelector<Long>()
+ {
+ @Override
+ public double getDouble()
+ {
+ return summedValue;
+ }
+
+ @Override
+ public float getFloat()
+ {
+ return summedValue;
+ }
+
+ @Override
+ public long getLong()
+ {
+ return summedValue;
+ }
+
+ @Override
+ public boolean isNull()
+ {
+ return false;
+ }
+
+ @Override
+ public Long getObject()
+ {
+ return summedValue;
+ }
+
+ @Override
+ public Class<Long> classOfObject()
+ {
+ return Long.class;
+ }
+
+ @Override
+ public void inspectRuntimeShape(RuntimeShapeInspector inspector)
+ {
+ // Nothing to do.
+ }
+ };
+ } else {
+ return
keyCursor.getColumnSelectorFactory().makeColumnValueSelector(columnName);
+ }
+ }
+
+ @Nullable
+ @Override
+ public ColumnCapabilities getColumnCapabilities(final String column)
+ {
+ return signature.getColumnCapabilities(column);
+ }
+ };
+ }
+
+ private long readLongValue(final Frame frame, final int row)
+ {
+ final FrameCursor cursor = FrameProcessors.makeCursor(frame, frameReader);
+ cursor.setCurrentRow(row);
+ final String columnName = signature.getColumnName(sumColumnNumber);
+ return
cursor.getColumnSelectorFactory().makeColumnValueSelector(columnName).getLong();
+ }
+}
diff --git
a/processing/src/test/java/org/apache/druid/frame/processor/SuperSorterTest.java
b/processing/src/test/java/org/apache/druid/frame/processor/SuperSorterTest.java
index 680df7a3eb4..bef24a3c846 100644
---
a/processing/src/test/java/org/apache/druid/frame/processor/SuperSorterTest.java
+++
b/processing/src/test/java/org/apache/druid/frame/processor/SuperSorterTest.java
@@ -25,6 +25,8 @@ import com.google.common.primitives.Ints;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.common.util.concurrent.SettableFuture;
+import org.apache.druid.data.input.MapBasedRow;
+import org.apache.druid.data.input.Row;
import org.apache.druid.frame.Frame;
import org.apache.druid.frame.FrameType;
import org.apache.druid.frame.allocation.ArenaMemoryAllocator;
@@ -52,8 +54,11 @@ import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.segment.CursorFactory;
import org.apache.druid.segment.QueryableIndexCursorFactory;
+import org.apache.druid.segment.RowAdapters;
+import org.apache.druid.segment.RowBasedCursorFactory;
import org.apache.druid.segment.TestIndex;
import org.apache.druid.segment.column.ColumnHolder;
+import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.hamcrest.MatcherAssert;
@@ -138,7 +143,8 @@ public class SuperSorterTest
SuperSorter.UNLIMITED,
null,
superSorterProgressTracker,
- false
+ false,
+ null
);
superSorter.setNoWorkRunnable(() ->
outputPartitionsFuture.set(ClusterByPartitions.oneUniversalPartition()));
@@ -175,7 +181,8 @@ public class SuperSorterTest
-1,
null,
superSorterProgressTracker,
- false
+ false,
+ null
);
final OutputChannels channels = superSorter.run().get();
@@ -211,7 +218,8 @@ public class SuperSorterTest
3,
null,
superSorterProgressTracker,
- false
+ false,
+ null
);
final OutputChannels channels = superSorter.run().get();
@@ -459,7 +467,8 @@ public class SuperSorterTest
limitHint,
null,
superSorterProgressTracker,
- false
+ false,
+ null
);
if (partitionsDeferred) {
@@ -819,6 +828,379 @@ public class SuperSorterTest
}
}
+ /**
+ * Parameterized test cases for the combiner functionality.
+ */
+ @RunWith(Parameterized.class)
+ public static class CombinerTest extends InitializedNullHandlingTest
+ {
+ private static final int FRAME_SIZE = 1_000_000;
+
+ private static final RowSignature SIGNATURE = RowSignature.builder()
+ .add("key",
ColumnType.STRING)
+ .add("value",
ColumnType.LONG)
+ .build();
+
+ private static final ClusterBy CLUSTER_BY = new ClusterBy(
+ ImmutableList.of(new KeyColumn("key", KeyOrder.ASCENDING)),
+ 0
+ );
+
+ private static final RowSignature SORTABLE_SIGNATURE =
+ FrameWriters.sortableSignature(SIGNATURE, CLUSTER_BY.getColumns());
+
+ @Parameterized.Parameters(name = "maxRowsPerFrame = {0},
maxChannelsPerMerger = {1}")
+ public static Iterable<Object[]> constructorFeeder()
+ {
+ final List<Object[]> constructors = new ArrayList<>();
+ for (int maxRowsPerFrame : new int[]{1, 2, 3, 4, 5, 6}) {
+ for (int maxChannelsPerMerger : new int[]{2, 3}) {
+ constructors.add(new Object[]{maxRowsPerFrame,
maxChannelsPerMerger});
+ }
+ }
+ return constructors;
+ }
+
+ @Rule
+ public TemporaryFolder temporaryFolder = new TemporaryFolder();
+
+ private final int maxRowsPerFrame;
+ private final int maxChannelsPerMerger;
+ private FrameProcessorExecutor exec;
+
+ public CombinerTest(final int maxRowsPerFrame, final int
maxChannelsPerMerger)
+ {
+ this.maxRowsPerFrame = maxRowsPerFrame;
+ this.maxChannelsPerMerger = maxChannelsPerMerger;
+ }
+
+ @Before
+ public void setUp()
+ {
+ exec = new FrameProcessorExecutor(
+ MoreExecutors.listeningDecorator(Execs.multiThreaded(1,
"super-sorter-combiner-test-%d"))
+ );
+ }
+
+ @After
+ public void tearDown()
+ {
+ exec.getExecutorService().shutdownNow();
+ }
+
+ /**
+ * Test that combining works in direct mode: input data with duplicate
keys produces fewer output rows
+ * with combined values.
+ */
+ @Test
+ public void testCombineDirectMode() throws Exception
+ {
+ final List<Object[][]> channelData = ImmutableList.of(
+ new Object[][]{{"a", 1L}, {"b", 2L}, {"c", 3L}},
+ new Object[][]{{"a", 10L}, {"b", 20L}, {"c", 30L}}
+ );
+
+ final List<List<Object>> rows = runCombiningSuperSorter(
+ channelData,
+ ClusterByPartitions.oneUniversalPartition(),
+ 2,
+ SuperSorter.UNLIMITED
+ );
+
+ Assert.assertEquals(
+ ImmutableList.of(
+ ImmutableList.of("a", 11L),
+ ImmutableList.of("b", 22L),
+ ImmutableList.of("c", 33L)
+ ),
+ rows
+ );
+ }
+
+ /**
+ * Test that combining works in external mode with many channels forcing
multi-level merge.
+ */
+ @Test
+ public void testCombineExternalMode() throws Exception
+ {
+ final List<Object[][]> channelData = new ArrayList<>();
+ for (int i = 0; i < 10; i++) {
+ channelData.add(new Object[][]{{"a", 1L}, {"b", 1L}});
+ }
+
+ final List<List<Object>> rows = runCombiningSuperSorter(
+ channelData,
+ ClusterByPartitions.oneUniversalPartition(),
+ 1,
+ SuperSorter.UNLIMITED
+ );
+
+ Assert.assertEquals(
+ ImmutableList.of(
+ ImmutableList.of("a", 10L),
+ ImmutableList.of("b", 10L)
+ ),
+ rows
+ );
+ }
+
+ /**
+ * Test that combining works correctly when output is partitioned.
+ */
+ @Test
+ public void testCombineWithPartitions() throws Exception
+ {
+ // Two partitions: [null, c) and [c, null)
+ final RowKey partitionBoundary = KeyTestUtils.createKey(
+ KeyTestUtils.createKeySignature(CLUSTER_BY.getColumns(),
SORTABLE_SIGNATURE),
+ FrameType.latestRowBased(),
+ "c"
+ );
+
+ final ClusterByPartitions partitions = new ClusterByPartitions(
+ ImmutableList.of(
+ new ClusterByPartition(null, partitionBoundary),
+ new ClusterByPartition(partitionBoundary, null)
+ )
+ );
+
+ final List<Object[][]> channelData = ImmutableList.of(
+ new Object[][]{{"a", 1L}, {"b", 2L}, {"c", 3L}, {"d", 4L}},
+ new Object[][]{{"a", 10L}, {"b", 20L}, {"c", 30L}, {"d", 40L}}
+ );
+
+ final List<List<Object>> rows = runCombiningSuperSorter(channelData,
partitions, 2, SuperSorter.UNLIMITED);
+
+ Assert.assertEquals(
+ ImmutableList.of(
+ ImmutableList.of("a", 11L),
+ ImmutableList.of("b", 22L),
+ ImmutableList.of("c", 33L),
+ ImmutableList.of("d", 44L)
+ ),
+ rows
+ );
+ }
+
+ /**
+ * Test combining when all rows across all channels have the same key.
+ */
+ @Test
+ public void testCombineAllSameKey() throws Exception
+ {
+ final List<Object[][]> channelData = new ArrayList<>();
+ for (int i = 0; i < 5; i++) {
+ channelData.add(new Object[][]{{"x", 1L}, {"x", 2L}, {"x", 3L}});
+ }
+
+ final List<List<Object>> rows = runCombiningSuperSorter(
+ channelData,
+ ClusterByPartitions.oneUniversalPartition(),
+ 2,
+ SuperSorter.UNLIMITED
+ );
+
+ // 5 channels * (1 + 2 + 3) = 30
+ Assert.assertEquals(
+ ImmutableList.of(ImmutableList.of("x", 30L)),
+ rows
+ );
+ }
+
+ /**
+ * Test combining with row limits.
+ */
+ @Test
+ public void testCombineWithRowLimit() throws Exception
+ {
+ final List<Object[][]> channelData = ImmutableList.of(
+ new Object[][]{{"a", 1L}, {"b", 2L}, {"c", 3L}},
+ new Object[][]{{"a", 10L}, {"b", 20L}, {"c", 30L}}
+ );
+
+ for (int limit = 1; limit <= 3; limit++) {
+ final List<List<Object>> rows = runCombiningSuperSorter(
+ channelData,
+ ClusterByPartitions.oneUniversalPartition(),
+ 2,
+ limit
+ );
+
+ Assert.assertEquals(
+ StringUtils.format("limit[%d]: expected exactly %d row(s), got
%d", limit, limit, rows.size()),
+ limit,
+ rows.size()
+ );
+ Assert.assertEquals(ImmutableList.of("a", 11L), rows.get(0));
+ if (limit >= 2) {
+ Assert.assertEquals(ImmutableList.of("b", 22L), rows.get(1));
+ }
+ if (limit >= 3) {
+ Assert.assertEquals(ImmutableList.of("c", 33L), rows.get(2));
+ }
+ }
+ }
+
+ /**
+ * Test combining with rowLimit = 1 and all-same-key input: single
combined row.
+ */
+ @Test
+ public void testCombineAllSameKeyWithRowLimit1() throws Exception
+ {
+ final List<Object[][]> channelData = ImmutableList.of(
+ new Object[][]{{"x", 1L}},
+ new Object[][]{{"x", 2L}},
+ new Object[][]{{"x", 3L}}
+ );
+
+ final List<List<Object>> rows = runCombiningSuperSorter(
+ channelData,
+ ClusterByPartitions.oneUniversalPartition(),
+ 2,
+ 1
+ );
+
+ // All rows combine to one; rowLimit = 1 is satisfied.
+ Assert.assertEquals(
+ ImmutableList.of(ImmutableList.of("x", 6L)),
+ rows
+ );
+ }
+
+ /**
+ * Test combining with a single input channel where duplicate keys are
within the same sorted stream.
+ */
+ @Test
+ public void testCombineSingleChannel() throws Exception
+ {
+ final List<Object[][]> channelData = ImmutableList.of(
+ new Object[][]{{"a", 1L}, {"a", 2L}, {"b", 3L}, {"b", 4L}}
+ );
+
+ final List<List<Object>> rows = runCombiningSuperSorter(
+ channelData,
+ ClusterByPartitions.oneUniversalPartition(),
+ 1,
+ SuperSorter.UNLIMITED
+ );
+
+ Assert.assertEquals(
+ ImmutableList.of(
+ ImmutableList.of("a", 3L),
+ ImmutableList.of("b", 7L)
+ ),
+ rows
+ );
+ }
+
+ /**
+ * Test combining with empty input channels.
+ */
+ @Test
+ public void testCombineEmptyInput() throws Exception
+ {
+ final List<Object[][]> channelData = ImmutableList.of(new Object[][]{});
+
+ final List<List<Object>> rows = runCombiningSuperSorter(
+ channelData,
+ ClusterByPartitions.oneUniversalPartition(),
+ 1,
+ SuperSorter.UNLIMITED
+ );
+
+ Assert.assertEquals(ImmutableList.of(), rows);
+ }
+
+ /**
+ * Helper that runs a combining SuperSorter with the given channel data,
partitions, maxActiveProcessors,
+ * and rowLimit. Returns all output rows across all partitions.
+ */
+ private List<List<Object>> runCombiningSuperSorter(
+ final List<Object[][]> channelData,
+ final ClusterByPartitions partitions,
+ final int maxActiveProcessors,
+ final long rowLimit
+ ) throws Exception
+ {
+ final FrameReader frameReader = FrameReader.create(SORTABLE_SIGNATURE);
+
+ final List<ReadableFrameChannel> channels = new ArrayList<>();
+ for (final Object[][] data : channelData) {
+ channels.add(makeFrameChannel(data));
+ }
+
+ final File tempFolder = temporaryFolder.newFolder();
+
+ final SuperSorter superSorter = new SuperSorter(
+ channels,
+ frameReader,
+ CLUSTER_BY.getColumns(),
+ Futures.immediateFuture(partitions),
+ exec,
+ FrameProcessorDecorator.NONE,
+ new FileOutputChannelFactory(tempFolder, FRAME_SIZE, null,
FrameTestUtil.WT_CONTEXT_LEGACY),
+ new FileOutputChannelFactory(tempFolder, FRAME_SIZE, null,
FrameTestUtil.WT_CONTEXT_LEGACY),
+ FrameType.latestRowBased(),
+ maxActiveProcessors,
+ maxChannelsPerMerger,
+ rowLimit,
+ null,
+ new SuperSorterProgressTracker(),
+ false,
+ () -> new SummingFrameCombiner(SORTABLE_SIGNATURE, 1)
+ );
+
+ final OutputChannels outputChannels = superSorter.run().get();
+
+ final List<List<Object>> rows = new ArrayList<>();
+ for (final OutputChannel channel : outputChannels.getAllChannels()) {
+ FrameTestUtil.readRowsFromFrameChannel(channel.getReadableChannel(),
frameReader)
+ .forEach(rows::add);
+ }
+
+ return rows;
+ }
+
+ private ReadableFrameChannel makeFrameChannel(final Object[][] rows)
throws IOException
+ {
+ final List<Row> rowList = new ArrayList<>();
+ for (final Object[] row : rows) {
+ final Map<String, Object> map = new HashMap<>();
+ for (int i = 0; i < SIGNATURE.size(); i++) {
+ map.put(SIGNATURE.getColumnName(i), row[i]);
+ }
+ rowList.add(new MapBasedRow(0L, map));
+ }
+
+ final RowBasedCursorFactory<Row> cursorFactory =
+ new RowBasedCursorFactory<>(
+ Sequences.simple(rowList),
+ RowAdapters.standardRow(),
+ SIGNATURE
+ );
+
+ final Sequence<Frame> frames =
+ FrameSequenceBuilder.fromCursorFactory(cursorFactory)
+ .maxRowsPerFrame(maxRowsPerFrame)
+ .sortBy(CLUSTER_BY.getColumns())
+
.allocator(ArenaMemoryAllocator.create(ByteBuffer.allocate(FRAME_SIZE)))
+ .frameType(FrameType.latestRowBased())
+ .frames();
+
+ final BlockingQueueFrameChannel channel = new
BlockingQueueFrameChannel(100);
+ frames.forEach(frame -> {
+ try {
+ channel.writable().write(frame);
+ }
+ catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ });
+ channel.writable().close();
+ return channel.readable();
+ }
+ }
+
/**
* Distribute frames round-robin to some number of channels.
*/
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]