gianm commented on code in PR #13506: URL: https://github.com/apache/druid/pull/13506#discussion_r1114174262
########## extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessor.java: ########## @@ -0,0 +1,1054 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.querykit.common; + +import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; +import com.google.common.collect.ImmutableList; +import it.unimi.dsi.fastutil.ints.IntOpenHashSet; +import it.unimi.dsi.fastutil.ints.IntSet; +import org.apache.druid.frame.Frame; +import org.apache.druid.frame.channel.FrameWithPartition; +import org.apache.druid.frame.channel.ReadableFrameChannel; +import org.apache.druid.frame.channel.WritableFrameChannel; +import org.apache.druid.frame.key.FrameComparisonWidget; +import org.apache.druid.frame.key.KeyColumn; +import org.apache.druid.frame.key.RowKey; +import org.apache.druid.frame.key.RowKeyReader; +import org.apache.druid.frame.processor.FrameProcessor; +import org.apache.druid.frame.processor.FrameProcessors; +import org.apache.druid.frame.processor.FrameRowTooLargeException; +import org.apache.druid.frame.processor.ReturnOrAwait; +import org.apache.druid.frame.read.FrameReader; +import org.apache.druid.frame.segment.FrameCursor; +import org.apache.druid.frame.write.FrameWriter; +import org.apache.druid.frame.write.FrameWriterFactory; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.msq.exec.Limits; +import org.apache.druid.msq.indexing.error.MSQException; +import org.apache.druid.msq.indexing.error.TooManyRowsWithSameKeyFault; +import org.apache.druid.msq.input.ReadableInput; +import org.apache.druid.query.dimension.DefaultDimensionSpec; +import org.apache.druid.query.dimension.DimensionSpec; +import org.apache.druid.query.filter.ValueMatcher; +import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; +import org.apache.druid.segment.ColumnSelectorFactory; +import org.apache.druid.segment.ColumnValueSelector; +import org.apache.druid.segment.Cursor; +import org.apache.druid.segment.DimensionSelector; +import org.apache.druid.segment.DimensionSelectorUtils; +import org.apache.druid.segment.IdLookup; +import org.apache.druid.segment.NilColumnValueSelector; +import org.apache.druid.segment.column.ColumnCapabilities; +import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.segment.data.IndexedInts; +import org.apache.druid.segment.data.ZeroIndexedInts; +import org.apache.druid.segment.join.JoinPrefixUtils; +import org.apache.druid.segment.join.JoinType; + +import javax.annotation.Nullable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +/** + * Processor for a sort-merge join of two inputs. + * + * Prerequisites: + * + * 1) Two inputs, both of which are stages; i.e. {@link ReadableInput#hasChannel()}. + * + * 2) Conditions are all simple equalities. Validated by {@link SortMergeJoinFrameProcessorFactory#validateCondition} + * and then transformed to lists of key columns by {@link SortMergeJoinFrameProcessorFactory#toKeyColumns}. + * + * 3) Both inputs are comprised of {@link org.apache.druid.frame.FrameType#ROW_BASED} frames, are sorted by the same + * key, and that key can be used to check the provided condition. Validated by + * {@link SortMergeJoinFrameProcessorFactory#validateInputFrameSignatures}. + * + * Algorithm: + * + * 1) Read current key from each side of the join. + * + * 2) If there is no match, emit or skip the row for the earlier key, as appropriate, based on the join type. + * + * 3) If there is a match, identify a complete set on one side or the other. (It doesn't matter which side has the + * complete set, but we need it on one of them.) We mark the first row for the key using {@link Tracker#markCurrent()} + * and find complete sets using {@link Tracker#hasCompleteSetForMark()}. Once we find one, we store it in + * {@link #trackerWithCompleteSetForCurrentKey}. If both sides have a complete set, we break ties by choosing the + * left side. + * + * 4) Once a complete set for the current key is identified: for each row on the *other* side, loop through the entire + * set of rows on {@link #trackerWithCompleteSetForCurrentKey}, and emit that many joined rows. + * + * 5) Once we process the final row on the *other* side, reset both marks with {@link Tracker#markCurrent()} and + * continue the algorithm. + */ +public class SortMergeJoinFrameProcessor implements FrameProcessor<Long> +{ + private static final int LEFT = 0; + private static final int RIGHT = 1; + + /** + * Two sides of the join. Must be channels; i.e. {@link ReadableInput#hasChannel()} must be true. Two-element array: + * {@link #LEFT} and {@link #RIGHT}. + */ + private final List<ReadableInput> inputs; + + /** + * Channels from {@link #inputs}. Two-element array: {@link #LEFT} and {@link #RIGHT}. + */ + private final List<ReadableFrameChannel> inputChannels; + + /** + * Names of the key columns on each side of the join, unprefixed. Two-element array: {@link #LEFT} and {@link #RIGHT}. + */ + private final List<List<KeyColumn>> keyColumns; + + /** + * Trackers for each side of the join. Two-element array: {@link #LEFT} and {@link #RIGHT}. + */ + private final List<Tracker> trackers; + + private final WritableFrameChannel outputChannel; + private final FrameWriterFactory frameWriterFactory; + private final String rightPrefix; + private final JoinType joinType; + private final JoinColumnSelectorFactory joinColumnSelectorFactory = new JoinColumnSelectorFactory(); + private FrameWriter frameWriter = null; + + // Used by runIncrementally to defer certain logic to the next run. + private Runnable nextIterationRunnable = null; + + // Used by runIncrementally to remember which tracker has the complete set for the current key. + private int trackerWithCompleteSetForCurrentKey = -1; + + SortMergeJoinFrameProcessor( + ReadableInput left, + ReadableInput right, + WritableFrameChannel outputChannel, + FrameWriterFactory frameWriterFactory, + String rightPrefix, + List<List<KeyColumn>> keyColumns, + JoinType joinType + ) + { + this.inputs = ImmutableList.of(left, right); + this.inputChannels = inputs.stream().map(ReadableInput::getChannel).collect(Collectors.toList()); + this.keyColumns = keyColumns; + this.outputChannel = outputChannel; + this.frameWriterFactory = frameWriterFactory; + this.rightPrefix = rightPrefix; + this.joinType = joinType; + this.trackers = ImmutableList.of(new Tracker(LEFT), new Tracker(RIGHT)); + } + + @Override + public List<ReadableFrameChannel> inputChannels() + { + return inputChannels; + } + + @Override + public List<WritableFrameChannel> outputChannels() + { + return Collections.singletonList(outputChannel); + } + + @Override + public ReturnOrAwait<Long> runIncrementally(IntSet readableInputs) throws IOException + { + // Fetch enough frames such that each tracker has one readable row. + for (int i = 0; i < inputChannels.size(); i++) { + final Tracker tracker = trackers.get(i); + if (tracker.isAtEndOfPushedData() && !pushNextFrame(i)) { + return nextAwait(); + } + } + + // Initialize new output frame, if needed. + startNewFrameIfNeeded(); + + while (!trackers.get(LEFT).needsMoreData() + && !trackers.get(RIGHT).needsMoreData() + && !(trackers.get(LEFT).isAtEnd() && trackers.get(RIGHT).isAtEnd())) { + if (nextIterationRunnable != null) { + final Runnable tmp = nextIterationRunnable; + nextIterationRunnable = null; + tmp.run(); + } + + final int markCmp = compareMarks(); + final boolean match = markCmp == 0 && !trackers.get(LEFT).hasPartiallyNullMark(); Review Comment: It's to achieve proper join semantics. I added a comment: ``` // Two rows match if the keys compare equal _and_ neither key has a null component. (x JOIN y ON x.a = y.a does // not match rows where "x.a" is null.) ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
