imply-cheddar commented on code in PR #13506: URL: https://github.com/apache/druid/pull/13506#discussion_r1042823240
########## extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessor.java: ########## @@ -0,0 +1,1054 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.querykit.common; + +import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; +import com.google.common.collect.ImmutableList; +import it.unimi.dsi.fastutil.ints.IntOpenHashSet; +import it.unimi.dsi.fastutil.ints.IntSet; +import org.apache.druid.frame.Frame; +import org.apache.druid.frame.channel.FrameWithPartition; +import org.apache.druid.frame.channel.ReadableFrameChannel; +import org.apache.druid.frame.channel.WritableFrameChannel; +import org.apache.druid.frame.key.FrameComparisonWidget; +import org.apache.druid.frame.key.KeyColumn; +import org.apache.druid.frame.key.RowKey; +import org.apache.druid.frame.key.RowKeyReader; +import org.apache.druid.frame.processor.FrameProcessor; +import org.apache.druid.frame.processor.FrameProcessors; +import org.apache.druid.frame.processor.FrameRowTooLargeException; +import org.apache.druid.frame.processor.ReturnOrAwait; +import org.apache.druid.frame.read.FrameReader; +import org.apache.druid.frame.segment.FrameCursor; +import org.apache.druid.frame.write.FrameWriter; +import org.apache.druid.frame.write.FrameWriterFactory; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.msq.exec.Limits; +import org.apache.druid.msq.indexing.error.MSQException; +import org.apache.druid.msq.indexing.error.TooManyRowsWithSameKeyFault; +import org.apache.druid.msq.input.ReadableInput; +import org.apache.druid.query.dimension.DefaultDimensionSpec; +import org.apache.druid.query.dimension.DimensionSpec; +import org.apache.druid.query.filter.ValueMatcher; +import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; +import org.apache.druid.segment.ColumnSelectorFactory; +import org.apache.druid.segment.ColumnValueSelector; +import org.apache.druid.segment.Cursor; +import org.apache.druid.segment.DimensionSelector; +import org.apache.druid.segment.DimensionSelectorUtils; +import org.apache.druid.segment.IdLookup; +import org.apache.druid.segment.NilColumnValueSelector; +import org.apache.druid.segment.column.ColumnCapabilities; +import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.segment.data.IndexedInts; +import org.apache.druid.segment.data.ZeroIndexedInts; +import org.apache.druid.segment.join.JoinPrefixUtils; +import org.apache.druid.segment.join.JoinType; + +import javax.annotation.Nullable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +/** + * Processor for a sort-merge join of two inputs. + * + * Prerequisites: + * + * 1) Two inputs, both of which are stages; i.e. {@link ReadableInput#hasChannel()}. + * + * 2) Conditions are all simple equalities. Validated by {@link SortMergeJoinFrameProcessorFactory#validateCondition} + * and then transformed to lists of key columns by {@link SortMergeJoinFrameProcessorFactory#toKeyColumns}. + * + * 3) Both inputs are comprised of {@link org.apache.druid.frame.FrameType#ROW_BASED} frames, are sorted by the same + * key, and that key can be used to check the provided condition. Validated by + * {@link SortMergeJoinFrameProcessorFactory#validateInputFrameSignatures}. + * + * Algorithm: + * + * 1) Read current key from each side of the join. + * + * 2) If there is no match, emit or skip the row for the earlier key, as appropriate, based on the join type. + * + * 3) If there is a match, identify a complete set on one side or the other. (It doesn't matter which side has the + * complete set, but we need it on one of them.) We mark the first row for the key using {@link Tracker#markCurrent()} + * and find complete sets using {@link Tracker#hasCompleteSetForMark()}. Once we find one, we store it in + * {@link #trackerWithCompleteSetForCurrentKey}. If both sides have a complete set, we break ties by choosing the + * left side. + * + * 4) Once a complete set for the current key is identified: for each row on the *other* side, loop through the entire + * set of rows on {@link #trackerWithCompleteSetForCurrentKey}, and emit that many joined rows. + * + * 5) Once we process the final row on the *other* side, reset both marks with {@link Tracker#markCurrent()} and + * continue the algorithm. + */ +public class SortMergeJoinFrameProcessor implements FrameProcessor<Long> +{ + private static final int LEFT = 0; + private static final int RIGHT = 1; + + /** + * Two sides of the join. Must be channels; i.e. {@link ReadableInput#hasChannel()} must be true. Two-element array: + * {@link #LEFT} and {@link #RIGHT}. + */ + private final List<ReadableInput> inputs; + + /** + * Channels from {@link #inputs}. Two-element array: {@link #LEFT} and {@link #RIGHT}. + */ + private final List<ReadableFrameChannel> inputChannels; Review Comment: Why not have 2 fields: a left and a right ReadableFrameChannel instead of a List that you do `.get()` on all the time. ########## extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessor.java: ########## @@ -0,0 +1,1054 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.querykit.common; + +import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; +import com.google.common.collect.ImmutableList; +import it.unimi.dsi.fastutil.ints.IntOpenHashSet; +import it.unimi.dsi.fastutil.ints.IntSet; +import org.apache.druid.frame.Frame; +import org.apache.druid.frame.channel.FrameWithPartition; +import org.apache.druid.frame.channel.ReadableFrameChannel; +import org.apache.druid.frame.channel.WritableFrameChannel; +import org.apache.druid.frame.key.FrameComparisonWidget; +import org.apache.druid.frame.key.KeyColumn; +import org.apache.druid.frame.key.RowKey; +import org.apache.druid.frame.key.RowKeyReader; +import org.apache.druid.frame.processor.FrameProcessor; +import org.apache.druid.frame.processor.FrameProcessors; +import org.apache.druid.frame.processor.FrameRowTooLargeException; +import org.apache.druid.frame.processor.ReturnOrAwait; +import org.apache.druid.frame.read.FrameReader; +import org.apache.druid.frame.segment.FrameCursor; +import org.apache.druid.frame.write.FrameWriter; +import org.apache.druid.frame.write.FrameWriterFactory; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.msq.exec.Limits; +import org.apache.druid.msq.indexing.error.MSQException; +import org.apache.druid.msq.indexing.error.TooManyRowsWithSameKeyFault; +import org.apache.druid.msq.input.ReadableInput; +import org.apache.druid.query.dimension.DefaultDimensionSpec; +import org.apache.druid.query.dimension.DimensionSpec; +import org.apache.druid.query.filter.ValueMatcher; +import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; +import org.apache.druid.segment.ColumnSelectorFactory; +import org.apache.druid.segment.ColumnValueSelector; +import org.apache.druid.segment.Cursor; +import org.apache.druid.segment.DimensionSelector; +import org.apache.druid.segment.DimensionSelectorUtils; +import org.apache.druid.segment.IdLookup; +import org.apache.druid.segment.NilColumnValueSelector; +import org.apache.druid.segment.column.ColumnCapabilities; +import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.segment.data.IndexedInts; +import org.apache.druid.segment.data.ZeroIndexedInts; +import org.apache.druid.segment.join.JoinPrefixUtils; +import org.apache.druid.segment.join.JoinType; + +import javax.annotation.Nullable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +/** + * Processor for a sort-merge join of two inputs. + * + * Prerequisites: + * + * 1) Two inputs, both of which are stages; i.e. {@link ReadableInput#hasChannel()}. + * + * 2) Conditions are all simple equalities. Validated by {@link SortMergeJoinFrameProcessorFactory#validateCondition} + * and then transformed to lists of key columns by {@link SortMergeJoinFrameProcessorFactory#toKeyColumns}. + * + * 3) Both inputs are comprised of {@link org.apache.druid.frame.FrameType#ROW_BASED} frames, are sorted by the same + * key, and that key can be used to check the provided condition. Validated by + * {@link SortMergeJoinFrameProcessorFactory#validateInputFrameSignatures}. + * + * Algorithm: + * + * 1) Read current key from each side of the join. + * + * 2) If there is no match, emit or skip the row for the earlier key, as appropriate, based on the join type. + * + * 3) If there is a match, identify a complete set on one side or the other. (It doesn't matter which side has the + * complete set, but we need it on one of them.) We mark the first row for the key using {@link Tracker#markCurrent()} + * and find complete sets using {@link Tracker#hasCompleteSetForMark()}. Once we find one, we store it in + * {@link #trackerWithCompleteSetForCurrentKey}. If both sides have a complete set, we break ties by choosing the + * left side. + * + * 4) Once a complete set for the current key is identified: for each row on the *other* side, loop through the entire + * set of rows on {@link #trackerWithCompleteSetForCurrentKey}, and emit that many joined rows. + * + * 5) Once we process the final row on the *other* side, reset both marks with {@link Tracker#markCurrent()} and + * continue the algorithm. + */ +public class SortMergeJoinFrameProcessor implements FrameProcessor<Long> +{ + private static final int LEFT = 0; + private static final int RIGHT = 1; + + /** + * Two sides of the join. Must be channels; i.e. {@link ReadableInput#hasChannel()} must be true. Two-element array: + * {@link #LEFT} and {@link #RIGHT}. + */ + private final List<ReadableInput> inputs; + + /** + * Channels from {@link #inputs}. Two-element array: {@link #LEFT} and {@link #RIGHT}. + */ + private final List<ReadableFrameChannel> inputChannels; + + /** + * Names of the key columns on each side of the join, unprefixed. Two-element array: {@link #LEFT} and {@link #RIGHT}. + */ + private final List<List<KeyColumn>> keyColumns; + + /** + * Trackers for each side of the join. Two-element array: {@link #LEFT} and {@link #RIGHT}. + */ + private final List<Tracker> trackers; + + private final WritableFrameChannel outputChannel; + private final FrameWriterFactory frameWriterFactory; + private final String rightPrefix; + private final JoinType joinType; + private final JoinColumnSelectorFactory joinColumnSelectorFactory = new JoinColumnSelectorFactory(); + private FrameWriter frameWriter = null; + + // Used by runIncrementally to defer certain logic to the next run. + private Runnable nextIterationRunnable = null; + + // Used by runIncrementally to remember which tracker has the complete set for the current key. + private int trackerWithCompleteSetForCurrentKey = -1; + + SortMergeJoinFrameProcessor( + ReadableInput left, + ReadableInput right, + WritableFrameChannel outputChannel, + FrameWriterFactory frameWriterFactory, + String rightPrefix, + List<List<KeyColumn>> keyColumns, + JoinType joinType + ) + { + this.inputs = ImmutableList.of(left, right); + this.inputChannels = inputs.stream().map(ReadableInput::getChannel).collect(Collectors.toList()); + this.keyColumns = keyColumns; + this.outputChannel = outputChannel; + this.frameWriterFactory = frameWriterFactory; + this.rightPrefix = rightPrefix; + this.joinType = joinType; + this.trackers = ImmutableList.of(new Tracker(LEFT), new Tracker(RIGHT)); + } + + @Override + public List<ReadableFrameChannel> inputChannels() + { + return inputChannels; + } + + @Override + public List<WritableFrameChannel> outputChannels() + { + return Collections.singletonList(outputChannel); + } + + @Override + public ReturnOrAwait<Long> runIncrementally(IntSet readableInputs) throws IOException + { + // Fetch enough frames such that each tracker has one readable row. + for (int i = 0; i < inputChannels.size(); i++) { + final Tracker tracker = trackers.get(i); + if (tracker.isAtEndOfPushedData() && !pushNextFrame(i)) { + return nextAwait(); + } + } + + // Initialize new output frame, if needed. + startNewFrameIfNeeded(); + + while (!trackers.get(LEFT).needsMoreData() + && !trackers.get(RIGHT).needsMoreData() + && !(trackers.get(LEFT).isAtEnd() && trackers.get(RIGHT).isAtEnd())) { + if (nextIterationRunnable != null) { + final Runnable tmp = nextIterationRunnable; + nextIterationRunnable = null; + tmp.run(); + } + + final int markCmp = compareMarks(); + final boolean match = markCmp == 0 && !trackers.get(LEFT).hasPartiallyNullMark(); + + // Fetch new frames until the algorithm can proceed: + // 1) If marked keys are equal on both sides, at least one side has a complete set of rows for the marked key. + // 2) Otherwise, each side has at least one readable row, or is completely done. (Checked by "while" condition.) + if (match && trackerWithCompleteSetForCurrentKey < 0) { + for (int i = 0; i < inputChannels.size(); i++) { + final Tracker tracker = trackers.get(i); + if (tracker.hasCompleteSetForMark() || (pushNextFrame(i) && tracker.hasCompleteSetForMark())) { Review Comment: > (pushNextFrame(i) && tracker.hasCompleteSetForMark()) It's unclear to me why it is sufficient to only push a single next frame and check. This sort of thing is usually in a while-loop, no? ########## extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessor.java: ########## @@ -0,0 +1,1054 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.querykit.common; + +import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; +import com.google.common.collect.ImmutableList; +import it.unimi.dsi.fastutil.ints.IntOpenHashSet; +import it.unimi.dsi.fastutil.ints.IntSet; +import org.apache.druid.frame.Frame; +import org.apache.druid.frame.channel.FrameWithPartition; +import org.apache.druid.frame.channel.ReadableFrameChannel; +import org.apache.druid.frame.channel.WritableFrameChannel; +import org.apache.druid.frame.key.FrameComparisonWidget; +import org.apache.druid.frame.key.KeyColumn; +import org.apache.druid.frame.key.RowKey; +import org.apache.druid.frame.key.RowKeyReader; +import org.apache.druid.frame.processor.FrameProcessor; +import org.apache.druid.frame.processor.FrameProcessors; +import org.apache.druid.frame.processor.FrameRowTooLargeException; +import org.apache.druid.frame.processor.ReturnOrAwait; +import org.apache.druid.frame.read.FrameReader; +import org.apache.druid.frame.segment.FrameCursor; +import org.apache.druid.frame.write.FrameWriter; +import org.apache.druid.frame.write.FrameWriterFactory; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.msq.exec.Limits; +import org.apache.druid.msq.indexing.error.MSQException; +import org.apache.druid.msq.indexing.error.TooManyRowsWithSameKeyFault; +import org.apache.druid.msq.input.ReadableInput; +import org.apache.druid.query.dimension.DefaultDimensionSpec; +import org.apache.druid.query.dimension.DimensionSpec; +import org.apache.druid.query.filter.ValueMatcher; +import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; +import org.apache.druid.segment.ColumnSelectorFactory; +import org.apache.druid.segment.ColumnValueSelector; +import org.apache.druid.segment.Cursor; +import org.apache.druid.segment.DimensionSelector; +import org.apache.druid.segment.DimensionSelectorUtils; +import org.apache.druid.segment.IdLookup; +import org.apache.druid.segment.NilColumnValueSelector; +import org.apache.druid.segment.column.ColumnCapabilities; +import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.segment.data.IndexedInts; +import org.apache.druid.segment.data.ZeroIndexedInts; +import org.apache.druid.segment.join.JoinPrefixUtils; +import org.apache.druid.segment.join.JoinType; + +import javax.annotation.Nullable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +/** + * Processor for a sort-merge join of two inputs. + * + * Prerequisites: + * + * 1) Two inputs, both of which are stages; i.e. {@link ReadableInput#hasChannel()}. + * + * 2) Conditions are all simple equalities. Validated by {@link SortMergeJoinFrameProcessorFactory#validateCondition} + * and then transformed to lists of key columns by {@link SortMergeJoinFrameProcessorFactory#toKeyColumns}. + * + * 3) Both inputs are comprised of {@link org.apache.druid.frame.FrameType#ROW_BASED} frames, are sorted by the same + * key, and that key can be used to check the provided condition. Validated by + * {@link SortMergeJoinFrameProcessorFactory#validateInputFrameSignatures}. + * + * Algorithm: + * + * 1) Read current key from each side of the join. + * + * 2) If there is no match, emit or skip the row for the earlier key, as appropriate, based on the join type. + * + * 3) If there is a match, identify a complete set on one side or the other. (It doesn't matter which side has the + * complete set, but we need it on one of them.) We mark the first row for the key using {@link Tracker#markCurrent()} + * and find complete sets using {@link Tracker#hasCompleteSetForMark()}. Once we find one, we store it in + * {@link #trackerWithCompleteSetForCurrentKey}. If both sides have a complete set, we break ties by choosing the + * left side. + * + * 4) Once a complete set for the current key is identified: for each row on the *other* side, loop through the entire + * set of rows on {@link #trackerWithCompleteSetForCurrentKey}, and emit that many joined rows. + * + * 5) Once we process the final row on the *other* side, reset both marks with {@link Tracker#markCurrent()} and + * continue the algorithm. + */ +public class SortMergeJoinFrameProcessor implements FrameProcessor<Long> +{ + private static final int LEFT = 0; + private static final int RIGHT = 1; + + /** + * Two sides of the join. Must be channels; i.e. {@link ReadableInput#hasChannel()} must be true. Two-element array: + * {@link #LEFT} and {@link #RIGHT}. + */ + private final List<ReadableInput> inputs; + + /** + * Channels from {@link #inputs}. Two-element array: {@link #LEFT} and {@link #RIGHT}. + */ + private final List<ReadableFrameChannel> inputChannels; + + /** + * Names of the key columns on each side of the join, unprefixed. Two-element array: {@link #LEFT} and {@link #RIGHT}. + */ + private final List<List<KeyColumn>> keyColumns; + + /** + * Trackers for each side of the join. Two-element array: {@link #LEFT} and {@link #RIGHT}. + */ + private final List<Tracker> trackers; + + private final WritableFrameChannel outputChannel; + private final FrameWriterFactory frameWriterFactory; + private final String rightPrefix; + private final JoinType joinType; + private final JoinColumnSelectorFactory joinColumnSelectorFactory = new JoinColumnSelectorFactory(); + private FrameWriter frameWriter = null; + + // Used by runIncrementally to defer certain logic to the next run. + private Runnable nextIterationRunnable = null; + + // Used by runIncrementally to remember which tracker has the complete set for the current key. + private int trackerWithCompleteSetForCurrentKey = -1; + + SortMergeJoinFrameProcessor( + ReadableInput left, + ReadableInput right, + WritableFrameChannel outputChannel, + FrameWriterFactory frameWriterFactory, + String rightPrefix, + List<List<KeyColumn>> keyColumns, + JoinType joinType + ) + { + this.inputs = ImmutableList.of(left, right); + this.inputChannels = inputs.stream().map(ReadableInput::getChannel).collect(Collectors.toList()); + this.keyColumns = keyColumns; + this.outputChannel = outputChannel; + this.frameWriterFactory = frameWriterFactory; + this.rightPrefix = rightPrefix; + this.joinType = joinType; + this.trackers = ImmutableList.of(new Tracker(LEFT), new Tracker(RIGHT)); + } + + @Override + public List<ReadableFrameChannel> inputChannels() + { + return inputChannels; + } + + @Override + public List<WritableFrameChannel> outputChannels() + { + return Collections.singletonList(outputChannel); + } + + @Override + public ReturnOrAwait<Long> runIncrementally(IntSet readableInputs) throws IOException + { + // Fetch enough frames such that each tracker has one readable row. + for (int i = 0; i < inputChannels.size(); i++) { + final Tracker tracker = trackers.get(i); + if (tracker.isAtEndOfPushedData() && !pushNextFrame(i)) { + return nextAwait(); + } + } + + // Initialize new output frame, if needed. + startNewFrameIfNeeded(); + + while (!trackers.get(LEFT).needsMoreData() + && !trackers.get(RIGHT).needsMoreData() + && !(trackers.get(LEFT).isAtEnd() && trackers.get(RIGHT).isAtEnd())) { Review Comment: This boolean is large... There's lots of negation. It might be nice to negate a large OR instead of 3 negated-ANDs, but if there's any way to make this boolean simpler, that might be nice. ########## extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessor.java: ########## @@ -0,0 +1,1054 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.querykit.common; + +import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; +import com.google.common.collect.ImmutableList; +import it.unimi.dsi.fastutil.ints.IntOpenHashSet; +import it.unimi.dsi.fastutil.ints.IntSet; +import org.apache.druid.frame.Frame; +import org.apache.druid.frame.channel.FrameWithPartition; +import org.apache.druid.frame.channel.ReadableFrameChannel; +import org.apache.druid.frame.channel.WritableFrameChannel; +import org.apache.druid.frame.key.FrameComparisonWidget; +import org.apache.druid.frame.key.KeyColumn; +import org.apache.druid.frame.key.RowKey; +import org.apache.druid.frame.key.RowKeyReader; +import org.apache.druid.frame.processor.FrameProcessor; +import org.apache.druid.frame.processor.FrameProcessors; +import org.apache.druid.frame.processor.FrameRowTooLargeException; +import org.apache.druid.frame.processor.ReturnOrAwait; +import org.apache.druid.frame.read.FrameReader; +import org.apache.druid.frame.segment.FrameCursor; +import org.apache.druid.frame.write.FrameWriter; +import org.apache.druid.frame.write.FrameWriterFactory; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.msq.exec.Limits; +import org.apache.druid.msq.indexing.error.MSQException; +import org.apache.druid.msq.indexing.error.TooManyRowsWithSameKeyFault; +import org.apache.druid.msq.input.ReadableInput; +import org.apache.druid.query.dimension.DefaultDimensionSpec; +import org.apache.druid.query.dimension.DimensionSpec; +import org.apache.druid.query.filter.ValueMatcher; +import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; +import org.apache.druid.segment.ColumnSelectorFactory; +import org.apache.druid.segment.ColumnValueSelector; +import org.apache.druid.segment.Cursor; +import org.apache.druid.segment.DimensionSelector; +import org.apache.druid.segment.DimensionSelectorUtils; +import org.apache.druid.segment.IdLookup; +import org.apache.druid.segment.NilColumnValueSelector; +import org.apache.druid.segment.column.ColumnCapabilities; +import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.segment.data.IndexedInts; +import org.apache.druid.segment.data.ZeroIndexedInts; +import org.apache.druid.segment.join.JoinPrefixUtils; +import org.apache.druid.segment.join.JoinType; + +import javax.annotation.Nullable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +/** + * Processor for a sort-merge join of two inputs. + * + * Prerequisites: + * + * 1) Two inputs, both of which are stages; i.e. {@link ReadableInput#hasChannel()}. + * + * 2) Conditions are all simple equalities. Validated by {@link SortMergeJoinFrameProcessorFactory#validateCondition} + * and then transformed to lists of key columns by {@link SortMergeJoinFrameProcessorFactory#toKeyColumns}. + * + * 3) Both inputs are comprised of {@link org.apache.druid.frame.FrameType#ROW_BASED} frames, are sorted by the same + * key, and that key can be used to check the provided condition. Validated by + * {@link SortMergeJoinFrameProcessorFactory#validateInputFrameSignatures}. + * + * Algorithm: + * + * 1) Read current key from each side of the join. + * + * 2) If there is no match, emit or skip the row for the earlier key, as appropriate, based on the join type. + * + * 3) If there is a match, identify a complete set on one side or the other. (It doesn't matter which side has the + * complete set, but we need it on one of them.) We mark the first row for the key using {@link Tracker#markCurrent()} + * and find complete sets using {@link Tracker#hasCompleteSetForMark()}. Once we find one, we store it in + * {@link #trackerWithCompleteSetForCurrentKey}. If both sides have a complete set, we break ties by choosing the + * left side. + * + * 4) Once a complete set for the current key is identified: for each row on the *other* side, loop through the entire + * set of rows on {@link #trackerWithCompleteSetForCurrentKey}, and emit that many joined rows. + * + * 5) Once we process the final row on the *other* side, reset both marks with {@link Tracker#markCurrent()} and + * continue the algorithm. + */ +public class SortMergeJoinFrameProcessor implements FrameProcessor<Long> +{ + private static final int LEFT = 0; + private static final int RIGHT = 1; + + /** + * Two sides of the join. Must be channels; i.e. {@link ReadableInput#hasChannel()} must be true. Two-element array: + * {@link #LEFT} and {@link #RIGHT}. + */ + private final List<ReadableInput> inputs; + + /** + * Channels from {@link #inputs}. Two-element array: {@link #LEFT} and {@link #RIGHT}. + */ + private final List<ReadableFrameChannel> inputChannels; + + /** + * Names of the key columns on each side of the join, unprefixed. Two-element array: {@link #LEFT} and {@link #RIGHT}. + */ + private final List<List<KeyColumn>> keyColumns; + + /** + * Trackers for each side of the join. Two-element array: {@link #LEFT} and {@link #RIGHT}. + */ + private final List<Tracker> trackers; + + private final WritableFrameChannel outputChannel; + private final FrameWriterFactory frameWriterFactory; + private final String rightPrefix; + private final JoinType joinType; + private final JoinColumnSelectorFactory joinColumnSelectorFactory = new JoinColumnSelectorFactory(); + private FrameWriter frameWriter = null; + + // Used by runIncrementally to defer certain logic to the next run. + private Runnable nextIterationRunnable = null; + + // Used by runIncrementally to remember which tracker has the complete set for the current key. + private int trackerWithCompleteSetForCurrentKey = -1; + + SortMergeJoinFrameProcessor( + ReadableInput left, + ReadableInput right, + WritableFrameChannel outputChannel, + FrameWriterFactory frameWriterFactory, + String rightPrefix, + List<List<KeyColumn>> keyColumns, + JoinType joinType + ) + { + this.inputs = ImmutableList.of(left, right); + this.inputChannels = inputs.stream().map(ReadableInput::getChannel).collect(Collectors.toList()); + this.keyColumns = keyColumns; + this.outputChannel = outputChannel; + this.frameWriterFactory = frameWriterFactory; + this.rightPrefix = rightPrefix; + this.joinType = joinType; + this.trackers = ImmutableList.of(new Tracker(LEFT), new Tracker(RIGHT)); + } + + @Override + public List<ReadableFrameChannel> inputChannels() + { + return inputChannels; + } + + @Override + public List<WritableFrameChannel> outputChannels() + { + return Collections.singletonList(outputChannel); + } + + @Override + public ReturnOrAwait<Long> runIncrementally(IntSet readableInputs) throws IOException + { + // Fetch enough frames such that each tracker has one readable row. + for (int i = 0; i < inputChannels.size(); i++) { + final Tracker tracker = trackers.get(i); + if (tracker.isAtEndOfPushedData() && !pushNextFrame(i)) { + return nextAwait(); + } + } + + // Initialize new output frame, if needed. + startNewFrameIfNeeded(); + + while (!trackers.get(LEFT).needsMoreData() + && !trackers.get(RIGHT).needsMoreData() + && !(trackers.get(LEFT).isAtEnd() && trackers.get(RIGHT).isAtEnd())) { + if (nextIterationRunnable != null) { + final Runnable tmp = nextIterationRunnable; + nextIterationRunnable = null; + tmp.run(); + } + + final int markCmp = compareMarks(); + final boolean match = markCmp == 0 && !trackers.get(LEFT).hasPartiallyNullMark(); + + // Fetch new frames until the algorithm can proceed: + // 1) If marked keys are equal on both sides, at least one side has a complete set of rows for the marked key. + // 2) Otherwise, each side has at least one readable row, or is completely done. (Checked by "while" condition.) + if (match && trackerWithCompleteSetForCurrentKey < 0) { + for (int i = 0; i < inputChannels.size(); i++) { + final Tracker tracker = trackers.get(i); + if (tracker.hasCompleteSetForMark() || (pushNextFrame(i) && tracker.hasCompleteSetForMark())) { + trackerWithCompleteSetForCurrentKey = i; + break; + } + } + + if (trackerWithCompleteSetForCurrentKey < 0) { + // Algorithm cannot proceed; fetch more frames. + return nextAwait(); + } + } + + if (match || (markCmp <= 0 && joinType.isLefty()) || (markCmp >= 0 && joinType.isRighty())) { + // Emit row, if there's room in the current frameWriter. + joinColumnSelectorFactory.cmp = markCmp; + joinColumnSelectorFactory.match = match; + + if (!frameWriter.addSelection()) { + if (frameWriter.getNumRows() > 0) { + // Out of space in the current frame. Run again without moving cursors. + flushCurrentFrame(); + return ReturnOrAwait.runAgain(); + } else { + throw new FrameRowTooLargeException(frameWriterFactory.allocatorCapacity()); + } + } + } + + // Advance one or both trackers. + if (match) { + // Matching keys. First advance the tracker with the complete set. + final Tracker tracker = trackers.get(trackerWithCompleteSetForCurrentKey); + final Tracker otherTracker = trackers.get(trackerWithCompleteSetForCurrentKey == LEFT ? RIGHT : LEFT); + + tracker.advance(); + if (!tracker.isCurrentSameKeyAsMark()) { + // Reached end of complete set. Advance the other tracker. + otherTracker.advance(); + + // On next iteration (when we're sure to have data) either rewind the complete-set tracker, or update marks + // of both, as appropriate. + onNextIteration(() -> { + if (otherTracker.isCurrentSameKeyAsMark()) { + otherTracker.markCurrent(); // Set mark to enable cleanup of old frames. + tracker.rewindToMark(); + } else { + // Reached end of the other side too. Advance marks on both trackers. + tracker.markCurrent(); + otherTracker.markCurrent(); + trackerWithCompleteSetForCurrentKey = -1; + } + }); + } + } else { + final int trackerToAdvance; + + if (markCmp < 0) { + trackerToAdvance = LEFT; + } else if (markCmp > 0) { + trackerToAdvance = RIGHT; + } else { + // Key is null on both sides. Note that there is a preference for running through the left side first + // on a FULL join. It doesn't really matter which side we run through first, but we do need to be consistent + // for the benefit of the logic in "shouldEmitColumnValue". + trackerToAdvance = joinType.isLefty() ? LEFT : RIGHT; + } + + final Tracker tracker = trackers.get(trackerToAdvance); + + tracker.advance(); + + // On next iteration (when we're sure to have data), update mark if the key changed. + onNextIteration(() -> { + if (!tracker.isCurrentSameKeyAsMark()) { + tracker.markCurrent(); + trackerWithCompleteSetForCurrentKey = -1; + } + }); + } + } + + if (trackers.get(LEFT).isAtEnd() && trackers.get(RIGHT).isAtEnd()) { + // Both channels completely done. + flushCurrentFrame(); + return ReturnOrAwait.returnObject(0L); + } else { + // Keep reading. + return nextAwait(); + } + } + + @Override + public void cleanup() throws IOException + { + FrameProcessors.closeAll(inputChannels(), outputChannels(), frameWriter, () -> trackers.forEach(Tracker::clear)); + } + + /** + * Returns a {@link ReturnOrAwait#awaitAll} for the channel numbers that need more data and have not yet hit their + * buffered-bytes limit, {@link Limits#MAX_BUFFERED_BYTES_FOR_SORT_MERGE_JOIN}. + * + * If all channels have hit their limit, throws {@link MSQException} with {@link TooManyRowsWithSameKeyFault}. + */ + private ReturnOrAwait<Long> nextAwait() + { + final IntSet awaitSet = new IntOpenHashSet(); + int trackerAtLimit = -1; + + for (int i = 0; i < inputChannels.size(); i++) { + final Tracker tracker = trackers.get(i); + if (tracker.needsMoreData()) { + if (tracker.totalBytesBuffered() < Limits.MAX_BUFFERED_BYTES_FOR_SORT_MERGE_JOIN) { + awaitSet.add(i); + } else if (trackerAtLimit < 0) { + trackerAtLimit = i; + } + } + } + + if (awaitSet.isEmpty() && trackerAtLimit > 0) { + // All trackers that need more data are at their max buffered bytes limit. Generate a nice exception. + final Tracker tracker = trackers.get(trackerAtLimit); + if (tracker.totalBytesBuffered() > Limits.MAX_BUFFERED_BYTES_FOR_SORT_MERGE_JOIN) { + // Generate a nice exception. + throw new MSQException( + new TooManyRowsWithSameKeyFault( + tracker.readMarkKey(), + tracker.totalBytesBuffered(), + Limits.MAX_BUFFERED_BYTES_FOR_SORT_MERGE_JOIN + ) + ); + } + } + + return ReturnOrAwait.awaitAll(awaitSet); + } + + /** + * Compares the marked rows of the two {@link #trackers}. + * + * @throws IllegalStateException if either tracker does not have a marked row and is not completely done + */ + private int compareMarks() + { + final Tracker leftTracker = trackers.get(LEFT); + final Tracker rightTracker = trackers.get(RIGHT); + + Preconditions.checkState(leftTracker.hasMark() || leftTracker.isAtEnd(), "left.hasMark || left.isAtEnd"); + Preconditions.checkState(rightTracker.hasMark() || rightTracker.isAtEnd(), "right.hasMark || right.isAtEnd"); + + if (leftTracker.markFrame < 0) { + return rightTracker.markFrame < 0 ? 0 : 1; + } else if (rightTracker.markFrame < 0) { + return -1; + } else { + final FrameHolder leftHolder = leftTracker.holders.get(leftTracker.markFrame); + final FrameHolder rightHolder = rightTracker.holders.get(rightTracker.markFrame); + return leftHolder.comparisonWidget.compare( + leftTracker.markRow, + rightHolder.comparisonWidget, + rightTracker.markRow + ); + } + } + + /** + * Pushes a frame from the indicated channel into the appropriate tracker. Returns true if a frame was pushed + * or if the channel is finished. + */ + private boolean pushNextFrame(final int channelNumber) + { + final ReadableFrameChannel channel = inputChannels.get(channelNumber); + final Tracker tracker = trackers.get(channelNumber); + + if (!channel.isFinished() && !channel.canRead()) { + return false; + } else if (channel.isFinished()) { + tracker.push(null); + return true; + } else { + final Frame frame = channel.read(); + + if (frame.numRows() == 0) { + // Skip, read next. + return false; + } else { + tracker.push(frame); + return true; + } + } + } + + private void onNextIteration(final Runnable runnable) + { + if (nextIterationRunnable != null) { + throw new ISE("postAdvanceRunnable already set"); + } else { + nextIterationRunnable = runnable; + } + } + + private void startNewFrameIfNeeded() + { + if (frameWriter == null) { + frameWriter = frameWriterFactory.newFrameWriter(joinColumnSelectorFactory); + } + } + + private void flushCurrentFrame() throws IOException + { + if (frameWriter != null) { + if (frameWriter.getNumRows() > 0) { + final Frame frame = Frame.wrap(frameWriter.toByteArray()); + frameWriter.close(); + frameWriter = null; + outputChannel.write(new FrameWithPartition(frame, FrameWithPartition.NO_PARTITION)); + } + } + } + + /** + * Tracks the current set of rows that have the same key from a sequence of frames. + * + * markFrame and markRow are set when we encounter a new key, which enables rewinding and re-reading data with the + * same key. + */ + private class Tracker + { + private final List<FrameHolder> holders = new ArrayList<>(); + private final int channelNumber; + + // markFrame and markRow are the first frame and row with the current key. + private int markFrame = -1; + private int markRow = -1; + + // currentFrame is the frame containing the current cursor row. + private int currentFrame = -1; + + // done indicates that no more data is available in the channel. + private boolean done; + + public Tracker(int channelNumber) + { + this.channelNumber = channelNumber; + } + + /** + * Adds a holder for a frame. If this is the first frame, sets the current cursor position and mark to the first + * row of the frame. Otherwise, the cursor position and mark are not changed. + * + * Pushing a null frame indicates no more frames are coming. + * + * @param frame frame, or null indicating no more frames are coming + */ + public void push(final Frame frame) + { + if (frame == null) { + done = true; + return; + } + + if (done) { + throw new ISE("Cannot push frames when already done"); + } + + final boolean atEndOfPushedData = isAtEndOfPushedData(); + final FrameReader frameReader = inputs.get(channelNumber).getChannelFrameReader(); + final FrameCursor cursor = FrameProcessors.makeCursor(frame, frameReader); + final FrameComparisonWidget comparisonWidget = + frameReader.makeComparisonWidget(frame, keyColumns.get(channelNumber)); + + final RowSignature.Builder keySignatureBuilder = RowSignature.builder(); + for (final KeyColumn keyColumn : keyColumns.get(channelNumber)) { + keySignatureBuilder.add( + keyColumn.columnName(), + frameReader.signature().getColumnType(keyColumn.columnName()).orElse(null) + ); + } + + holders.add( + new FrameHolder( + frame, + RowKeyReader.create(keySignatureBuilder.build()), + cursor, + comparisonWidget + ) + ); + + if (atEndOfPushedData) { + // Move currentFrame so it points at the next row, which we now have, instead of an "isDone" cursor. + currentFrame = currentFrame < 0 ? 0 : currentFrame + 1; + } + + if (markFrame < 0) { + // Cleared mark means we want the current row to be marked. + markFrame = currentFrame; + markRow = 0; + } + } + + /** + * Number of bytes currently buffered in {@link #holders}. + */ + public long totalBytesBuffered() + { + long bytes = 0; + for (final FrameHolder holder : holders) { + bytes += holder.frame.numBytes(); + } + return bytes; + } + + /** + * Cursor containing the current row. + */ + @Nullable + public FrameCursor currentCursor() + { + if (currentFrame < 0) { + return null; + } else { + return holders.get(currentFrame).cursor; + } + } + + /** + * Advances the current row (the current row of {@link #currentFrame}). After calling this method, + * {@link #isAtEndOfPushedData()} may start returning true. + */ + public void advance() + { + assert !isAtEndOfPushedData(); + + final FrameHolder currentHolder = holders.get(currentFrame); + + currentHolder.cursor.advance(); + + if (currentHolder.cursor.isDone() && currentFrame + 1 < holders.size()) { + currentFrame++; + holders.get(currentFrame).cursor.reset(); + } Review Comment: This appears to never be letting go of the previous FrameHolder objects? It can effectively null out the old holder, right? ########## extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessor.java: ########## @@ -0,0 +1,1054 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.querykit.common; + +import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; +import com.google.common.collect.ImmutableList; +import it.unimi.dsi.fastutil.ints.IntOpenHashSet; +import it.unimi.dsi.fastutil.ints.IntSet; +import org.apache.druid.frame.Frame; +import org.apache.druid.frame.channel.FrameWithPartition; +import org.apache.druid.frame.channel.ReadableFrameChannel; +import org.apache.druid.frame.channel.WritableFrameChannel; +import org.apache.druid.frame.key.FrameComparisonWidget; +import org.apache.druid.frame.key.KeyColumn; +import org.apache.druid.frame.key.RowKey; +import org.apache.druid.frame.key.RowKeyReader; +import org.apache.druid.frame.processor.FrameProcessor; +import org.apache.druid.frame.processor.FrameProcessors; +import org.apache.druid.frame.processor.FrameRowTooLargeException; +import org.apache.druid.frame.processor.ReturnOrAwait; +import org.apache.druid.frame.read.FrameReader; +import org.apache.druid.frame.segment.FrameCursor; +import org.apache.druid.frame.write.FrameWriter; +import org.apache.druid.frame.write.FrameWriterFactory; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.msq.exec.Limits; +import org.apache.druid.msq.indexing.error.MSQException; +import org.apache.druid.msq.indexing.error.TooManyRowsWithSameKeyFault; +import org.apache.druid.msq.input.ReadableInput; +import org.apache.druid.query.dimension.DefaultDimensionSpec; +import org.apache.druid.query.dimension.DimensionSpec; +import org.apache.druid.query.filter.ValueMatcher; +import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; +import org.apache.druid.segment.ColumnSelectorFactory; +import org.apache.druid.segment.ColumnValueSelector; +import org.apache.druid.segment.Cursor; +import org.apache.druid.segment.DimensionSelector; +import org.apache.druid.segment.DimensionSelectorUtils; +import org.apache.druid.segment.IdLookup; +import org.apache.druid.segment.NilColumnValueSelector; +import org.apache.druid.segment.column.ColumnCapabilities; +import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.segment.data.IndexedInts; +import org.apache.druid.segment.data.ZeroIndexedInts; +import org.apache.druid.segment.join.JoinPrefixUtils; +import org.apache.druid.segment.join.JoinType; + +import javax.annotation.Nullable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +/** + * Processor for a sort-merge join of two inputs. + * + * Prerequisites: + * + * 1) Two inputs, both of which are stages; i.e. {@link ReadableInput#hasChannel()}. + * + * 2) Conditions are all simple equalities. Validated by {@link SortMergeJoinFrameProcessorFactory#validateCondition} + * and then transformed to lists of key columns by {@link SortMergeJoinFrameProcessorFactory#toKeyColumns}. + * + * 3) Both inputs are comprised of {@link org.apache.druid.frame.FrameType#ROW_BASED} frames, are sorted by the same + * key, and that key can be used to check the provided condition. Validated by + * {@link SortMergeJoinFrameProcessorFactory#validateInputFrameSignatures}. + * + * Algorithm: + * + * 1) Read current key from each side of the join. + * + * 2) If there is no match, emit or skip the row for the earlier key, as appropriate, based on the join type. + * + * 3) If there is a match, identify a complete set on one side or the other. (It doesn't matter which side has the + * complete set, but we need it on one of them.) We mark the first row for the key using {@link Tracker#markCurrent()} + * and find complete sets using {@link Tracker#hasCompleteSetForMark()}. Once we find one, we store it in + * {@link #trackerWithCompleteSetForCurrentKey}. If both sides have a complete set, we break ties by choosing the + * left side. + * + * 4) Once a complete set for the current key is identified: for each row on the *other* side, loop through the entire + * set of rows on {@link #trackerWithCompleteSetForCurrentKey}, and emit that many joined rows. + * + * 5) Once we process the final row on the *other* side, reset both marks with {@link Tracker#markCurrent()} and + * continue the algorithm. + */ +public class SortMergeJoinFrameProcessor implements FrameProcessor<Long> +{ + private static final int LEFT = 0; + private static final int RIGHT = 1; + + /** + * Two sides of the join. Must be channels; i.e. {@link ReadableInput#hasChannel()} must be true. Two-element array: + * {@link #LEFT} and {@link #RIGHT}. + */ + private final List<ReadableInput> inputs; + + /** + * Channels from {@link #inputs}. Two-element array: {@link #LEFT} and {@link #RIGHT}. + */ + private final List<ReadableFrameChannel> inputChannels; + + /** + * Names of the key columns on each side of the join, unprefixed. Two-element array: {@link #LEFT} and {@link #RIGHT}. + */ + private final List<List<KeyColumn>> keyColumns; + + /** + * Trackers for each side of the join. Two-element array: {@link #LEFT} and {@link #RIGHT}. + */ + private final List<Tracker> trackers; + + private final WritableFrameChannel outputChannel; + private final FrameWriterFactory frameWriterFactory; + private final String rightPrefix; + private final JoinType joinType; + private final JoinColumnSelectorFactory joinColumnSelectorFactory = new JoinColumnSelectorFactory(); + private FrameWriter frameWriter = null; + + // Used by runIncrementally to defer certain logic to the next run. + private Runnable nextIterationRunnable = null; + + // Used by runIncrementally to remember which tracker has the complete set for the current key. + private int trackerWithCompleteSetForCurrentKey = -1; + + SortMergeJoinFrameProcessor( + ReadableInput left, + ReadableInput right, + WritableFrameChannel outputChannel, + FrameWriterFactory frameWriterFactory, + String rightPrefix, + List<List<KeyColumn>> keyColumns, + JoinType joinType + ) + { + this.inputs = ImmutableList.of(left, right); + this.inputChannels = inputs.stream().map(ReadableInput::getChannel).collect(Collectors.toList()); + this.keyColumns = keyColumns; + this.outputChannel = outputChannel; + this.frameWriterFactory = frameWriterFactory; + this.rightPrefix = rightPrefix; + this.joinType = joinType; + this.trackers = ImmutableList.of(new Tracker(LEFT), new Tracker(RIGHT)); + } + + @Override + public List<ReadableFrameChannel> inputChannels() + { + return inputChannels; + } + + @Override + public List<WritableFrameChannel> outputChannels() + { + return Collections.singletonList(outputChannel); + } + + @Override + public ReturnOrAwait<Long> runIncrementally(IntSet readableInputs) throws IOException + { + // Fetch enough frames such that each tracker has one readable row. + for (int i = 0; i < inputChannels.size(); i++) { + final Tracker tracker = trackers.get(i); + if (tracker.isAtEndOfPushedData() && !pushNextFrame(i)) { + return nextAwait(); + } + } + + // Initialize new output frame, if needed. + startNewFrameIfNeeded(); + + while (!trackers.get(LEFT).needsMoreData() + && !trackers.get(RIGHT).needsMoreData() + && !(trackers.get(LEFT).isAtEnd() && trackers.get(RIGHT).isAtEnd())) { + if (nextIterationRunnable != null) { + final Runnable tmp = nextIterationRunnable; + nextIterationRunnable = null; + tmp.run(); + } + + final int markCmp = compareMarks(); + final boolean match = markCmp == 0 && !trackers.get(LEFT).hasPartiallyNullMark(); Review Comment: As I read this line, I'm unclear why "partially null" matters. Maybe it'll become clearer as I read more. I went back up to the class level javadoc and skimmed it one more time, but didn't see the relevance of null called out. ########## extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessor.java: ########## @@ -0,0 +1,1054 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.querykit.common; + +import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; +import com.google.common.collect.ImmutableList; +import it.unimi.dsi.fastutil.ints.IntOpenHashSet; +import it.unimi.dsi.fastutil.ints.IntSet; +import org.apache.druid.frame.Frame; +import org.apache.druid.frame.channel.FrameWithPartition; +import org.apache.druid.frame.channel.ReadableFrameChannel; +import org.apache.druid.frame.channel.WritableFrameChannel; +import org.apache.druid.frame.key.FrameComparisonWidget; +import org.apache.druid.frame.key.KeyColumn; +import org.apache.druid.frame.key.RowKey; +import org.apache.druid.frame.key.RowKeyReader; +import org.apache.druid.frame.processor.FrameProcessor; +import org.apache.druid.frame.processor.FrameProcessors; +import org.apache.druid.frame.processor.FrameRowTooLargeException; +import org.apache.druid.frame.processor.ReturnOrAwait; +import org.apache.druid.frame.read.FrameReader; +import org.apache.druid.frame.segment.FrameCursor; +import org.apache.druid.frame.write.FrameWriter; +import org.apache.druid.frame.write.FrameWriterFactory; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.msq.exec.Limits; +import org.apache.druid.msq.indexing.error.MSQException; +import org.apache.druid.msq.indexing.error.TooManyRowsWithSameKeyFault; +import org.apache.druid.msq.input.ReadableInput; +import org.apache.druid.query.dimension.DefaultDimensionSpec; +import org.apache.druid.query.dimension.DimensionSpec; +import org.apache.druid.query.filter.ValueMatcher; +import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; +import org.apache.druid.segment.ColumnSelectorFactory; +import org.apache.druid.segment.ColumnValueSelector; +import org.apache.druid.segment.Cursor; +import org.apache.druid.segment.DimensionSelector; +import org.apache.druid.segment.DimensionSelectorUtils; +import org.apache.druid.segment.IdLookup; +import org.apache.druid.segment.NilColumnValueSelector; +import org.apache.druid.segment.column.ColumnCapabilities; +import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.segment.data.IndexedInts; +import org.apache.druid.segment.data.ZeroIndexedInts; +import org.apache.druid.segment.join.JoinPrefixUtils; +import org.apache.druid.segment.join.JoinType; + +import javax.annotation.Nullable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +/** + * Processor for a sort-merge join of two inputs. + * + * Prerequisites: + * + * 1) Two inputs, both of which are stages; i.e. {@link ReadableInput#hasChannel()}. + * + * 2) Conditions are all simple equalities. Validated by {@link SortMergeJoinFrameProcessorFactory#validateCondition} + * and then transformed to lists of key columns by {@link SortMergeJoinFrameProcessorFactory#toKeyColumns}. + * + * 3) Both inputs are comprised of {@link org.apache.druid.frame.FrameType#ROW_BASED} frames, are sorted by the same + * key, and that key can be used to check the provided condition. Validated by + * {@link SortMergeJoinFrameProcessorFactory#validateInputFrameSignatures}. + * + * Algorithm: + * + * 1) Read current key from each side of the join. + * + * 2) If there is no match, emit or skip the row for the earlier key, as appropriate, based on the join type. + * + * 3) If there is a match, identify a complete set on one side or the other. (It doesn't matter which side has the + * complete set, but we need it on one of them.) We mark the first row for the key using {@link Tracker#markCurrent()} + * and find complete sets using {@link Tracker#hasCompleteSetForMark()}. Once we find one, we store it in + * {@link #trackerWithCompleteSetForCurrentKey}. If both sides have a complete set, we break ties by choosing the + * left side. + * + * 4) Once a complete set for the current key is identified: for each row on the *other* side, loop through the entire + * set of rows on {@link #trackerWithCompleteSetForCurrentKey}, and emit that many joined rows. + * + * 5) Once we process the final row on the *other* side, reset both marks with {@link Tracker#markCurrent()} and + * continue the algorithm. + */ +public class SortMergeJoinFrameProcessor implements FrameProcessor<Long> +{ + private static final int LEFT = 0; + private static final int RIGHT = 1; + + /** + * Two sides of the join. Must be channels; i.e. {@link ReadableInput#hasChannel()} must be true. Two-element array: + * {@link #LEFT} and {@link #RIGHT}. + */ + private final List<ReadableInput> inputs; + + /** + * Channels from {@link #inputs}. Two-element array: {@link #LEFT} and {@link #RIGHT}. + */ + private final List<ReadableFrameChannel> inputChannels; Review Comment: Or, given that you have multiple lists here. Maybe what you need is a `JoinChannelContainer` or something like that which has a reference to all of `ReadableInput`, `ReadableFrameChannel`, `List<KeyColumn>`, `Tracker`? ########## extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessor.java: ########## @@ -0,0 +1,1054 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.querykit.common; + +import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; +import com.google.common.collect.ImmutableList; +import it.unimi.dsi.fastutil.ints.IntOpenHashSet; +import it.unimi.dsi.fastutil.ints.IntSet; +import org.apache.druid.frame.Frame; +import org.apache.druid.frame.channel.FrameWithPartition; +import org.apache.druid.frame.channel.ReadableFrameChannel; +import org.apache.druid.frame.channel.WritableFrameChannel; +import org.apache.druid.frame.key.FrameComparisonWidget; +import org.apache.druid.frame.key.KeyColumn; +import org.apache.druid.frame.key.RowKey; +import org.apache.druid.frame.key.RowKeyReader; +import org.apache.druid.frame.processor.FrameProcessor; +import org.apache.druid.frame.processor.FrameProcessors; +import org.apache.druid.frame.processor.FrameRowTooLargeException; +import org.apache.druid.frame.processor.ReturnOrAwait; +import org.apache.druid.frame.read.FrameReader; +import org.apache.druid.frame.segment.FrameCursor; +import org.apache.druid.frame.write.FrameWriter; +import org.apache.druid.frame.write.FrameWriterFactory; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.msq.exec.Limits; +import org.apache.druid.msq.indexing.error.MSQException; +import org.apache.druid.msq.indexing.error.TooManyRowsWithSameKeyFault; +import org.apache.druid.msq.input.ReadableInput; +import org.apache.druid.query.dimension.DefaultDimensionSpec; +import org.apache.druid.query.dimension.DimensionSpec; +import org.apache.druid.query.filter.ValueMatcher; +import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; +import org.apache.druid.segment.ColumnSelectorFactory; +import org.apache.druid.segment.ColumnValueSelector; +import org.apache.druid.segment.Cursor; +import org.apache.druid.segment.DimensionSelector; +import org.apache.druid.segment.DimensionSelectorUtils; +import org.apache.druid.segment.IdLookup; +import org.apache.druid.segment.NilColumnValueSelector; +import org.apache.druid.segment.column.ColumnCapabilities; +import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.segment.data.IndexedInts; +import org.apache.druid.segment.data.ZeroIndexedInts; +import org.apache.druid.segment.join.JoinPrefixUtils; +import org.apache.druid.segment.join.JoinType; + +import javax.annotation.Nullable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +/** + * Processor for a sort-merge join of two inputs. + * + * Prerequisites: + * + * 1) Two inputs, both of which are stages; i.e. {@link ReadableInput#hasChannel()}. + * + * 2) Conditions are all simple equalities. Validated by {@link SortMergeJoinFrameProcessorFactory#validateCondition} + * and then transformed to lists of key columns by {@link SortMergeJoinFrameProcessorFactory#toKeyColumns}. + * + * 3) Both inputs are comprised of {@link org.apache.druid.frame.FrameType#ROW_BASED} frames, are sorted by the same + * key, and that key can be used to check the provided condition. Validated by + * {@link SortMergeJoinFrameProcessorFactory#validateInputFrameSignatures}. + * + * Algorithm: + * + * 1) Read current key from each side of the join. + * + * 2) If there is no match, emit or skip the row for the earlier key, as appropriate, based on the join type. + * + * 3) If there is a match, identify a complete set on one side or the other. (It doesn't matter which side has the + * complete set, but we need it on one of them.) We mark the first row for the key using {@link Tracker#markCurrent()} + * and find complete sets using {@link Tracker#hasCompleteSetForMark()}. Once we find one, we store it in + * {@link #trackerWithCompleteSetForCurrentKey}. If both sides have a complete set, we break ties by choosing the + * left side. + * + * 4) Once a complete set for the current key is identified: for each row on the *other* side, loop through the entire + * set of rows on {@link #trackerWithCompleteSetForCurrentKey}, and emit that many joined rows. + * + * 5) Once we process the final row on the *other* side, reset both marks with {@link Tracker#markCurrent()} and + * continue the algorithm. + */ +public class SortMergeJoinFrameProcessor implements FrameProcessor<Long> +{ + private static final int LEFT = 0; + private static final int RIGHT = 1; + + /** + * Two sides of the join. Must be channels; i.e. {@link ReadableInput#hasChannel()} must be true. Two-element array: + * {@link #LEFT} and {@link #RIGHT}. + */ + private final List<ReadableInput> inputs; + + /** + * Channels from {@link #inputs}. Two-element array: {@link #LEFT} and {@link #RIGHT}. + */ + private final List<ReadableFrameChannel> inputChannels; + + /** + * Names of the key columns on each side of the join, unprefixed. Two-element array: {@link #LEFT} and {@link #RIGHT}. + */ + private final List<List<KeyColumn>> keyColumns; + + /** + * Trackers for each side of the join. Two-element array: {@link #LEFT} and {@link #RIGHT}. + */ + private final List<Tracker> trackers; + + private final WritableFrameChannel outputChannel; + private final FrameWriterFactory frameWriterFactory; + private final String rightPrefix; + private final JoinType joinType; + private final JoinColumnSelectorFactory joinColumnSelectorFactory = new JoinColumnSelectorFactory(); + private FrameWriter frameWriter = null; + + // Used by runIncrementally to defer certain logic to the next run. + private Runnable nextIterationRunnable = null; + + // Used by runIncrementally to remember which tracker has the complete set for the current key. + private int trackerWithCompleteSetForCurrentKey = -1; + + SortMergeJoinFrameProcessor( + ReadableInput left, + ReadableInput right, + WritableFrameChannel outputChannel, + FrameWriterFactory frameWriterFactory, + String rightPrefix, + List<List<KeyColumn>> keyColumns, + JoinType joinType + ) + { + this.inputs = ImmutableList.of(left, right); + this.inputChannels = inputs.stream().map(ReadableInput::getChannel).collect(Collectors.toList()); + this.keyColumns = keyColumns; + this.outputChannel = outputChannel; + this.frameWriterFactory = frameWriterFactory; + this.rightPrefix = rightPrefix; + this.joinType = joinType; + this.trackers = ImmutableList.of(new Tracker(LEFT), new Tracker(RIGHT)); + } + + @Override + public List<ReadableFrameChannel> inputChannels() + { + return inputChannels; + } + + @Override + public List<WritableFrameChannel> outputChannels() + { + return Collections.singletonList(outputChannel); + } + + @Override + public ReturnOrAwait<Long> runIncrementally(IntSet readableInputs) throws IOException + { + // Fetch enough frames such that each tracker has one readable row. + for (int i = 0; i < inputChannels.size(); i++) { + final Tracker tracker = trackers.get(i); + if (tracker.isAtEndOfPushedData() && !pushNextFrame(i)) { + return nextAwait(); + } + } + + // Initialize new output frame, if needed. + startNewFrameIfNeeded(); + + while (!trackers.get(LEFT).needsMoreData() + && !trackers.get(RIGHT).needsMoreData() + && !(trackers.get(LEFT).isAtEnd() && trackers.get(RIGHT).isAtEnd())) { + if (nextIterationRunnable != null) { + final Runnable tmp = nextIterationRunnable; + nextIterationRunnable = null; + tmp.run(); + } + + final int markCmp = compareMarks(); + final boolean match = markCmp == 0 && !trackers.get(LEFT).hasPartiallyNullMark(); + + // Fetch new frames until the algorithm can proceed: + // 1) If marked keys are equal on both sides, at least one side has a complete set of rows for the marked key. + // 2) Otherwise, each side has at least one readable row, or is completely done. (Checked by "while" condition.) + if (match && trackerWithCompleteSetForCurrentKey < 0) { + for (int i = 0; i < inputChannels.size(); i++) { + final Tracker tracker = trackers.get(i); + if (tracker.hasCompleteSetForMark() || (pushNextFrame(i) && tracker.hasCompleteSetForMark())) { + trackerWithCompleteSetForCurrentKey = i; + break; + } + } + + if (trackerWithCompleteSetForCurrentKey < 0) { + // Algorithm cannot proceed; fetch more frames. + return nextAwait(); + } + } + + if (match || (markCmp <= 0 && joinType.isLefty()) || (markCmp >= 0 && joinType.isRighty())) { + // Emit row, if there's room in the current frameWriter. + joinColumnSelectorFactory.cmp = markCmp; + joinColumnSelectorFactory.match = match; + + if (!frameWriter.addSelection()) { + if (frameWriter.getNumRows() > 0) { + // Out of space in the current frame. Run again without moving cursors. + flushCurrentFrame(); + return ReturnOrAwait.runAgain(); + } else { + throw new FrameRowTooLargeException(frameWriterFactory.allocatorCapacity()); + } + } + } + + // Advance one or both trackers. + if (match) { + // Matching keys. First advance the tracker with the complete set. + final Tracker tracker = trackers.get(trackerWithCompleteSetForCurrentKey); + final Tracker otherTracker = trackers.get(trackerWithCompleteSetForCurrentKey == LEFT ? RIGHT : LEFT); + + tracker.advance(); + if (!tracker.isCurrentSameKeyAsMark()) { + // Reached end of complete set. Advance the other tracker. + otherTracker.advance(); + + // On next iteration (when we're sure to have data) either rewind the complete-set tracker, or update marks + // of both, as appropriate. + onNextIteration(() -> { + if (otherTracker.isCurrentSameKeyAsMark()) { + otherTracker.markCurrent(); // Set mark to enable cleanup of old frames. + tracker.rewindToMark(); + } else { + // Reached end of the other side too. Advance marks on both trackers. + tracker.markCurrent(); + otherTracker.markCurrent(); + trackerWithCompleteSetForCurrentKey = -1; + } + }); + } + } else { + final int trackerToAdvance; + + if (markCmp < 0) { + trackerToAdvance = LEFT; + } else if (markCmp > 0) { + trackerToAdvance = RIGHT; + } else { + // Key is null on both sides. Note that there is a preference for running through the left side first + // on a FULL join. It doesn't really matter which side we run through first, but we do need to be consistent + // for the benefit of the logic in "shouldEmitColumnValue". + trackerToAdvance = joinType.isLefty() ? LEFT : RIGHT; + } + + final Tracker tracker = trackers.get(trackerToAdvance); + + tracker.advance(); + + // On next iteration (when we're sure to have data), update mark if the key changed. + onNextIteration(() -> { + if (!tracker.isCurrentSameKeyAsMark()) { + tracker.markCurrent(); + trackerWithCompleteSetForCurrentKey = -1; + } + }); + } + } + + if (trackers.get(LEFT).isAtEnd() && trackers.get(RIGHT).isAtEnd()) { + // Both channels completely done. + flushCurrentFrame(); + return ReturnOrAwait.returnObject(0L); + } else { + // Keep reading. + return nextAwait(); + } + } + + @Override + public void cleanup() throws IOException + { + FrameProcessors.closeAll(inputChannels(), outputChannels(), frameWriter, () -> trackers.forEach(Tracker::clear)); + } + + /** + * Returns a {@link ReturnOrAwait#awaitAll} for the channel numbers that need more data and have not yet hit their + * buffered-bytes limit, {@link Limits#MAX_BUFFERED_BYTES_FOR_SORT_MERGE_JOIN}. + * + * If all channels have hit their limit, throws {@link MSQException} with {@link TooManyRowsWithSameKeyFault}. + */ + private ReturnOrAwait<Long> nextAwait() + { + final IntSet awaitSet = new IntOpenHashSet(); + int trackerAtLimit = -1; + + for (int i = 0; i < inputChannels.size(); i++) { + final Tracker tracker = trackers.get(i); + if (tracker.needsMoreData()) { + if (tracker.totalBytesBuffered() < Limits.MAX_BUFFERED_BYTES_FOR_SORT_MERGE_JOIN) { + awaitSet.add(i); + } else if (trackerAtLimit < 0) { + trackerAtLimit = i; + } + } + } + + if (awaitSet.isEmpty() && trackerAtLimit > 0) { + // All trackers that need more data are at their max buffered bytes limit. Generate a nice exception. + final Tracker tracker = trackers.get(trackerAtLimit); + if (tracker.totalBytesBuffered() > Limits.MAX_BUFFERED_BYTES_FOR_SORT_MERGE_JOIN) { + // Generate a nice exception. + throw new MSQException( + new TooManyRowsWithSameKeyFault( + tracker.readMarkKey(), + tracker.totalBytesBuffered(), + Limits.MAX_BUFFERED_BYTES_FOR_SORT_MERGE_JOIN + ) + ); + } + } + + return ReturnOrAwait.awaitAll(awaitSet); + } + + /** + * Compares the marked rows of the two {@link #trackers}. + * + * @throws IllegalStateException if either tracker does not have a marked row and is not completely done + */ + private int compareMarks() + { + final Tracker leftTracker = trackers.get(LEFT); + final Tracker rightTracker = trackers.get(RIGHT); + + Preconditions.checkState(leftTracker.hasMark() || leftTracker.isAtEnd(), "left.hasMark || left.isAtEnd"); + Preconditions.checkState(rightTracker.hasMark() || rightTracker.isAtEnd(), "right.hasMark || right.isAtEnd"); + + if (leftTracker.markFrame < 0) { + return rightTracker.markFrame < 0 ? 0 : 1; + } else if (rightTracker.markFrame < 0) { + return -1; + } else { + final FrameHolder leftHolder = leftTracker.holders.get(leftTracker.markFrame); + final FrameHolder rightHolder = rightTracker.holders.get(rightTracker.markFrame); + return leftHolder.comparisonWidget.compare( + leftTracker.markRow, + rightHolder.comparisonWidget, + rightTracker.markRow + ); + } + } + + /** + * Pushes a frame from the indicated channel into the appropriate tracker. Returns true if a frame was pushed + * or if the channel is finished. + */ + private boolean pushNextFrame(final int channelNumber) + { + final ReadableFrameChannel channel = inputChannels.get(channelNumber); + final Tracker tracker = trackers.get(channelNumber); + + if (!channel.isFinished() && !channel.canRead()) { + return false; + } else if (channel.isFinished()) { + tracker.push(null); + return true; + } else { + final Frame frame = channel.read(); + + if (frame.numRows() == 0) { + // Skip, read next. + return false; + } else { + tracker.push(frame); + return true; + } + } + } + + private void onNextIteration(final Runnable runnable) + { + if (nextIterationRunnable != null) { + throw new ISE("postAdvanceRunnable already set"); + } else { + nextIterationRunnable = runnable; + } + } + + private void startNewFrameIfNeeded() + { + if (frameWriter == null) { + frameWriter = frameWriterFactory.newFrameWriter(joinColumnSelectorFactory); + } + } + + private void flushCurrentFrame() throws IOException + { + if (frameWriter != null) { + if (frameWriter.getNumRows() > 0) { + final Frame frame = Frame.wrap(frameWriter.toByteArray()); + frameWriter.close(); + frameWriter = null; + outputChannel.write(new FrameWithPartition(frame, FrameWithPartition.NO_PARTITION)); + } + } + } + + /** + * Tracks the current set of rows that have the same key from a sequence of frames. + * + * markFrame and markRow are set when we encounter a new key, which enables rewinding and re-reading data with the + * same key. + */ + private class Tracker + { + private final List<FrameHolder> holders = new ArrayList<>(); + private final int channelNumber; + + // markFrame and markRow are the first frame and row with the current key. + private int markFrame = -1; + private int markRow = -1; + + // currentFrame is the frame containing the current cursor row. + private int currentFrame = -1; + + // done indicates that no more data is available in the channel. + private boolean done; + + public Tracker(int channelNumber) + { + this.channelNumber = channelNumber; + } + + /** + * Adds a holder for a frame. If this is the first frame, sets the current cursor position and mark to the first + * row of the frame. Otherwise, the cursor position and mark are not changed. + * + * Pushing a null frame indicates no more frames are coming. + * + * @param frame frame, or null indicating no more frames are coming + */ + public void push(final Frame frame) + { + if (frame == null) { + done = true; + return; + } + + if (done) { + throw new ISE("Cannot push frames when already done"); + } + + final boolean atEndOfPushedData = isAtEndOfPushedData(); + final FrameReader frameReader = inputs.get(channelNumber).getChannelFrameReader(); + final FrameCursor cursor = FrameProcessors.makeCursor(frame, frameReader); + final FrameComparisonWidget comparisonWidget = + frameReader.makeComparisonWidget(frame, keyColumns.get(channelNumber)); + + final RowSignature.Builder keySignatureBuilder = RowSignature.builder(); + for (final KeyColumn keyColumn : keyColumns.get(channelNumber)) { + keySignatureBuilder.add( + keyColumn.columnName(), + frameReader.signature().getColumnType(keyColumn.columnName()).orElse(null) + ); + } + + holders.add( + new FrameHolder( + frame, + RowKeyReader.create(keySignatureBuilder.build()), + cursor, + comparisonWidget + ) + ); + + if (atEndOfPushedData) { + // Move currentFrame so it points at the next row, which we now have, instead of an "isDone" cursor. + currentFrame = currentFrame < 0 ? 0 : currentFrame + 1; + } + + if (markFrame < 0) { + // Cleared mark means we want the current row to be marked. + markFrame = currentFrame; + markRow = 0; + } + } + + /** + * Number of bytes currently buffered in {@link #holders}. + */ + public long totalBytesBuffered() + { + long bytes = 0; + for (final FrameHolder holder : holders) { + bytes += holder.frame.numBytes(); + } + return bytes; + } + + /** + * Cursor containing the current row. + */ + @Nullable + public FrameCursor currentCursor() + { + if (currentFrame < 0) { + return null; + } else { + return holders.get(currentFrame).cursor; + } + } + + /** + * Advances the current row (the current row of {@link #currentFrame}). After calling this method, + * {@link #isAtEndOfPushedData()} may start returning true. + */ + public void advance() + { + assert !isAtEndOfPushedData(); + + final FrameHolder currentHolder = holders.get(currentFrame); + + currentHolder.cursor.advance(); + + if (currentHolder.cursor.isDone() && currentFrame + 1 < holders.size()) { + currentFrame++; + holders.get(currentFrame).cursor.reset(); + } + } + + /** + * Whether this tracker has a marked row. + */ + public boolean hasMark() + { + return markFrame >= 0; + } + + /** + * Whether this tracker has a marked row that is partially null. + */ + public boolean hasPartiallyNullMark() + { + return hasMark() && holders.get(markFrame).comparisonWidget.isPartiallyNullKey(markRow); + } + + /** + * Reads the current marked key. + */ + @Nullable + public List<Object> readMarkKey() + { + if (!hasMark()) { + return null; + } + + final FrameHolder markHolder = holders.get(markFrame); + final RowKey markKey = markHolder.comparisonWidget.readKey(markRow); + return markHolder.keyReader.read(markKey); + } + + /** + * Rewind to the mark row: the first one with the current key. + * + * @throws IllegalStateException if there is no marked row + */ + public void rewindToMark() + { + if (markFrame < 0) { + throw new ISE("No mark"); + } + + currentFrame = markFrame; + holders.get(currentFrame).cursor.setCurrentRow(markRow); + } + + /** + * Set the mark row to the current row. Used when data from the old mark to the current row is no longer needed. + */ + public void markCurrent() + { + if (isAtEndOfPushedData()) { + clear(); + } else { + // Remove unnecessary holders. + while (currentFrame > 0) { + if (currentFrame == holders.size() - 1) { + final FrameHolder lastHolder = holders.get(currentFrame); + holders.clear(); + holders.add(lastHolder); + currentFrame = 0; + } else { + holders.remove(0); + currentFrame--; + } + } + + markFrame = 0; + markRow = holders.get(currentFrame).cursor.getCurrentRow(); + } + } + + /** + * Whether the current cursor is past the end of the last frame for which we have data. + */ + public boolean isAtEndOfPushedData() + { + return currentFrame < 0 || (currentFrame == holders.size() - 1 && holders.get(currentFrame).cursor.isDone()); + } + + /** + * Whether the current cursor is past the end of all data that will ever be pushed. + */ + public boolean isAtEnd() + { + return done && isAtEndOfPushedData(); + } + + /** + * Whether this tracker needs more data in order to read the current cursor location or move it forward. + */ + public boolean needsMoreData() + { + return !done && isAtEndOfPushedData(); + } + + /** + * Whether this tracker contains all rows for the marked key. + * + * @throws IllegalStateException if there is no marked key + */ + public boolean hasCompleteSetForMark() + { + if (markFrame < 0) { + throw new ISE("No mark"); + } + + if (done) { + return true; + } + + final FrameHolder lastHolder = holders.get(holders.size() - 1); + return !isSameKeyAsMark(lastHolder, lastHolder.frame.numRows() - 1); + } + + /** + * Whether the current position (the current row of the {@link #currentFrame}) compares equally to the mark row. + * If {@link #isAtEnd()}, returns true iff there is no mark row. + */ + public boolean isCurrentSameKeyAsMark() + { + if (isAtEnd()) { + return markFrame < 0; + } else { + assert !isAtEndOfPushedData(); + final FrameHolder headHolder = holders.get(currentFrame); + return isSameKeyAsMark(headHolder, headHolder.cursor.getCurrentRow()); + } + } + + /** + * Clears the current mark and all buffered frames. Does not change {@link #done}. + */ + public void clear() + { + holders.clear(); + markFrame = -1; + markRow = -1; + currentFrame = -1; + } + + /** + * Whether the provided frame and row compares equally to the mark row. The provided row must be at, or after, + * the mark row. + */ + private boolean isSameKeyAsMark(final FrameHolder holder, final int row) + { + // Mark row must exist. + assert markFrame >= 0; + + // Row must exist. + assert row >= 0 && row < holder.frame.numRows(); + + final FrameHolder markHolder = holders.get(markFrame); + final int cmp = markHolder.comparisonWidget.compare(markRow, holder.comparisonWidget, row); + + assert cmp <= 0; Review Comment: how important is this assert? If it's truly important, you should perhaps have an if statement (that branch prediction *should* be able to do a really good job with)? ########## extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessorTest.java: ########## @@ -0,0 +1,987 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.querykit.common; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.MoreExecutors; +import org.apache.druid.common.config.NullHandling; +import org.apache.druid.common.guava.FutureUtils; +import org.apache.druid.frame.Frame; +import org.apache.druid.frame.FrameType; +import org.apache.druid.frame.allocation.ArenaMemoryAllocator; +import org.apache.druid.frame.allocation.SingleMemoryAllocatorFactory; +import org.apache.druid.frame.channel.BlockingQueueFrameChannel; +import org.apache.druid.frame.channel.ReadableFrameChannel; +import org.apache.druid.frame.channel.ReadableNilFrameChannel; +import org.apache.druid.frame.key.KeyColumn; +import org.apache.druid.frame.key.KeyOrder; +import org.apache.druid.frame.processor.FrameProcessorExecutor; +import org.apache.druid.frame.read.FrameReader; +import org.apache.druid.frame.segment.FrameStorageAdapter; +import org.apache.druid.frame.testutil.FrameSequenceBuilder; +import org.apache.druid.frame.testutil.FrameTestUtil; +import org.apache.druid.frame.write.FrameWriterFactory; +import org.apache.druid.frame.write.FrameWriters; +import org.apache.druid.java.util.common.Intervals; +import org.apache.druid.java.util.common.concurrent.Execs; +import org.apache.druid.java.util.common.guava.Sequence; +import org.apache.druid.java.util.common.guava.Sequences; +import org.apache.druid.msq.input.ReadableInput; +import org.apache.druid.msq.kernel.StageId; +import org.apache.druid.msq.kernel.StagePartition; +import org.apache.druid.msq.test.LimitedFrameWriterFactory; +import org.apache.druid.segment.RowBasedSegment; +import org.apache.druid.segment.StorageAdapter; +import org.apache.druid.segment.column.ColumnType; +import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.segment.join.JoinTestHelper; +import org.apache.druid.segment.join.JoinType; +import org.apache.druid.testing.InitializedNullHandlingTest; +import org.apache.druid.timeline.SegmentId; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +@RunWith(Parameterized.class) +public class SortMergeJoinFrameProcessorTest extends InitializedNullHandlingTest +{ + private static final StagePartition STAGE_PARTITION = new StagePartition(new StageId("q", 0), 0); + + private final int rowsPerInputFrame; + private final int rowsPerOutputFrame; + + private FrameProcessorExecutor exec; + + @Rule + public TemporaryFolder temporaryFolder = new TemporaryFolder(); + + public SortMergeJoinFrameProcessorTest(int rowsPerInputFrame, int rowsPerOutputFrame) + { + this.rowsPerInputFrame = rowsPerInputFrame; + this.rowsPerOutputFrame = rowsPerOutputFrame; + } + + @Parameterized.Parameters(name = "rowsPerInputFrame = {0}, rowsPerOutputFrame = {1}") + public static Iterable<Object[]> constructorFeeder() + { + final List<Object[]> constructors = new ArrayList<>(); + + for (final int rowsPerInputFrame : new int[]{1, 2, 7, Integer.MAX_VALUE}) { + for (final int rowsPerOutputFrame : new int[]{1, 2, 7, Integer.MAX_VALUE}) { + constructors.add(new Object[]{rowsPerInputFrame, rowsPerOutputFrame}); + } + } + + return constructors; + } + + @Before + public void setUp() + { + exec = new FrameProcessorExecutor(MoreExecutors.listeningDecorator(Execs.singleThreaded("test-exec"))); + } + + @After + public void tearDown() throws Exception + { + exec.getExecutorService().shutdownNow(); + exec.getExecutorService().awaitTermination(10, TimeUnit.MINUTES); + } + + @Test + public void testLeftJoinEmptyLeftSide() throws Exception + { + final ReadableInput factChannel = ReadableInput.channel( + ReadableNilFrameChannel.INSTANCE, + FrameReader.create(JoinTestHelper.FACT_SIGNATURE), + STAGE_PARTITION + ); + + final ReadableInput countriesChannel = + buildCountriesInput(ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING))); + + final BlockingQueueFrameChannel outputChannel = BlockingQueueFrameChannel.minimal(); + + final RowSignature joinSignature = + RowSignature.builder() + .add("page", ColumnType.STRING) + .add("countryIsoCode", ColumnType.STRING) + .add("j0.countryIsoCode", ColumnType.STRING) + .add("j0.countryName", ColumnType.STRING) + .add("j0.countryNumber", ColumnType.LONG) + .build(); + + final SortMergeJoinFrameProcessor processor = new SortMergeJoinFrameProcessor( + factChannel, + countriesChannel, + outputChannel.writable(), + makeFrameWriterFactory(joinSignature), + "j0.", + ImmutableList.of( + ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING)), + ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING)) + ), + JoinType.LEFT + ); + + assertResult(processor, outputChannel.readable(), joinSignature, Collections.emptyList()); + } + + @Test + public void testLeftJoinEmptyRightSide() throws Exception + { + final ReadableInput factChannel = buildFactInput( + ImmutableList.of( + new KeyColumn("countryIsoCode", KeyOrder.ASCENDING), + new KeyColumn("page", KeyOrder.ASCENDING) + ) + ); + + + final ReadableInput countriesChannel = ReadableInput.channel( + ReadableNilFrameChannel.INSTANCE, + FrameReader.create(JoinTestHelper.COUNTRIES_SIGNATURE), + STAGE_PARTITION + ); + + final BlockingQueueFrameChannel outputChannel = BlockingQueueFrameChannel.minimal(); + + final RowSignature joinSignature = + RowSignature.builder() + .add("page", ColumnType.STRING) + .add("countryIsoCode", ColumnType.STRING) + .add("j0.countryIsoCode", ColumnType.STRING) + .add("j0.countryName", ColumnType.STRING) + .add("j0.countryNumber", ColumnType.LONG) + .build(); + + final SortMergeJoinFrameProcessor processor = new SortMergeJoinFrameProcessor( + factChannel, + countriesChannel, + outputChannel.writable(), + makeFrameWriterFactory(joinSignature), + "j0.", + ImmutableList.of( + ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING)), + ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING)) + ), + JoinType.LEFT + ); + + final List<List<Object>> expectedRows = Arrays.asList( + Arrays.asList("Agama mossambica", null, null, null, null), + Arrays.asList("Apamea abruzzorum", null, null, null, null), + Arrays.asList("Atractus flammigerus", null, null, null, null), + Arrays.asList("Rallicula", null, null, null, null), + Arrays.asList("Talk:Oswald Tilghman", null, null, null, null), + Arrays.asList("Peremptory norm", "AU", null, null, null), + Arrays.asList("Didier Leclair", "CA", null, null, null), + Arrays.asList("Les Argonautes", "CA", null, null, null), + Arrays.asList("Sarah Michelle Gellar", "CA", null, null, null), + Arrays.asList("Golpe de Estado en Chile de 1973", "CL", null, null, null), + Arrays.asList("Diskussion:Sebastian Schulz", "DE", null, null, null), + Arrays.asList("Gabinete Ministerial de Rafael Correa", "EC", null, null, null), + Arrays.asList("Saison 9 de Secret Story", "FR", null, null, null), + Arrays.asList("Glasgow", "GB", null, null, null), + Arrays.asList("Giusy Ferreri discography", "IT", null, null, null), + Arrays.asList("Roma-Bangkok", "IT", null, null, null), + Arrays.asList("青野武", "JP", null, null, null), + Arrays.asList("유희왕 GX", "KR", null, null, null), + Arrays.asList("History of Fourems", "MMMM", null, null, null), + Arrays.asList("Mathis Bolly", "MX", null, null, null), + Arrays.asList("Orange Soda", "MatchNothing", null, null, null), + Arrays.asList("Алиса в Зазеркалье", "NO", null, null, null), + Arrays.asList("Cream Soda", "SU", null, null, null), + Arrays.asList("Wendigo", "SV", null, null, null), + Arrays.asList("Carlo Curti", "US", null, null, null), + Arrays.asList("DirecTV", "US", null, null, null), + Arrays.asList("Old Anatolian Turkish", "US", null, null, null), + Arrays.asList("Otjiwarongo Airport", "US", null, null, null), + Arrays.asList("President of India", "US", null, null, null) + ); + + assertResult(processor, outputChannel.readable(), joinSignature, expectedRows); + } + + @Test + public void testInnerJoinEmptyRightSide() throws Exception + { + final ReadableInput factChannel = buildFactInput( + ImmutableList.of( + new KeyColumn("countryIsoCode", KeyOrder.ASCENDING), + new KeyColumn("page", KeyOrder.ASCENDING) + ) + ); + + final ReadableInput countriesChannel = ReadableInput.channel( + ReadableNilFrameChannel.INSTANCE, + FrameReader.create(JoinTestHelper.COUNTRIES_SIGNATURE), + STAGE_PARTITION + ); + + final BlockingQueueFrameChannel outputChannel = BlockingQueueFrameChannel.minimal(); + + final RowSignature joinSignature = + RowSignature.builder() + .add("page", ColumnType.STRING) + .add("countryIsoCode", ColumnType.STRING) + .add("j0.countryIsoCode", ColumnType.STRING) + .add("j0.countryName", ColumnType.STRING) + .add("j0.countryNumber", ColumnType.LONG) + .build(); + + final SortMergeJoinFrameProcessor processor = new SortMergeJoinFrameProcessor( + factChannel, + countriesChannel, + outputChannel.writable(), + makeFrameWriterFactory(joinSignature), + "j0.", + ImmutableList.of( + ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING)), + ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING)) + ), + JoinType.INNER + ); + + assertResult(processor, outputChannel.readable(), joinSignature, Collections.emptyList()); + } + + @Test + public void testLeftJoinCountryIsoCode() throws Exception + { + final ReadableInput factChannel = buildFactInput( + ImmutableList.of( + new KeyColumn("countryIsoCode", KeyOrder.ASCENDING), + new KeyColumn("page", KeyOrder.ASCENDING) + ) + ); + + final ReadableInput countriesChannel = + buildCountriesInput(ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING))); + + final BlockingQueueFrameChannel outputChannel = BlockingQueueFrameChannel.minimal(); + + final RowSignature joinSignature = + RowSignature.builder() + .add("page", ColumnType.STRING) + .add("countryIsoCode", ColumnType.STRING) + .add("j0.countryIsoCode", ColumnType.STRING) + .add("j0.countryName", ColumnType.STRING) + .add("j0.countryNumber", ColumnType.LONG) + .build(); + + final SortMergeJoinFrameProcessor processor = new SortMergeJoinFrameProcessor( + factChannel, + countriesChannel, + outputChannel.writable(), + makeFrameWriterFactory(joinSignature), + "j0.", + ImmutableList.of( + ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING)), + ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING)) + ), + JoinType.LEFT + ); + + final List<List<Object>> expectedRows = Arrays.asList( + Arrays.asList("Agama mossambica", null, null, null, null), + Arrays.asList("Apamea abruzzorum", null, null, null, null), + Arrays.asList("Atractus flammigerus", null, null, null, null), + Arrays.asList("Rallicula", null, null, null, null), + Arrays.asList("Talk:Oswald Tilghman", null, null, null, null), + Arrays.asList("Peremptory norm", "AU", "AU", "Australia", 0L), + Arrays.asList("Didier Leclair", "CA", "CA", "Canada", 1L), + Arrays.asList("Les Argonautes", "CA", "CA", "Canada", 1L), + Arrays.asList("Sarah Michelle Gellar", "CA", "CA", "Canada", 1L), + Arrays.asList("Golpe de Estado en Chile de 1973", "CL", "CL", "Chile", 2L), + Arrays.asList("Diskussion:Sebastian Schulz", "DE", "DE", "Germany", 3L), + Arrays.asList("Gabinete Ministerial de Rafael Correa", "EC", "EC", "Ecuador", 4L), + Arrays.asList("Saison 9 de Secret Story", "FR", "FR", "France", 5L), + Arrays.asList("Glasgow", "GB", "GB", "United Kingdom", 6L), + Arrays.asList("Giusy Ferreri discography", "IT", "IT", "Italy", 7L), + Arrays.asList("Roma-Bangkok", "IT", "IT", "Italy", 7L), + Arrays.asList("青野武", "JP", "JP", "Japan", 8L), + Arrays.asList("유희왕 GX", "KR", "KR", "Republic of Korea", 9L), + Arrays.asList("History of Fourems", "MMMM", "MMMM", "Fourems", 205L), + Arrays.asList("Mathis Bolly", "MX", "MX", "Mexico", 10L), + Arrays.asList("Orange Soda", "MatchNothing", null, null, null), + Arrays.asList("Алиса в Зазеркалье", "NO", "NO", "Norway", 11L), + Arrays.asList("Cream Soda", "SU", "SU", "States United", 15L), + Arrays.asList("Wendigo", "SV", "SV", "El Salvador", 12L), + Arrays.asList("Carlo Curti", "US", "US", "United States", 13L), + Arrays.asList("DirecTV", "US", "US", "United States", 13L), + Arrays.asList("Old Anatolian Turkish", "US", "US", "United States", 13L), + Arrays.asList("Otjiwarongo Airport", "US", "US", "United States", 13L), + Arrays.asList("President of India", "US", "US", "United States", 13L) + ); + + assertResult(processor, outputChannel.readable(), joinSignature, expectedRows); + } + + @Test + public void testCrossJoin() throws Exception + { + final ReadableInput factChannel = buildFactInput( + ImmutableList.of( + new KeyColumn("countryIsoCode", KeyOrder.ASCENDING), + new KeyColumn("page", KeyOrder.ASCENDING) + ) + ); + + final ReadableInput countriesChannel = makeChannelFromResourceWithLimit( + JoinTestHelper.COUNTRIES_RESOURCE, + JoinTestHelper.COUNTRIES_SIGNATURE, + ImmutableList.of(new KeyColumn("countryIsoCode", KeyOrder.ASCENDING)), + 2 + ); + + final BlockingQueueFrameChannel outputChannel = BlockingQueueFrameChannel.minimal(); + + final RowSignature joinSignature = + RowSignature.builder() + .add("j0.page", ColumnType.STRING) + .add("countryIsoCode", ColumnType.STRING) + .build(); + + final SortMergeJoinFrameProcessor processor = new SortMergeJoinFrameProcessor( + countriesChannel, + factChannel, + outputChannel.writable(), + makeFrameWriterFactory(joinSignature), + "j0.", + ImmutableList.of(Collections.emptyList(), Collections.emptyList()), + JoinType.INNER + ); + + final List<List<Object>> expectedRows = Arrays.asList( + Arrays.asList("Agama mossambica", "AU"), + Arrays.asList("Agama mossambica", "CA"), + Arrays.asList("Apamea abruzzorum", "AU"), + Arrays.asList("Apamea abruzzorum", "CA"), + Arrays.asList("Atractus flammigerus", "AU"), + Arrays.asList("Atractus flammigerus", "CA"), + Arrays.asList("Rallicula", "AU"), + Arrays.asList("Rallicula", "CA"), + Arrays.asList("Talk:Oswald Tilghman", "AU"), + Arrays.asList("Talk:Oswald Tilghman", "CA"), + Arrays.asList("Peremptory norm", "AU"), + Arrays.asList("Peremptory norm", "CA"), + Arrays.asList("Didier Leclair", "AU"), + Arrays.asList("Didier Leclair", "CA"), + Arrays.asList("Les Argonautes", "AU"), + Arrays.asList("Les Argonautes", "CA"), + Arrays.asList("Sarah Michelle Gellar", "AU"), + Arrays.asList("Sarah Michelle Gellar", "CA"), + Arrays.asList("Golpe de Estado en Chile de 1973", "AU"), + Arrays.asList("Golpe de Estado en Chile de 1973", "CA"), + Arrays.asList("Diskussion:Sebastian Schulz", "AU"), + Arrays.asList("Diskussion:Sebastian Schulz", "CA"), + Arrays.asList("Gabinete Ministerial de Rafael Correa", "AU"), + Arrays.asList("Gabinete Ministerial de Rafael Correa", "CA"), + Arrays.asList("Saison 9 de Secret Story", "AU"), + Arrays.asList("Saison 9 de Secret Story", "CA"), + Arrays.asList("Glasgow", "AU"), + Arrays.asList("Glasgow", "CA"), + Arrays.asList("Giusy Ferreri discography", "AU"), + Arrays.asList("Giusy Ferreri discography", "CA"), + Arrays.asList("Roma-Bangkok", "AU"), + Arrays.asList("Roma-Bangkok", "CA"), + Arrays.asList("青野武", "AU"), + Arrays.asList("青野武", "CA"), + Arrays.asList("유희왕 GX", "AU"), + Arrays.asList("유희왕 GX", "CA"), + Arrays.asList("History of Fourems", "AU"), + Arrays.asList("History of Fourems", "CA"), + Arrays.asList("Mathis Bolly", "AU"), + Arrays.asList("Mathis Bolly", "CA"), + Arrays.asList("Orange Soda", "AU"), + Arrays.asList("Orange Soda", "CA"), + Arrays.asList("Алиса в Зазеркалье", "AU"), + Arrays.asList("Алиса в Зазеркалье", "CA"), + Arrays.asList("Cream Soda", "AU"), + Arrays.asList("Cream Soda", "CA"), + Arrays.asList("Wendigo", "AU"), + Arrays.asList("Wendigo", "CA"), + Arrays.asList("Carlo Curti", "AU"), + Arrays.asList("Carlo Curti", "CA"), + Arrays.asList("DirecTV", "AU"), + Arrays.asList("DirecTV", "CA"), + Arrays.asList("Old Anatolian Turkish", "AU"), + Arrays.asList("Old Anatolian Turkish", "CA"), + Arrays.asList("Otjiwarongo Airport", "AU"), + Arrays.asList("Otjiwarongo Airport", "CA"), + Arrays.asList("President of India", "AU"), + Arrays.asList("President of India", "CA") + ); + + assertResult(processor, outputChannel.readable(), joinSignature, expectedRows); + } + + @Test + public void testLeftJoinRegions() throws Exception + { + final ReadableInput factChannel = + buildFactInput( + ImmutableList.of( + new KeyColumn("countryIsoCode", KeyOrder.ASCENDING), + new KeyColumn("regionIsoCode", KeyOrder.ASCENDING), + new KeyColumn("page", KeyOrder.ASCENDING) + ) + ); + + final ReadableInput regionsChannel = + buildRegionsInput( + ImmutableList.of( + new KeyColumn("countryIsoCode", KeyOrder.ASCENDING), + new KeyColumn("regionIsoCode", KeyOrder.ASCENDING) + ) + ); + + final BlockingQueueFrameChannel outputChannel = BlockingQueueFrameChannel.minimal(); + + final RowSignature joinSignature = + RowSignature.builder() + .add("page", ColumnType.STRING) + .add("j0.regionName", ColumnType.STRING) + .add("countryIsoCode", ColumnType.STRING) + .build(); + + final SortMergeJoinFrameProcessor processor = new SortMergeJoinFrameProcessor( + factChannel, + regionsChannel, + outputChannel.writable(), + makeFrameWriterFactory(joinSignature), + "j0.", + ImmutableList.of( + ImmutableList.of( + new KeyColumn("countryIsoCode", KeyOrder.ASCENDING), + new KeyColumn("regionIsoCode", KeyOrder.ASCENDING) + ), + ImmutableList.of( + new KeyColumn("countryIsoCode", KeyOrder.ASCENDING), + new KeyColumn("regionIsoCode", KeyOrder.ASCENDING) + ) + ), + JoinType.LEFT + ); + + final List<List<Object>> expectedRows = Arrays.asList( + Arrays.asList("Agama mossambica", null, null), + Arrays.asList("Apamea abruzzorum", null, null), + Arrays.asList("Atractus flammigerus", null, null), + Arrays.asList("Rallicula", null, null), + Arrays.asList("Talk:Oswald Tilghman", null, null), + Arrays.asList("Peremptory norm", "New South Wales", "AU"), + Arrays.asList("Didier Leclair", "Ontario", "CA"), + Arrays.asList("Sarah Michelle Gellar", "Ontario", "CA"), + Arrays.asList("Les Argonautes", "Quebec", "CA"), + Arrays.asList("Golpe de Estado en Chile de 1973", "Santiago Metropolitan", "CL"), + Arrays.asList("Diskussion:Sebastian Schulz", "Hesse", "DE"), + Arrays.asList("Gabinete Ministerial de Rafael Correa", "Provincia del Guayas", "EC"), + Arrays.asList("Saison 9 de Secret Story", "Val d'Oise", "FR"), + Arrays.asList("Glasgow", "Kingston upon Hull", "GB"), + Arrays.asList("Giusy Ferreri discography", "Provincia di Varese", "IT"), + Arrays.asList("Roma-Bangkok", "Provincia di Varese", "IT"), + Arrays.asList("青野武", "Tōkyō", "JP"), + Arrays.asList("유희왕 GX", "Seoul", "KR"), + Arrays.asList("History of Fourems", "Fourems Province", "MMMM"), + Arrays.asList("Mathis Bolly", "Mexico City", "MX"), + Arrays.asList("Orange Soda", null, "MatchNothing"), + Arrays.asList("Алиса в Зазеркалье", "Finnmark Fylke", "NO"), + Arrays.asList("Cream Soda", "Ainigriv", "SU"), + Arrays.asList("Wendigo", "Departamento de San Salvador", "SV"), + Arrays.asList("Carlo Curti", "California", "US"), + Arrays.asList("Otjiwarongo Airport", "California", "US"), + Arrays.asList("President of India", "California", "US"), + Arrays.asList("DirecTV", "North Carolina", "US"), + Arrays.asList("Old Anatolian Turkish", "Virginia", "US") + ); + + assertResult(processor, outputChannel.readable(), joinSignature, expectedRows); + } + + @Test + public void testRightJoinRegionCodeOnly() throws Exception + { + // This join generates duplicates. + + final ReadableInput factChannel = + buildFactInput( + ImmutableList.of( + new KeyColumn("regionIsoCode", KeyOrder.ASCENDING), + new KeyColumn("page", KeyOrder.ASCENDING) + ) + ); + + final ReadableInput regionsChannel = + buildRegionsInput( + ImmutableList.of( + new KeyColumn("regionIsoCode", KeyOrder.ASCENDING), + new KeyColumn("countryIsoCode", KeyOrder.ASCENDING) + ) + ); + + final BlockingQueueFrameChannel outputChannel = BlockingQueueFrameChannel.minimal(); + + final RowSignature joinSignature = + RowSignature.builder() + .add("j0.page", ColumnType.STRING) + .add("regionName", ColumnType.STRING) + .add("j0.countryIsoCode", ColumnType.STRING) + .build(); + + final SortMergeJoinFrameProcessor processor = new SortMergeJoinFrameProcessor( + regionsChannel, + factChannel, + outputChannel.writable(), + makeFrameWriterFactory(joinSignature), + "j0.", + ImmutableList.of( + ImmutableList.of(new KeyColumn("regionIsoCode", KeyOrder.ASCENDING)), + ImmutableList.of(new KeyColumn("regionIsoCode", KeyOrder.ASCENDING)) + ), + JoinType.RIGHT + ); + + final List<List<Object>> expectedRows = Arrays.asList( + Arrays.asList("Agama mossambica", null, null), + Arrays.asList("Apamea abruzzorum", null, null), + Arrays.asList("Atractus flammigerus", null, null), + Arrays.asList("Rallicula", null, null), + Arrays.asList("Talk:Oswald Tilghman", null, null), + Arrays.asList("유희왕 GX", "Seoul", "KR"), + Arrays.asList("青野武", "Tōkyō", "JP"), + Arrays.asList("Алиса в Зазеркалье", "Finnmark Fylke", "NO"), + Arrays.asList("Saison 9 de Secret Story", "Val d'Oise", "FR"), + Arrays.asList("Cream Soda", "Ainigriv", "SU"), + Arrays.asList("Carlo Curti", "California", "US"), + Arrays.asList("Otjiwarongo Airport", "California", "US"), + Arrays.asList("President of India", "California", "US"), + Arrays.asList("Mathis Bolly", "Mexico City", "MX"), + Arrays.asList("Gabinete Ministerial de Rafael Correa", "Provincia del Guayas", "EC"), + Arrays.asList("Diskussion:Sebastian Schulz", "Hesse", "DE"), + Arrays.asList("Glasgow", "Kingston upon Hull", "GB"), + Arrays.asList("History of Fourems", "Fourems Province", "MMMM"), + Arrays.asList("Orange Soda", null, "MatchNothing"), + Arrays.asList("DirecTV", "North Carolina", "US"), + Arrays.asList("Peremptory norm", "New South Wales", "AU"), + Arrays.asList("Didier Leclair", "Ontario", "CA"), + Arrays.asList("Sarah Michelle Gellar", "Ontario", "CA"), + Arrays.asList("Les Argonautes", "Quebec", "CA"), + Arrays.asList("Golpe de Estado en Chile de 1973", "Santiago Metropolitan", "CL"), + Arrays.asList("Wendigo", "Departamento de San Salvador", "SV"), + Arrays.asList("Giusy Ferreri discography", "Provincia di Varese", "IT"), + Arrays.asList("Giusy Ferreri discography", "Virginia", "IT"), + Arrays.asList("Old Anatolian Turkish", "Provincia di Varese", "US"), + Arrays.asList("Old Anatolian Turkish", "Virginia", "US"), + Arrays.asList("Roma-Bangkok", "Provincia di Varese", "IT"), + Arrays.asList("Roma-Bangkok", "Virginia", "IT") + ); + + assertResult(processor, outputChannel.readable(), joinSignature, expectedRows); + } + + @Test + public void testFullOuterJoinRegionCodeOnly() throws Exception + { + // This join generates duplicates. + + final ReadableInput factChannel = + buildFactInput( + ImmutableList.of( + new KeyColumn("regionIsoCode", KeyOrder.ASCENDING), + new KeyColumn("page", KeyOrder.ASCENDING) + ) + ); + + final ReadableInput regionsChannel = + buildRegionsInput( + ImmutableList.of( + new KeyColumn("regionIsoCode", KeyOrder.ASCENDING), + new KeyColumn("countryIsoCode", KeyOrder.ASCENDING) + ) + ); + + final BlockingQueueFrameChannel outputChannel = BlockingQueueFrameChannel.minimal(); + + final RowSignature joinSignature = + RowSignature.builder() + .add("j0.page", ColumnType.STRING) + .add("regionName", ColumnType.STRING) + .add("j0.countryIsoCode", ColumnType.STRING) + .build(); + + final SortMergeJoinFrameProcessor processor = new SortMergeJoinFrameProcessor( + regionsChannel, + factChannel, + outputChannel.writable(), + makeFrameWriterFactory(joinSignature), + "j0.", + ImmutableList.of( + ImmutableList.of(new KeyColumn("regionIsoCode", KeyOrder.ASCENDING)), + ImmutableList.of(new KeyColumn("regionIsoCode", KeyOrder.ASCENDING)) + ), + JoinType.FULL + ); + + final List<List<Object>> expectedRows = Arrays.asList( + Arrays.asList(null, "Nulland", null), + Arrays.asList("Agama mossambica", null, null), + Arrays.asList("Apamea abruzzorum", null, null), + Arrays.asList("Atractus flammigerus", null, null), + Arrays.asList("Rallicula", null, null), + Arrays.asList("Talk:Oswald Tilghman", null, null), + Arrays.asList("유희왕 GX", "Seoul", "KR"), + Arrays.asList("青野武", "Tōkyō", "JP"), + Arrays.asList("Алиса в Зазеркалье", "Finnmark Fylke", "NO"), + Arrays.asList("Saison 9 de Secret Story", "Val d'Oise", "FR"), + Arrays.asList(null, "Foureis Province", null), + Arrays.asList("Cream Soda", "Ainigriv", "SU"), + Arrays.asList("Carlo Curti", "California", "US"), + Arrays.asList("Otjiwarongo Airport", "California", "US"), + Arrays.asList("President of India", "California", "US"), + Arrays.asList("Mathis Bolly", "Mexico City", "MX"), + Arrays.asList("Gabinete Ministerial de Rafael Correa", "Provincia del Guayas", "EC"), + Arrays.asList("Diskussion:Sebastian Schulz", "Hesse", "DE"), + Arrays.asList("Glasgow", "Kingston upon Hull", "GB"), + Arrays.asList("History of Fourems", "Fourems Province", "MMMM"), + Arrays.asList("Orange Soda", null, "MatchNothing"), + Arrays.asList("DirecTV", "North Carolina", "US"), + Arrays.asList("Peremptory norm", "New South Wales", "AU"), + Arrays.asList("Didier Leclair", "Ontario", "CA"), + Arrays.asList("Sarah Michelle Gellar", "Ontario", "CA"), + Arrays.asList("Les Argonautes", "Quebec", "CA"), + Arrays.asList("Golpe de Estado en Chile de 1973", "Santiago Metropolitan", "CL"), + Arrays.asList("Wendigo", "Departamento de San Salvador", "SV"), + Arrays.asList("Giusy Ferreri discography", "Provincia di Varese", "IT"), + Arrays.asList("Giusy Ferreri discography", "Virginia", "IT"), + Arrays.asList("Old Anatolian Turkish", "Provincia di Varese", "US"), + Arrays.asList("Old Anatolian Turkish", "Virginia", "US"), + Arrays.asList("Roma-Bangkok", "Provincia di Varese", "IT"), + Arrays.asList("Roma-Bangkok", "Virginia", "IT"), + Arrays.asList(null, "Usca City", null) + ); + + assertResult(processor, outputChannel.readable(), joinSignature, expectedRows); + } + + @Test + public void testLeftJoinCountryNumber() throws Exception + { + final ReadableInput factChannel = buildFactInput( + ImmutableList.of( + new KeyColumn("countryNumber", KeyOrder.ASCENDING), + new KeyColumn("page", KeyOrder.ASCENDING) + ) + ); + + final ReadableInput countriesChannel = + buildCountriesInput(ImmutableList.of(new KeyColumn("countryNumber", KeyOrder.ASCENDING))); + + final BlockingQueueFrameChannel outputChannel = BlockingQueueFrameChannel.minimal(); + + final RowSignature joinSignature = + RowSignature.builder() + .add("page", ColumnType.STRING) + .add("countryIsoCode", ColumnType.STRING) + .add("j0.countryIsoCode", ColumnType.STRING) + .add("j0.countryName", ColumnType.STRING) + .add("j0.countryNumber", ColumnType.LONG) + .build(); + + final SortMergeJoinFrameProcessor processor = new SortMergeJoinFrameProcessor( + factChannel, + countriesChannel, + outputChannel.writable(), + makeFrameWriterFactory(joinSignature), + "j0.", + ImmutableList.of( + ImmutableList.of(new KeyColumn("countryNumber", KeyOrder.ASCENDING)), + ImmutableList.of(new KeyColumn("countryNumber", KeyOrder.ASCENDING)) + ), + JoinType.LEFT + ); + + final String countryCodeForNull; + final String countryNameForNull; + final Long countryNumberForNull; + + if (NullHandling.sqlCompatible()) { + countryCodeForNull = null; + countryNameForNull = null; + countryNumberForNull = null; + } else { + // In default-value mode, null country number from the left-hand table converts to zero, which matches Australia. + countryCodeForNull = "AU"; + countryNameForNull = "Australia"; + countryNumberForNull = 0L; Review Comment: I'm like 80% sure this assumption would be incorrect if the long column actually came from the nested column instead of the "normal" top-level long column. That is, it's perhaps less a function of the "default-value mode" and more a function of the column implementation that happens to be getting used. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
