gianm commented on code in PR #13506: URL: https://github.com/apache/druid/pull/13506#discussion_r1114188415
########## extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessor.java: ########## @@ -0,0 +1,1054 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.querykit.common; + +import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; +import com.google.common.collect.ImmutableList; +import it.unimi.dsi.fastutil.ints.IntOpenHashSet; +import it.unimi.dsi.fastutil.ints.IntSet; +import org.apache.druid.frame.Frame; +import org.apache.druid.frame.channel.FrameWithPartition; +import org.apache.druid.frame.channel.ReadableFrameChannel; +import org.apache.druid.frame.channel.WritableFrameChannel; +import org.apache.druid.frame.key.FrameComparisonWidget; +import org.apache.druid.frame.key.KeyColumn; +import org.apache.druid.frame.key.RowKey; +import org.apache.druid.frame.key.RowKeyReader; +import org.apache.druid.frame.processor.FrameProcessor; +import org.apache.druid.frame.processor.FrameProcessors; +import org.apache.druid.frame.processor.FrameRowTooLargeException; +import org.apache.druid.frame.processor.ReturnOrAwait; +import org.apache.druid.frame.read.FrameReader; +import org.apache.druid.frame.segment.FrameCursor; +import org.apache.druid.frame.write.FrameWriter; +import org.apache.druid.frame.write.FrameWriterFactory; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.msq.exec.Limits; +import org.apache.druid.msq.indexing.error.MSQException; +import org.apache.druid.msq.indexing.error.TooManyRowsWithSameKeyFault; +import org.apache.druid.msq.input.ReadableInput; +import org.apache.druid.query.dimension.DefaultDimensionSpec; +import org.apache.druid.query.dimension.DimensionSpec; +import org.apache.druid.query.filter.ValueMatcher; +import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; +import org.apache.druid.segment.ColumnSelectorFactory; +import org.apache.druid.segment.ColumnValueSelector; +import org.apache.druid.segment.Cursor; +import org.apache.druid.segment.DimensionSelector; +import org.apache.druid.segment.DimensionSelectorUtils; +import org.apache.druid.segment.IdLookup; +import org.apache.druid.segment.NilColumnValueSelector; +import org.apache.druid.segment.column.ColumnCapabilities; +import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.segment.data.IndexedInts; +import org.apache.druid.segment.data.ZeroIndexedInts; +import org.apache.druid.segment.join.JoinPrefixUtils; +import org.apache.druid.segment.join.JoinType; + +import javax.annotation.Nullable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +/** + * Processor for a sort-merge join of two inputs. + * + * Prerequisites: + * + * 1) Two inputs, both of which are stages; i.e. {@link ReadableInput#hasChannel()}. + * + * 2) Conditions are all simple equalities. Validated by {@link SortMergeJoinFrameProcessorFactory#validateCondition} + * and then transformed to lists of key columns by {@link SortMergeJoinFrameProcessorFactory#toKeyColumns}. + * + * 3) Both inputs are comprised of {@link org.apache.druid.frame.FrameType#ROW_BASED} frames, are sorted by the same + * key, and that key can be used to check the provided condition. Validated by + * {@link SortMergeJoinFrameProcessorFactory#validateInputFrameSignatures}. + * + * Algorithm: + * + * 1) Read current key from each side of the join. + * + * 2) If there is no match, emit or skip the row for the earlier key, as appropriate, based on the join type. + * + * 3) If there is a match, identify a complete set on one side or the other. (It doesn't matter which side has the + * complete set, but we need it on one of them.) We mark the first row for the key using {@link Tracker#markCurrent()} + * and find complete sets using {@link Tracker#hasCompleteSetForMark()}. Once we find one, we store it in + * {@link #trackerWithCompleteSetForCurrentKey}. If both sides have a complete set, we break ties by choosing the + * left side. + * + * 4) Once a complete set for the current key is identified: for each row on the *other* side, loop through the entire + * set of rows on {@link #trackerWithCompleteSetForCurrentKey}, and emit that many joined rows. + * + * 5) Once we process the final row on the *other* side, reset both marks with {@link Tracker#markCurrent()} and + * continue the algorithm. + */ +public class SortMergeJoinFrameProcessor implements FrameProcessor<Long> +{ + private static final int LEFT = 0; + private static final int RIGHT = 1; + + /** + * Two sides of the join. Must be channels; i.e. {@link ReadableInput#hasChannel()} must be true. Two-element array: + * {@link #LEFT} and {@link #RIGHT}. + */ + private final List<ReadableInput> inputs; + + /** + * Channels from {@link #inputs}. Two-element array: {@link #LEFT} and {@link #RIGHT}. + */ + private final List<ReadableFrameChannel> inputChannels; + + /** + * Names of the key columns on each side of the join, unprefixed. Two-element array: {@link #LEFT} and {@link #RIGHT}. + */ + private final List<List<KeyColumn>> keyColumns; + + /** + * Trackers for each side of the join. Two-element array: {@link #LEFT} and {@link #RIGHT}. + */ + private final List<Tracker> trackers; + + private final WritableFrameChannel outputChannel; + private final FrameWriterFactory frameWriterFactory; + private final String rightPrefix; + private final JoinType joinType; + private final JoinColumnSelectorFactory joinColumnSelectorFactory = new JoinColumnSelectorFactory(); + private FrameWriter frameWriter = null; + + // Used by runIncrementally to defer certain logic to the next run. + private Runnable nextIterationRunnable = null; + + // Used by runIncrementally to remember which tracker has the complete set for the current key. + private int trackerWithCompleteSetForCurrentKey = -1; + + SortMergeJoinFrameProcessor( + ReadableInput left, + ReadableInput right, + WritableFrameChannel outputChannel, + FrameWriterFactory frameWriterFactory, + String rightPrefix, + List<List<KeyColumn>> keyColumns, + JoinType joinType + ) + { + this.inputs = ImmutableList.of(left, right); + this.inputChannels = inputs.stream().map(ReadableInput::getChannel).collect(Collectors.toList()); + this.keyColumns = keyColumns; + this.outputChannel = outputChannel; + this.frameWriterFactory = frameWriterFactory; + this.rightPrefix = rightPrefix; + this.joinType = joinType; + this.trackers = ImmutableList.of(new Tracker(LEFT), new Tracker(RIGHT)); + } + + @Override + public List<ReadableFrameChannel> inputChannels() + { + return inputChannels; + } + + @Override + public List<WritableFrameChannel> outputChannels() + { + return Collections.singletonList(outputChannel); + } + + @Override + public ReturnOrAwait<Long> runIncrementally(IntSet readableInputs) throws IOException + { + // Fetch enough frames such that each tracker has one readable row. + for (int i = 0; i < inputChannels.size(); i++) { + final Tracker tracker = trackers.get(i); + if (tracker.isAtEndOfPushedData() && !pushNextFrame(i)) { + return nextAwait(); + } + } + + // Initialize new output frame, if needed. + startNewFrameIfNeeded(); + + while (!trackers.get(LEFT).needsMoreData() + && !trackers.get(RIGHT).needsMoreData() + && !(trackers.get(LEFT).isAtEnd() && trackers.get(RIGHT).isAtEnd())) { + if (nextIterationRunnable != null) { + final Runnable tmp = nextIterationRunnable; + nextIterationRunnable = null; + tmp.run(); + } + + final int markCmp = compareMarks(); + final boolean match = markCmp == 0 && !trackers.get(LEFT).hasPartiallyNullMark(); + + // Fetch new frames until the algorithm can proceed: + // 1) If marked keys are equal on both sides, at least one side has a complete set of rows for the marked key. + // 2) Otherwise, each side has at least one readable row, or is completely done. (Checked by "while" condition.) + if (match && trackerWithCompleteSetForCurrentKey < 0) { + for (int i = 0; i < inputChannels.size(); i++) { + final Tracker tracker = trackers.get(i); + if (tracker.hasCompleteSetForMark() || (pushNextFrame(i) && tracker.hasCompleteSetForMark())) { + trackerWithCompleteSetForCurrentKey = i; + break; + } + } + + if (trackerWithCompleteSetForCurrentKey < 0) { + // Algorithm cannot proceed; fetch more frames. + return nextAwait(); + } + } + + if (match || (markCmp <= 0 && joinType.isLefty()) || (markCmp >= 0 && joinType.isRighty())) { + // Emit row, if there's room in the current frameWriter. + joinColumnSelectorFactory.cmp = markCmp; + joinColumnSelectorFactory.match = match; + + if (!frameWriter.addSelection()) { + if (frameWriter.getNumRows() > 0) { + // Out of space in the current frame. Run again without moving cursors. + flushCurrentFrame(); + return ReturnOrAwait.runAgain(); + } else { + throw new FrameRowTooLargeException(frameWriterFactory.allocatorCapacity()); + } + } + } + + // Advance one or both trackers. + if (match) { + // Matching keys. First advance the tracker with the complete set. + final Tracker tracker = trackers.get(trackerWithCompleteSetForCurrentKey); + final Tracker otherTracker = trackers.get(trackerWithCompleteSetForCurrentKey == LEFT ? RIGHT : LEFT); + + tracker.advance(); + if (!tracker.isCurrentSameKeyAsMark()) { + // Reached end of complete set. Advance the other tracker. + otherTracker.advance(); + + // On next iteration (when we're sure to have data) either rewind the complete-set tracker, or update marks + // of both, as appropriate. + onNextIteration(() -> { + if (otherTracker.isCurrentSameKeyAsMark()) { + otherTracker.markCurrent(); // Set mark to enable cleanup of old frames. + tracker.rewindToMark(); + } else { + // Reached end of the other side too. Advance marks on both trackers. + tracker.markCurrent(); + otherTracker.markCurrent(); + trackerWithCompleteSetForCurrentKey = -1; + } + }); + } + } else { + final int trackerToAdvance; + + if (markCmp < 0) { + trackerToAdvance = LEFT; + } else if (markCmp > 0) { + trackerToAdvance = RIGHT; + } else { + // Key is null on both sides. Note that there is a preference for running through the left side first + // on a FULL join. It doesn't really matter which side we run through first, but we do need to be consistent + // for the benefit of the logic in "shouldEmitColumnValue". + trackerToAdvance = joinType.isLefty() ? LEFT : RIGHT; + } + + final Tracker tracker = trackers.get(trackerToAdvance); + + tracker.advance(); + + // On next iteration (when we're sure to have data), update mark if the key changed. + onNextIteration(() -> { + if (!tracker.isCurrentSameKeyAsMark()) { + tracker.markCurrent(); + trackerWithCompleteSetForCurrentKey = -1; + } + }); + } + } + + if (trackers.get(LEFT).isAtEnd() && trackers.get(RIGHT).isAtEnd()) { + // Both channels completely done. + flushCurrentFrame(); + return ReturnOrAwait.returnObject(0L); + } else { + // Keep reading. + return nextAwait(); + } + } + + @Override + public void cleanup() throws IOException + { + FrameProcessors.closeAll(inputChannels(), outputChannels(), frameWriter, () -> trackers.forEach(Tracker::clear)); + } + + /** + * Returns a {@link ReturnOrAwait#awaitAll} for the channel numbers that need more data and have not yet hit their + * buffered-bytes limit, {@link Limits#MAX_BUFFERED_BYTES_FOR_SORT_MERGE_JOIN}. + * + * If all channels have hit their limit, throws {@link MSQException} with {@link TooManyRowsWithSameKeyFault}. + */ + private ReturnOrAwait<Long> nextAwait() + { + final IntSet awaitSet = new IntOpenHashSet(); + int trackerAtLimit = -1; + + for (int i = 0; i < inputChannels.size(); i++) { + final Tracker tracker = trackers.get(i); + if (tracker.needsMoreData()) { + if (tracker.totalBytesBuffered() < Limits.MAX_BUFFERED_BYTES_FOR_SORT_MERGE_JOIN) { + awaitSet.add(i); + } else if (trackerAtLimit < 0) { + trackerAtLimit = i; + } + } + } + + if (awaitSet.isEmpty() && trackerAtLimit > 0) { + // All trackers that need more data are at their max buffered bytes limit. Generate a nice exception. + final Tracker tracker = trackers.get(trackerAtLimit); + if (tracker.totalBytesBuffered() > Limits.MAX_BUFFERED_BYTES_FOR_SORT_MERGE_JOIN) { + // Generate a nice exception. + throw new MSQException( + new TooManyRowsWithSameKeyFault( + tracker.readMarkKey(), + tracker.totalBytesBuffered(), + Limits.MAX_BUFFERED_BYTES_FOR_SORT_MERGE_JOIN + ) + ); + } + } + + return ReturnOrAwait.awaitAll(awaitSet); + } + + /** + * Compares the marked rows of the two {@link #trackers}. + * + * @throws IllegalStateException if either tracker does not have a marked row and is not completely done + */ + private int compareMarks() + { + final Tracker leftTracker = trackers.get(LEFT); + final Tracker rightTracker = trackers.get(RIGHT); + + Preconditions.checkState(leftTracker.hasMark() || leftTracker.isAtEnd(), "left.hasMark || left.isAtEnd"); + Preconditions.checkState(rightTracker.hasMark() || rightTracker.isAtEnd(), "right.hasMark || right.isAtEnd"); + + if (leftTracker.markFrame < 0) { + return rightTracker.markFrame < 0 ? 0 : 1; + } else if (rightTracker.markFrame < 0) { + return -1; + } else { + final FrameHolder leftHolder = leftTracker.holders.get(leftTracker.markFrame); + final FrameHolder rightHolder = rightTracker.holders.get(rightTracker.markFrame); + return leftHolder.comparisonWidget.compare( + leftTracker.markRow, + rightHolder.comparisonWidget, + rightTracker.markRow + ); + } + } + + /** + * Pushes a frame from the indicated channel into the appropriate tracker. Returns true if a frame was pushed + * or if the channel is finished. + */ + private boolean pushNextFrame(final int channelNumber) + { + final ReadableFrameChannel channel = inputChannels.get(channelNumber); + final Tracker tracker = trackers.get(channelNumber); + + if (!channel.isFinished() && !channel.canRead()) { + return false; + } else if (channel.isFinished()) { + tracker.push(null); + return true; + } else { + final Frame frame = channel.read(); + + if (frame.numRows() == 0) { + // Skip, read next. + return false; + } else { + tracker.push(frame); + return true; + } + } + } + + private void onNextIteration(final Runnable runnable) + { + if (nextIterationRunnable != null) { + throw new ISE("postAdvanceRunnable already set"); + } else { + nextIterationRunnable = runnable; + } + } + + private void startNewFrameIfNeeded() + { + if (frameWriter == null) { + frameWriter = frameWriterFactory.newFrameWriter(joinColumnSelectorFactory); + } + } + + private void flushCurrentFrame() throws IOException + { + if (frameWriter != null) { + if (frameWriter.getNumRows() > 0) { + final Frame frame = Frame.wrap(frameWriter.toByteArray()); + frameWriter.close(); + frameWriter = null; + outputChannel.write(new FrameWithPartition(frame, FrameWithPartition.NO_PARTITION)); + } + } + } + + /** + * Tracks the current set of rows that have the same key from a sequence of frames. + * + * markFrame and markRow are set when we encounter a new key, which enables rewinding and re-reading data with the + * same key. + */ + private class Tracker + { + private final List<FrameHolder> holders = new ArrayList<>(); + private final int channelNumber; + + // markFrame and markRow are the first frame and row with the current key. + private int markFrame = -1; + private int markRow = -1; + + // currentFrame is the frame containing the current cursor row. + private int currentFrame = -1; + + // done indicates that no more data is available in the channel. + private boolean done; + + public Tracker(int channelNumber) + { + this.channelNumber = channelNumber; + } + + /** + * Adds a holder for a frame. If this is the first frame, sets the current cursor position and mark to the first + * row of the frame. Otherwise, the cursor position and mark are not changed. + * + * Pushing a null frame indicates no more frames are coming. + * + * @param frame frame, or null indicating no more frames are coming + */ + public void push(final Frame frame) + { + if (frame == null) { + done = true; + return; + } + + if (done) { + throw new ISE("Cannot push frames when already done"); + } + + final boolean atEndOfPushedData = isAtEndOfPushedData(); + final FrameReader frameReader = inputs.get(channelNumber).getChannelFrameReader(); + final FrameCursor cursor = FrameProcessors.makeCursor(frame, frameReader); + final FrameComparisonWidget comparisonWidget = + frameReader.makeComparisonWidget(frame, keyColumns.get(channelNumber)); + + final RowSignature.Builder keySignatureBuilder = RowSignature.builder(); + for (final KeyColumn keyColumn : keyColumns.get(channelNumber)) { + keySignatureBuilder.add( + keyColumn.columnName(), + frameReader.signature().getColumnType(keyColumn.columnName()).orElse(null) + ); + } + + holders.add( + new FrameHolder( + frame, + RowKeyReader.create(keySignatureBuilder.build()), + cursor, + comparisonWidget + ) + ); + + if (atEndOfPushedData) { + // Move currentFrame so it points at the next row, which we now have, instead of an "isDone" cursor. + currentFrame = currentFrame < 0 ? 0 : currentFrame + 1; + } + + if (markFrame < 0) { + // Cleared mark means we want the current row to be marked. + markFrame = currentFrame; + markRow = 0; + } + } + + /** + * Number of bytes currently buffered in {@link #holders}. + */ + public long totalBytesBuffered() + { + long bytes = 0; + for (final FrameHolder holder : holders) { + bytes += holder.frame.numBytes(); + } + return bytes; + } + + /** + * Cursor containing the current row. + */ + @Nullable + public FrameCursor currentCursor() + { + if (currentFrame < 0) { + return null; + } else { + return holders.get(currentFrame).cursor; + } + } + + /** + * Advances the current row (the current row of {@link #currentFrame}). After calling this method, + * {@link #isAtEndOfPushedData()} may start returning true. + */ + public void advance() + { + assert !isAtEndOfPushedData(); + + final FrameHolder currentHolder = holders.get(currentFrame); + + currentHolder.cursor.advance(); + + if (currentHolder.cursor.isDone() && currentFrame + 1 < holders.size()) { + currentFrame++; + holders.get(currentFrame).cursor.reset(); + } Review Comment: There's a `holders.clear()` in `markCurrent()`. We can't get rid of the old holders until the mark moves on. (The purpose of holding on to them is to enable us to collect a complete set of rows for the marked key.) I added comments to `holders` clarifying this. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
