This is an automated email from the ASF dual-hosted git repository.

lakshsingla pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git


The following commit(s) were added to refs/heads/master by this push:
     new 58f3faf299 SortMergeJoinFrameProcessor: Fix two bugs with buffering. 
(#14196)
58f3faf299 is described below

commit 58f3faf2996051a037555d665da6a8781215e037
Author: Gian Merlino <[email protected]>
AuthorDate: Sun Jul 2 07:22:52 2023 -0700

    SortMergeJoinFrameProcessor: Fix two bugs with buffering. (#14196)
    
    1) Fix a problem where the fault wasn't reported when the left-hand side
       had too many buffered frames. (Instead, frames continued to be buffered,
       eventually running the server out of memory.)
    
    2) Always update the mark when rewinding isn't necessary. It fixes a 
problem where
       frames would be needlessly buffered when there isn't a key match across
       the two sides.
    
    3) Memory reserved for building the trackers now change based on the heap 
sized
---
 .../java/org/apache/druid/msq/exec/Limits.java     |   6 -
 .../druid/msq/exec/WorkerMemoryParameters.java     |  51 ++--
 .../error/TooManyRowsWithSameKeyFault.java         |   8 +-
 .../druid/msq/querykit/BroadcastJoinHelper.java    |   3 +-
 .../common/SortMergeJoinFrameProcessor.java        | 318 ++++++++++++++-------
 .../common/SortMergeJoinFrameProcessorFactory.java |   3 +-
 .../druid/msq/exec/WorkerMemoryParametersTest.java |  50 ++--
 .../common/SortMergeJoinFrameProcessorTest.java    | 287 ++++++++++++++++++-
 8 files changed, 550 insertions(+), 176 deletions(-)

diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Limits.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Limits.java
index c9e598c618..9069794222 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Limits.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Limits.java
@@ -69,12 +69,6 @@ public class Limits
    */
   public static final int MAX_KERNEL_MANIPULATION_QUEUE_SIZE = 100_000;
 
-  /**
-   * Maximum number of bytes buffered for each side of a
-   * {@link org.apache.druid.msq.querykit.common.SortMergeJoinFrameProcessor}, 
not counting the most recent frame read.
-   */
-  public static final int MAX_BUFFERED_BYTES_FOR_SORT_MERGE_JOIN = 10_000_000;
-
   /**
    * Maximum relaunches across all workers.
    */
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java
index f64e4dbd0e..4bddb949f0 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerMemoryParameters.java
@@ -137,29 +137,33 @@ public class WorkerMemoryParameters
    * we use a value somewhat lower than 0.5.
    */
   static final double BROADCAST_JOIN_MEMORY_FRACTION = 0.3;
+
+  /**
+   * Fraction of free memory per bundle that can be used by
+   * {@link org.apache.druid.msq.querykit.common.SortMergeJoinFrameProcessor} 
to buffer frames in its trackers.
+   */
+  static final double SORT_MERGE_JOIN_MEMORY_FRACTION = 0.9;
+
   /**
    * In case {@link NotEnoughMemoryFault} is thrown, a fixed estimation 
overhead is added when estimating total memory required for the process.
    */
   private static final long BUFFER_BYTES_FOR_ESTIMATION = 1000;
 
+  private final long processorBundleMemory;
   private final int superSorterMaxActiveProcessors;
   private final int superSorterMaxChannelsPerProcessor;
-  private final long appenderatorMemory;
-  private final long broadcastJoinMemory;
   private final int partitionStatisticsMaxRetainedBytes;
 
   WorkerMemoryParameters(
+      final long processorBundleMemory,
       final int superSorterMaxActiveProcessors,
       final int superSorterMaxChannelsPerProcessor,
-      final long appenderatorMemory,
-      final long broadcastJoinMemory,
       final int partitionStatisticsMaxRetainedBytes
   )
   {
+    this.processorBundleMemory = processorBundleMemory;
     this.superSorterMaxActiveProcessors = superSorterMaxActiveProcessors;
     this.superSorterMaxChannelsPerProcessor = 
superSorterMaxChannelsPerProcessor;
-    this.appenderatorMemory = appenderatorMemory;
-    this.broadcastJoinMemory = broadcastJoinMemory;
     this.partitionStatisticsMaxRetainedBytes = 
partitionStatisticsMaxRetainedBytes;
   }
 
@@ -344,10 +348,9 @@ public class WorkerMemoryParameters
     }
 
     return new WorkerMemoryParameters(
+        bundleMemoryForProcessing,
         superSorterMaxActiveProcessors,
         superSorterMaxChannelsPerProcessor,
-        (long) (bundleMemoryForProcessing * APPENDERATOR_MEMORY_FRACTION),
-        (long) (bundleMemoryForProcessing * BROADCAST_JOIN_MEMORY_FRACTION),
         Ints.checkedCast(workerMemory) // 100% of worker memory is devoted to 
partition statistics
     );
   }
@@ -365,13 +368,13 @@ public class WorkerMemoryParameters
   public long getAppenderatorMaxBytesInMemory()
   {
     // Half for indexing, half for merging.
-    return Math.max(1, appenderatorMemory / 2);
+    return Math.max(1, getAppenderatorMemory() / 2);
   }
 
   public int getAppenderatorMaxColumnsToMerge()
   {
     // Half for indexing, half for merging.
-    return Ints.checkedCast(Math.max(2, appenderatorMemory / 2 / 
APPENDERATOR_MERGE_ROUGH_MEMORY_PER_COLUMN));
+    return Ints.checkedCast(Math.max(2, getAppenderatorMemory() / 2 / 
APPENDERATOR_MERGE_ROUGH_MEMORY_PER_COLUMN));
   }
 
   public int getStandardFrameSize()
@@ -386,7 +389,12 @@ public class WorkerMemoryParameters
 
   public long getBroadcastJoinMemory()
   {
-    return broadcastJoinMemory;
+    return (long) (processorBundleMemory * BROADCAST_JOIN_MEMORY_FRACTION);
+  }
+
+  public long getSortMergeJoinMemory()
+  {
+    return (long) (processorBundleMemory * SORT_MERGE_JOIN_MEMORY_FRACTION);
   }
 
   public int getPartitionStatisticsMaxRetainedBytes()
@@ -394,6 +402,14 @@ public class WorkerMemoryParameters
     return partitionStatisticsMaxRetainedBytes;
   }
 
+  /**
+   * Amount of memory to devote to {@link 
org.apache.druid.segment.realtime.appenderator.Appenderator}.
+   */
+  private long getAppenderatorMemory()
+  {
+    return (long) (processorBundleMemory * APPENDERATOR_MEMORY_FRACTION);
+  }
+
   @Override
   public boolean equals(Object o)
   {
@@ -404,10 +420,9 @@ public class WorkerMemoryParameters
       return false;
     }
     WorkerMemoryParameters that = (WorkerMemoryParameters) o;
-    return superSorterMaxActiveProcessors == 
that.superSorterMaxActiveProcessors
+    return processorBundleMemory == that.processorBundleMemory
+           && superSorterMaxActiveProcessors == 
that.superSorterMaxActiveProcessors
            && superSorterMaxChannelsPerProcessor == 
that.superSorterMaxChannelsPerProcessor
-           && appenderatorMemory == that.appenderatorMemory
-           && broadcastJoinMemory == that.broadcastJoinMemory
            && partitionStatisticsMaxRetainedBytes == 
that.partitionStatisticsMaxRetainedBytes;
   }
 
@@ -415,10 +430,9 @@ public class WorkerMemoryParameters
   public int hashCode()
   {
     return Objects.hash(
+        processorBundleMemory,
         superSorterMaxActiveProcessors,
         superSorterMaxChannelsPerProcessor,
-        appenderatorMemory,
-        broadcastJoinMemory,
         partitionStatisticsMaxRetainedBytes
     );
   }
@@ -427,10 +441,9 @@ public class WorkerMemoryParameters
   public String toString()
   {
     return "WorkerMemoryParameters{" +
-           "superSorterMaxActiveProcessors=" + superSorterMaxActiveProcessors +
+           "processorBundleMemory=" + processorBundleMemory +
+           ", superSorterMaxActiveProcessors=" + 
superSorterMaxActiveProcessors +
            ", superSorterMaxChannelsPerProcessor=" + 
superSorterMaxChannelsPerProcessor +
-           ", appenderatorMemory=" + appenderatorMemory +
-           ", broadcastJoinMemory=" + broadcastJoinMemory +
            ", partitionStatisticsMaxRetainedBytes=" + 
partitionStatisticsMaxRetainedBytes +
            '}';
   }
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/TooManyRowsWithSameKeyFault.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/TooManyRowsWithSameKeyFault.java
index 21fa363af8..60d355579b 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/TooManyRowsWithSameKeyFault.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/TooManyRowsWithSameKeyFault.java
@@ -44,10 +44,12 @@ public class TooManyRowsWithSameKeyFault extends 
BaseMSQFault
   {
     super(
         CODE,
-        "Too many rows with the same key during sort-merge join (bytes 
buffered = %,d; limit = %,d). Key: %s",
+        "Too many rows with the same key[%s] during sort-merge join (bytes 
buffered[%,d], limit[%,d]). "
+        + "Try increasing heap memory available to workers, "
+        + "or adjusting your query to process fewer rows with this key.",
+        key,
         numBytes,
-        maxBytes,
-        key
+        maxBytes
     );
 
     this.key = key;
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BroadcastJoinHelper.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BroadcastJoinHelper.java
index 36dc52c5ce..d9e7bc6dee 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BroadcastJoinHelper.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/BroadcastJoinHelper.java
@@ -27,6 +27,7 @@ import org.apache.druid.frame.Frame;
 import org.apache.druid.frame.channel.ReadableFrameChannel;
 import org.apache.druid.frame.processor.FrameProcessors;
 import org.apache.druid.frame.read.FrameReader;
+import org.apache.druid.msq.exec.WorkerMemoryParameters;
 import org.apache.druid.msq.indexing.error.BroadcastTablesTooLargeFault;
 import org.apache.druid.msq.indexing.error.MSQException;
 import org.apache.druid.query.DataSource;
@@ -58,7 +59,7 @@ public class BroadcastJoinHelper
    * @param channels                         list of input channels
    * @param channelReaders                   list of input channel readers; 
corresponds one-to-one with "channels"
    * @param memoryReservedForBroadcastJoin   total bytes of frames we are 
permitted to use; derived from
-   *                                         {@link 
org.apache.druid.msq.exec.WorkerMemoryParameters#broadcastJoinMemory}
+   *                                         {@link 
WorkerMemoryParameters#getBroadcastJoinMemory()}
    */
   public BroadcastJoinHelper(
       final Int2IntMap inputNumberToProcessorChannelMap,
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessor.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessor.java
index 2c454e1d45..fdc80560f2 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessor.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessor.java
@@ -41,7 +41,6 @@ import org.apache.druid.frame.segment.FrameCursor;
 import org.apache.druid.frame.write.FrameWriter;
 import org.apache.druid.frame.write.FrameWriterFactory;
 import org.apache.druid.java.util.common.ISE;
-import org.apache.druid.msq.exec.Limits;
 import org.apache.druid.msq.indexing.error.MSQException;
 import org.apache.druid.msq.indexing.error.TooManyRowsWithSameKeyFault;
 import org.apache.druid.msq.input.ReadableInput;
@@ -122,6 +121,7 @@ public class SortMergeJoinFrameProcessor implements 
FrameProcessor<Long>
   private final String rightPrefix;
   private final JoinType joinType;
   private final JoinColumnSelectorFactory joinColumnSelectorFactory = new 
JoinColumnSelectorFactory();
+  private final long maxBufferedBytes;
   private FrameWriter frameWriter = null;
 
   // Used by runIncrementally to defer certain logic to the next run.
@@ -137,7 +137,8 @@ public class SortMergeJoinFrameProcessor implements 
FrameProcessor<Long>
       FrameWriterFactory frameWriterFactory,
       String rightPrefix,
       List<List<KeyColumn>> keyColumns,
-      JoinType joinType
+      JoinType joinType,
+      long maxBufferedBytes
   )
   {
     this.inputChannels = ImmutableList.of(left.getChannel(), 
right.getChannel());
@@ -146,9 +147,10 @@ public class SortMergeJoinFrameProcessor implements 
FrameProcessor<Long>
     this.rightPrefix = rightPrefix;
     this.joinType = joinType;
     this.trackers = ImmutableList.of(
-        new Tracker(left, keyColumns.get(LEFT)),
-        new Tracker(right, keyColumns.get(RIGHT))
+        new Tracker(left, keyColumns.get(LEFT), maxBufferedBytes),
+        new Tracker(right, keyColumns.get(RIGHT), maxBufferedBytes)
     );
+    this.maxBufferedBytes = maxBufferedBytes;
   }
 
   @Override
@@ -166,10 +168,10 @@ public class SortMergeJoinFrameProcessor implements 
FrameProcessor<Long>
   @Override
   public ReturnOrAwait<Long> runIncrementally(IntSet readableInputs) throws 
IOException
   {
-    // Fetch enough frames such that each tracker has one readable row.
+    // Fetch enough frames such that each tracker has one readable row (or is 
done).
     for (int i = 0; i < inputChannels.size(); i++) {
       final Tracker tracker = trackers.get(i);
-      if (tracker.isAtEndOfPushedData() && !pushNextFrame(i)) {
+      if (tracker.needsMoreDataForCurrentCursor() && !pushNextFrame(i)) {
         return nextAwait();
       }
     }
@@ -178,8 +180,8 @@ public class SortMergeJoinFrameProcessor implements 
FrameProcessor<Long>
     startNewFrameIfNeeded();
 
     while (!allTrackersAreAtEnd()
-           && !trackers.get(LEFT).needsMoreData()
-           && !trackers.get(RIGHT).needsMoreData()) {
+           && !trackers.get(LEFT).needsMoreDataForCurrentCursor()
+           && !trackers.get(RIGHT).needsMoreDataForCurrentCursor()) {
       // Algorithm can proceed: not all trackers are at the end of their 
streams, and no tracker needs more data to
       // read the current cursor or move it forward.
       if (nextIterationRunnable != null) {
@@ -192,21 +194,12 @@ public class SortMergeJoinFrameProcessor implements 
FrameProcessor<Long>
 
       // Two rows match if the keys compare equal _and_ neither key has a null 
component. (x JOIN y ON x.a = y.a does
       // not match rows where "x.a" is null.)
-      final boolean match = markCmp == 0 && 
trackers.get(LEFT).hasCompletelyNonNullMark();
-
-      // If marked keys are equal on both sides ("match"), at least one side 
must have a complete set of rows
-      // for the marked key.
-      if (match && trackerWithCompleteSetForCurrentKey < 0) {
-        for (int i = 0; i < inputChannels.size(); i++) {
-          final Tracker tracker = trackers.get(i);
-
-          // Fetch up to one frame from each tracker, to check if that tracker 
has a complete set.
-          // Can't fetch more than one frame, because channels are only 
guaranteed to have one frame per run.
-          if (tracker.hasCompleteSetForMark() || (pushNextFrame(i) && 
tracker.hasCompleteSetForMark())) {
-            trackerWithCompleteSetForCurrentKey = i;
-            break;
-          }
-        }
+      final boolean marksMatch = markCmp == 0 && 
trackers.get(LEFT).hasCompletelyNonNullMark();
+
+      // If marked keys are equal on both sides ("marksMatch"), at least one 
side needs to have a complete set of rows
+      // for the marked key. Check if this is true, otherwise call nextAwait 
to read more data.
+      if (marksMatch && trackerWithCompleteSetForCurrentKey < 0) {
+        updateTrackerWithCompleteSetForCurrentKey();
 
         if (trackerWithCompleteSetForCurrentKey < 0) {
           // Algorithm cannot proceed; fetch more frames on the next run.
@@ -214,93 +207,177 @@ public class SortMergeJoinFrameProcessor implements 
FrameProcessor<Long>
         }
       }
 
-      if (match || (markCmp <= 0 && joinType.isLefty()) || (markCmp >= 0 && 
joinType.isRighty())) {
-        // Emit row, if there's room in the current frameWriter.
-        joinColumnSelectorFactory.cmp = markCmp;
-        joinColumnSelectorFactory.match = match;
+      // Emit row if there was a match.
+      if (!emitRowIfNeeded(markCmp, marksMatch)) {
+        return ReturnOrAwait.runAgain();
+      }
+
+      // Advance one or both trackers.
+      advanceTrackersAfterEmittingRow(markCmp, marksMatch);
+    }
+
+    if (allTrackersAreAtEnd()) {
+      flushCurrentFrame();
+      return ReturnOrAwait.returnObject(0L);
+    } else {
+      // Keep reading.
+      return nextAwait();
+    }
+  }
+
+  @Override
+  public void cleanup() throws IOException
+  {
+    FrameProcessors.closeAll(inputChannels(), outputChannels(), frameWriter, 
() -> trackers.forEach(Tracker::clear));
+  }
+
+  /**
+   * Set {@link #trackerWithCompleteSetForCurrentKey} to the lowest-numbered 
{@link Tracker} that has a complete
+   * set of rows available for its mark.
+   */
+  private void updateTrackerWithCompleteSetForCurrentKey()
+  {
+    for (int i = 0; i < inputChannels.size(); i++) {
+      final Tracker tracker = trackers.get(i);
+
+      // Fetch up to one frame from each tracker, to check if that tracker has 
a complete set.
+      // Can't fetch more than one frame, because channels are only guaranteed 
to have one frame per run.
+      if (tracker.hasCompleteSetForMark() || (pushNextFrame(i) && 
tracker.hasCompleteSetForMark())) {
+        trackerWithCompleteSetForCurrentKey = i;
+        return;
+      }
+    }
+
+    trackerWithCompleteSetForCurrentKey = -1;
+  }
+
+  /**
+   * Emits a joined row based on the current state of all trackers.
+   *
+   * @param markCmp    result of {@link #compareMarks()}
+   * @param marksMatch whether the marks actually matched, taking nulls into 
account
+   *
+   * @return true if cursors should be advanced, false if we should run again 
without moving cursors
+   */
+  private boolean emitRowIfNeeded(final int markCmp, final boolean marksMatch) 
throws IOException
+  {
+    if (marksMatch || (markCmp <= 0 && joinType.isLefty()) || (markCmp >= 0 && 
joinType.isRighty())) {
+      // Emit row, if there's room in the current frameWriter.
+      joinColumnSelectorFactory.cmp = markCmp;
+      joinColumnSelectorFactory.match = marksMatch;
+
+      if (!frameWriter.addSelection()) {
+        if (frameWriter.getNumRows() > 0) {
+          // Out of space in the current frame. Run again without moving 
cursors.
+          flushCurrentFrame();
+          return false;
+        } else {
+          throw new 
FrameRowTooLargeException(frameWriterFactory.allocatorCapacity());
+        }
+      }
+    }
+
+    return true;
+  }
 
-        if (!frameWriter.addSelection()) {
-          if (frameWriter.getNumRows() > 0) {
-            // Out of space in the current frame. Run again without moving 
cursors.
-            flushCurrentFrame();
-            return ReturnOrAwait.runAgain();
+  /**
+   * Advance one or both trackers after emitting a row.
+   *
+   * @param markCmp    result of {@link #compareMarks()}
+   * @param marksMatch whether the marks actually matched, taking nulls into 
account
+   */
+  private void advanceTrackersAfterEmittingRow(final int markCmp, final 
boolean marksMatch)
+  {
+    if (marksMatch) {
+      // Matching keys. First advance the tracker with the complete set.
+      final Tracker completeSetTracker = 
trackers.get(trackerWithCompleteSetForCurrentKey);
+      final Tracker otherTracker = 
trackers.get(trackerWithCompleteSetForCurrentKey == LEFT ? RIGHT : LEFT);
+
+      completeSetTracker.advance();
+      if (!completeSetTracker.isCurrentSameKeyAsMark()) {
+        // Reached end of complete set. Advance the other tracker.
+        otherTracker.advance();
+
+        // On next iteration (when we're sure to have data) either rewind the 
complete-set tracker, or update marks
+        // of both, as appropriate.
+        onNextIteration(() -> {
+          if (otherTracker.isCurrentSameKeyAsMark()) {
+            completeSetTracker.rewindToMark();
           } else {
-            throw new 
FrameRowTooLargeException(frameWriterFactory.allocatorCapacity());
+            // Reached end of the other side too. Advance marks on both 
trackers.
+            completeSetTracker.markCurrent();
+            trackerWithCompleteSetForCurrentKey = -1;
           }
-        }
+
+          // Always update mark of the other tracker, to enable cleanup of old 
frames. It doesn't ever need to
+          // be rewound.
+          otherTracker.markCurrent();
+        });
+      }
+    } else {
+      // Keys don't match. Advance based on what kind of join this is.
+      final int trackerToAdvance;
+      final boolean skipMarkedKey;
+
+      if (markCmp < 0) {
+        trackerToAdvance = LEFT;
+      } else if (markCmp > 0) {
+        trackerToAdvance = RIGHT;
+      } else {
+        // Key is null on both sides. Note that there is a preference for 
running through the left side first
+        // on a FULL join. It doesn't really matter which side we run through 
first, but we do need to be consistent
+        // for the benefit of the logic in "shouldEmitColumnValue".
+        trackerToAdvance = joinType.isLefty() ? LEFT : RIGHT;
       }
 
-      // Advance one or both trackers.
-      if (match) {
-        // Matching keys. First advance the tracker with the complete set.
-        final Tracker tracker = 
trackers.get(trackerWithCompleteSetForCurrentKey);
-        final Tracker otherTracker = 
trackers.get(trackerWithCompleteSetForCurrentKey == LEFT ? RIGHT : LEFT);
+      // Skip marked key entirely if we're on the "off" side of the join. 
(i.e., right side of a LEFT join.)
+      // Note that for FULL joins, entire keys are never skipped, because they 
are both lefty and righty.
+      if (trackerToAdvance == LEFT) {
+        skipMarkedKey = !joinType.isLefty();
+      } else {
+        skipMarkedKey = !joinType.isRighty();
+      }
+
+      final Tracker tracker = trackers.get(trackerToAdvance);
 
+      // Advance past marked key, or as far as we can.
+      boolean didKeyChange = false;
+
+      do {
+        // Always advance a single row. If we're in "skipMarkedKey" mode, then 
we'll loop through later and
+        // potentially skip multiple rows with the same marked key.
         tracker.advance();
-        if (!tracker.isCurrentSameKeyAsMark()) {
-          // Reached end of complete set. Advance the other tracker.
-          otherTracker.advance();
-
-          // On next iteration (when we're sure to have data) either rewind 
the complete-set tracker, or update marks
-          // of both, as appropriate.
-          onNextIteration(() -> {
-            if (otherTracker.isCurrentSameKeyAsMark()) {
-              otherTracker.markCurrent(); // Set mark to enable cleanup of old 
frames.
-              tracker.rewindToMark();
-            } else {
-              // Reached end of the other side too. Advance marks on both 
trackers.
-              tracker.markCurrent();
-              otherTracker.markCurrent();
-              trackerWithCompleteSetForCurrentKey = -1;
-            }
-          });
-        }
-      } else {
-        final int trackerToAdvance;
 
-        if (markCmp < 0) {
-          trackerToAdvance = LEFT;
-        } else if (markCmp > 0) {
-          trackerToAdvance = RIGHT;
-        } else {
-          // Key is null on both sides. Note that there is a preference for 
running through the left side first
-          // on a FULL join. It doesn't really matter which side we run 
through first, but we do need to be consistent
-          // for the benefit of the logic in "shouldEmitColumnValue".
-          trackerToAdvance = joinType.isLefty() ? LEFT : RIGHT;
+        if (tracker.isAtEndOfPushedData()) {
+          break;
         }
 
-        final Tracker tracker = trackers.get(trackerToAdvance);
+        didKeyChange = !tracker.isCurrentSameKeyAsMark();
 
-        tracker.advance();
+        // Always update mark, even if key hasn't changed, to enable cleanup 
of old frames.
+        tracker.markCurrent();
+      } while (skipMarkedKey && !didKeyChange);
 
-        // On next iteration (when we're sure to have data), update mark if 
the key changed.
+      if (didKeyChange) {
+        trackerWithCompleteSetForCurrentKey = -1;
+      } else if (tracker.isAtEndOfPushedData()) {
+        // Not clear if we reached a new key or not.
+        // So, on next iteration (when we're sure to have data), check if 
we've moved on to a new key.
         onNextIteration(() -> {
           if (!tracker.isCurrentSameKeyAsMark()) {
-            tracker.markCurrent();
             trackerWithCompleteSetForCurrentKey = -1;
           }
+
+          // Always update mark, even if key hasn't changed, to enable cleanup 
of old frames.
+          tracker.markCurrent();
         });
       }
     }
-
-    if (allTrackersAreAtEnd()) {
-      flushCurrentFrame();
-      return ReturnOrAwait.returnObject(0L);
-    } else {
-      // Keep reading.
-      return nextAwait();
-    }
-  }
-
-  @Override
-  public void cleanup() throws IOException
-  {
-    FrameProcessors.closeAll(inputChannels(), outputChannels(), frameWriter, 
() -> trackers.forEach(Tracker::clear));
   }
 
   /**
-   * Returns a {@link ReturnOrAwait#awaitAll} for the channel numbers that 
need more data and have not yet hit their
-   * buffered-bytes limit, {@link 
Limits#MAX_BUFFERED_BYTES_FOR_SORT_MERGE_JOIN}.
+   * Returns a {@link ReturnOrAwait#awaitAll} for channels where {@link 
Tracker#needsMoreDataForCurrentCursor()}
+   * and {@link Tracker#canBufferMoreFrames()}.
    *
    * If all channels have hit their limit, throws {@link MSQException} with 
{@link TooManyRowsWithSameKeyFault}.
    */
@@ -309,10 +386,11 @@ public class SortMergeJoinFrameProcessor implements 
FrameProcessor<Long>
     final IntSet awaitSet = new IntOpenHashSet();
     int trackerAtLimit = -1;
 
+    // Add all trackers that "needsMoreData" to awaitSet.
     for (int i = 0; i < inputChannels.size(); i++) {
       final Tracker tracker = trackers.get(i);
-      if (tracker.needsMoreData()) {
-        if (tracker.totalBytesBuffered() < 
Limits.MAX_BUFFERED_BYTES_FOR_SORT_MERGE_JOIN) {
+      if (tracker.needsMoreDataForCurrentCursor()) {
+        if (tracker.canBufferMoreFrames()) {
           awaitSet.add(i);
         } else if (trackerAtLimit < 0) {
           trackerAtLimit = i;
@@ -320,19 +398,31 @@ public class SortMergeJoinFrameProcessor implements 
FrameProcessor<Long>
       }
     }
 
-    if (awaitSet.isEmpty() && trackerAtLimit > 0) {
+    if (awaitSet.isEmpty()) {
+      // No tracker reported that it "needsMoreData" to read the current 
cursor. However, we may still need to read
+      // more data to have a complete set for the current mark.
+      for (int i = 0; i < inputChannels.size(); i++) {
+        final Tracker tracker = trackers.get(i);
+        if (!tracker.hasCompleteSetForMark()) {
+          if (tracker.canBufferMoreFrames()) {
+            awaitSet.add(i);
+          } else if (trackerAtLimit < 0) {
+            trackerAtLimit = i;
+          }
+        }
+      }
+    }
+
+    if (awaitSet.isEmpty() && trackerAtLimit >= 0) {
       // All trackers that need more data are at their max buffered bytes 
limit. Generate a nice exception.
       final Tracker tracker = trackers.get(trackerAtLimit);
-      if (tracker.totalBytesBuffered() > 
Limits.MAX_BUFFERED_BYTES_FOR_SORT_MERGE_JOIN) {
-        // Generate a nice exception.
-        throw new MSQException(
-            new TooManyRowsWithSameKeyFault(
-                tracker.readMarkKey(),
-                tracker.totalBytesBuffered(),
-                Limits.MAX_BUFFERED_BYTES_FOR_SORT_MERGE_JOIN
-            )
-        );
-      }
+      throw new MSQException(
+          new TooManyRowsWithSameKeyFault(
+              tracker.readMarkKey(),
+              tracker.totalBytesBuffered(),
+              maxBufferedBytes
+          )
+      );
     }
 
     return ReturnOrAwait.awaitAll(awaitSet);
@@ -353,7 +443,13 @@ public class SortMergeJoinFrameProcessor implements 
FrameProcessor<Long>
   }
 
   /**
-   * Compares the marked rows of the two {@link #trackers}.
+   * Compares the marked rows of the two {@link #trackers}. This method 
returns 0 if both sides are null, even
+   * though this is not considered a match by join semantics. Therefore, it is 
important to also check
+   * {@link Tracker#hasCompletelyNonNullMark()}.
+   *
+   * @return negative if {@link #LEFT} key is earlier, positive if {@link 
#RIGHT} key is earlier, zero if the keys
+   * are the same. Returns zero even if a key component is null, even though 
this is not considered a match by
+   * join semantics.
    *
    * @throws IllegalStateException if either tracker does not have a marked 
row and is not completely done
    */
@@ -394,6 +490,8 @@ public class SortMergeJoinFrameProcessor implements 
FrameProcessor<Long>
     } else if (channel.isFinished()) {
       tracker.push(null);
       return true;
+    } else if (!tracker.canBufferMoreFrames()) {
+      return false;
     } else {
       final Frame frame = channel.read();
 
@@ -450,6 +548,7 @@ public class SortMergeJoinFrameProcessor implements 
FrameProcessor<Long>
     private final List<FrameHolder> holders = new ArrayList<>();
     private final ReadableInput input;
     private final List<KeyColumn> keyColumns;
+    private final long maxBytesBuffered;
 
     // markFrame and markRow are the first frame and row with the current key.
     private int markFrame = -1;
@@ -461,10 +560,11 @@ public class SortMergeJoinFrameProcessor implements 
FrameProcessor<Long>
     // done indicates that no more data is available in the channel.
     private boolean done;
 
-    public Tracker(ReadableInput input, List<KeyColumn> keyColumns)
+    public Tracker(ReadableInput input, List<KeyColumn> keyColumns, long 
maxBytesBuffered)
     {
       this.input = input;
       this.keyColumns = keyColumns;
+      this.maxBytesBuffered = maxBytesBuffered;
     }
 
     /**
@@ -533,6 +633,16 @@ public class SortMergeJoinFrameProcessor implements 
FrameProcessor<Long>
       return bytes;
     }
 
+    /**
+     * Whether this tracker can accept more frames without exceeding {@link 
#maxBufferedBytes}. Always returns true
+     * if the number of buffered frames is zero or one, because the join 
algorithm may require two frames being
+     * buffered. (For example, if we need to verify that the last row in a 
frame contains a complete set of a key.)
+     */
+    public boolean canBufferMoreFrames()
+    {
+      return holders.size() <= 1 || totalBytesBuffered() < maxBytesBuffered;
+    }
+
     /**
      * Cursor containing the current row.
      */
@@ -655,7 +765,7 @@ public class SortMergeJoinFrameProcessor implements 
FrameProcessor<Long>
     /**
      * Whether this tracker needs more data in order to read the current 
cursor location or move it forward.
      */
-    public boolean needsMoreData()
+    public boolean needsMoreDataForCurrentCursor()
     {
       return !done && isAtEndOfPushedData();
     }
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessorFactory.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessorFactory.java
index 9aa5063092..76e05d3ce0 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessorFactory.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessorFactory.java
@@ -180,7 +180,8 @@ public class SortMergeJoinFrameProcessorFactory extends 
BaseFrameProcessorFactor
               
stageDefinition.createFrameWriterFactory(outputChannel.getFrameMemoryAllocator()),
               rightPrefix,
               keyColumns,
-              joinType
+              joinType,
+              frameContext.memoryParameters().getSortMergeJoinMemory()
           );
         }
     );
diff --git 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerMemoryParametersTest.java
 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerMemoryParametersTest.java
index 78ecacbef2..29614fc073 100644
--- 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerMemoryParametersTest.java
+++ 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerMemoryParametersTest.java
@@ -32,11 +32,11 @@ public class WorkerMemoryParametersTest
   @Test
   public void test_oneWorkerInJvm_alone()
   {
-    Assert.assertEquals(params(1, 41, 224_785_000, 100_650_000, 75_000_000), 
create(1_000_000_000, 1, 1, 1, 0, 0));
-    Assert.assertEquals(params(2, 13, 149_410_000, 66_900_000, 75_000_000), 
create(1_000_000_000, 1, 2, 1, 0, 0));
-    Assert.assertEquals(params(4, 3, 89_110_000, 39_900_000, 75_000_000), 
create(1_000_000_000, 1, 4, 1, 0, 0));
-    Assert.assertEquals(params(3, 2, 48_910_000, 21_900_000, 75_000_000), 
create(1_000_000_000, 1, 8, 1, 0, 0));
-    Assert.assertEquals(params(2, 2, 33_448_460, 14_976_922, 75_000_000), 
create(1_000_000_000, 1, 12, 1, 0, 0));
+    Assert.assertEquals(params(335_500_000, 1, 41, 75_000_000), 
create(1_000_000_000, 1, 1, 1, 0, 0));
+    Assert.assertEquals(params(223_000_000, 2, 13, 75_000_000), 
create(1_000_000_000, 1, 2, 1, 0, 0));
+    Assert.assertEquals(params(133_000_000, 4, 3, 75_000_000), 
create(1_000_000_000, 1, 4, 1, 0, 0));
+    Assert.assertEquals(params(73_000_000, 3, 2, 75_000_000), 
create(1_000_000_000, 1, 8, 1, 0, 0));
+    Assert.assertEquals(params(49_923_076, 2, 2, 75_000_000), 
create(1_000_000_000, 1, 12, 1, 0, 0));
 
     final MSQException e = Assert.assertThrows(
         MSQException.class,
@@ -54,8 +54,8 @@ public class WorkerMemoryParametersTest
   @Test
   public void test_oneWorkerInJvm_twoHundredWorkersInCluster()
   {
-    Assert.assertEquals(params(1, 83, 317_580_000, 142_200_000, 150_000_000), 
create(2_000_000_000, 1, 1, 200, 0, 0));
-    Assert.assertEquals(params(2, 27, 166_830_000, 74_700_000, 150_000_000), 
create(2_000_000_000, 1, 2, 200, 0, 0));
+    Assert.assertEquals(params(474_000_000, 1, 83, 150_000_000), 
create(2_000_000_000, 1, 1, 200, 0, 0));
+    Assert.assertEquals(params(249_000_000, 2, 27, 150_000_000), 
create(2_000_000_000, 1, 2, 200, 0, 0));
 
     final MSQException e = Assert.assertThrows(
         MSQException.class,
@@ -68,11 +68,11 @@ public class WorkerMemoryParametersTest
   @Test
   public void test_fourWorkersInJvm_twoHundredWorkersInCluster()
   {
-    Assert.assertEquals(params(1, 150, 679_380_000, 304_200_000, 168_750_000), 
create(9_000_000_000L, 4, 1, 200, 0, 0));
-    Assert.assertEquals(params(2, 62, 543_705_000, 243_450_000, 168_750_000), 
create(9_000_000_000L, 4, 2, 200, 0, 0));
-    Assert.assertEquals(params(4, 22, 374_111_250, 167_512_500, 168_750_000), 
create(9_000_000_000L, 4, 4, 200, 0, 0));
-    Assert.assertEquals(params(4, 14, 204_517_500, 91_575_000, 168_750_000), 
create(9_000_000_000L, 4, 8, 200, 0, 0));
-    Assert.assertEquals(params(4, 8, 68_842_500, 30_825_000, 168_750_000), 
create(9_000_000_000L, 4, 16, 200, 0, 0));
+    Assert.assertEquals(params(1_014_000_000, 1, 150, 168_750_000), 
create(9_000_000_000L, 4, 1, 200, 0, 0));
+    Assert.assertEquals(params(811_500_000, 2, 62, 168_750_000), 
create(9_000_000_000L, 4, 2, 200, 0, 0));
+    Assert.assertEquals(params(558_375_000, 4, 22, 168_750_000), 
create(9_000_000_000L, 4, 4, 200, 0, 0));
+    Assert.assertEquals(params(305_250_000, 4, 14, 168_750_000), 
create(9_000_000_000L, 4, 8, 200, 0, 0));
+    Assert.assertEquals(params(102_750_000, 4, 8, 168_750_000), 
create(9_000_000_000L, 4, 16, 200, 0, 0));
 
     final MSQException e = Assert.assertThrows(
         MSQException.class,
@@ -82,7 +82,7 @@ public class WorkerMemoryParametersTest
     Assert.assertEquals(new TooManyWorkersFault(200, 124), e.getFault());
 
     // Make sure 124 actually works, and 125 doesn't. (Verify the error 
message above.)
-    Assert.assertEquals(params(4, 3, 16_750_000, 7_500_000, 150_000_000), 
create(8_000_000_000L, 4, 32, 124, 0, 0));
+    Assert.assertEquals(params(25_000_000, 4, 3, 150_000_000), 
create(8_000_000_000L, 4, 32, 124, 0, 0));
 
     final MSQException e2 = Assert.assertThrows(
         MSQException.class,
@@ -96,8 +96,8 @@ public class WorkerMemoryParametersTest
   public void test_oneWorkerInJvm_smallWorkerCapacity()
   {
     // Supersorter max channels per processer are one less than they are 
usually to account for extra frames that are required while creating composing 
output channels
-    Assert.assertEquals(params(1, 3, 27_604_000, 12_360_000, 9_600_000), 
create(128_000_000, 1, 1, 1, 0, 0));
-    Assert.assertEquals(params(1, 1, 17_956_000, 8_040_000, 9_600_000), 
create(128_000_000, 1, 2, 1, 0, 0));
+    Assert.assertEquals(params(41_200_000, 1, 3, 9_600_000), 
create(128_000_000, 1, 1, 1, 0, 0));
+    Assert.assertEquals(params(26_800_000, 1, 1, 9_600_000), 
create(128_000_000, 1, 2, 1, 0, 0));
 
     final MSQException e = Assert.assertThrows(
         MSQException.class,
@@ -120,14 +120,10 @@ public class WorkerMemoryParametersTest
   @Test
   public void test_fourWorkersInJvm_twoHundredWorkersInCluster_hashPartitions()
   {
-    Assert.assertEquals(
-        params(1, 150, 545_380_000, 244_200_000, 168_750_000), 
create(9_000_000_000L, 4, 1, 200, 200, 0));
-    Assert.assertEquals(
-        params(2, 62, 409_705_000, 183_450_000, 168_750_000), 
create(9_000_000_000L, 4, 2, 200, 200, 0));
-    Assert.assertEquals(
-        params(4, 22, 240_111_250, 107_512_500, 168_750_000), 
create(9_000_000_000L, 4, 4, 200, 200, 0));
-    Assert.assertEquals(
-        params(4, 14, 70_517_500, 31_575_000, 168_750_000), 
create(9_000_000_000L, 4, 8, 200, 200, 0));
+    Assert.assertEquals(params(814_000_000, 1, 150, 168_750_000), 
create(9_000_000_000L, 4, 1, 200, 200, 0));
+    Assert.assertEquals(params(611_500_000, 2, 62, 168_750_000), 
create(9_000_000_000L, 4, 2, 200, 200, 0));
+    Assert.assertEquals(params(358_375_000, 4, 22, 168_750_000), 
create(9_000_000_000L, 4, 4, 200, 200, 0));
+    Assert.assertEquals(params(105_250_000, 4, 14, 168_750_000), 
create(9_000_000_000L, 4, 8, 200, 200, 0));
 
     final MSQException e = Assert.assertThrows(
         MSQException.class,
@@ -137,7 +133,7 @@ public class WorkerMemoryParametersTest
     Assert.assertEquals(new TooManyWorkersFault(200, 138), e.getFault());
 
     // Make sure 138 actually works, and 139 doesn't. (Verify the error 
message above.)
-    Assert.assertEquals(params(4, 8, 17_922_500, 8_025_000, 168_750_000), 
create(9_000_000_000L, 4, 16, 138, 138, 0));
+    Assert.assertEquals(params(26_750_000, 4, 8, 168_750_000), 
create(9_000_000_000L, 4, 16, 138, 138, 0));
 
     final MSQException e2 = Assert.assertThrows(
         MSQException.class,
@@ -165,18 +161,16 @@ public class WorkerMemoryParametersTest
   }
 
   private static WorkerMemoryParameters params(
+      final long processorBundleMemory,
       final int superSorterMaxActiveProcessors,
       final int superSorterMaxChannelsPerProcessor,
-      final long appenderatorMemory,
-      final long broadcastJoinMemory,
       final int partitionStatisticsMaxRetainedBytes
   )
   {
     return new WorkerMemoryParameters(
+        processorBundleMemory,
         superSorterMaxActiveProcessors,
         superSorterMaxChannelsPerProcessor,
-        appenderatorMemory,
-        broadcastJoinMemory,
         partitionStatisticsMaxRetainedBytes
     );
   }
diff --git 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessorTest.java
 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessorTest.java
index cfc74d792f..060b14cec1 100644
--- 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessorTest.java
+++ 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/querykit/common/SortMergeJoinFrameProcessorTest.java
@@ -20,8 +20,10 @@
 package org.apache.druid.msq.querykit.common;
 
 import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.Iterables;
 import com.google.common.collect.Lists;
+import com.google.common.primitives.Ints;
 import com.google.common.util.concurrent.ListenableFuture;
 import com.google.common.util.concurrent.MoreExecutors;
 import org.apache.druid.common.config.NullHandling;
@@ -46,6 +48,8 @@ import org.apache.druid.java.util.common.Intervals;
 import org.apache.druid.java.util.common.concurrent.Execs;
 import org.apache.druid.java.util.common.guava.Sequence;
 import org.apache.druid.java.util.common.guava.Sequences;
+import org.apache.druid.msq.indexing.error.MSQException;
+import org.apache.druid.msq.indexing.error.TooManyRowsWithSameKeyFault;
 import org.apache.druid.msq.input.ReadableInput;
 import org.apache.druid.msq.kernel.StageId;
 import org.apache.druid.msq.kernel.StagePartition;
@@ -58,8 +62,12 @@ import org.apache.druid.segment.join.JoinTestHelper;
 import org.apache.druid.segment.join.JoinType;
 import org.apache.druid.testing.InitializedNullHandlingTest;
 import org.apache.druid.timeline.SegmentId;
+import org.hamcrest.CoreMatchers;
+import org.hamcrest.MatcherAssert;
+import org.hamcrest.Matchers;
 import org.junit.After;
 import org.junit.Assert;
+import org.junit.Assume;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
@@ -79,6 +87,7 @@ import java.util.concurrent.TimeUnit;
 public class SortMergeJoinFrameProcessorTest extends 
InitializedNullHandlingTest
 {
   private static final StagePartition STAGE_PARTITION = new StagePartition(new 
StageId("q", 0), 0);
+  private static final long MAX_BUFFERED_BYTES = 10_000_000;
 
   private final int rowsPerInputFrame;
   private final int rowsPerOutputFrame;
@@ -154,7 +163,8 @@ public class SortMergeJoinFrameProcessorTest extends 
InitializedNullHandlingTest
             ImmutableList.of(new KeyColumn("countryIsoCode", 
KeyOrder.ASCENDING)),
             ImmutableList.of(new KeyColumn("countryIsoCode", 
KeyOrder.ASCENDING))
         ),
-        JoinType.LEFT
+        JoinType.LEFT,
+        MAX_BUFFERED_BYTES
     );
 
     assertResult(processor, outputChannel.readable(), joinSignature, 
Collections.emptyList());
@@ -198,7 +208,8 @@ public class SortMergeJoinFrameProcessorTest extends 
InitializedNullHandlingTest
             ImmutableList.of(new KeyColumn("countryIsoCode", 
KeyOrder.ASCENDING)),
             ImmutableList.of(new KeyColumn("countryIsoCode", 
KeyOrder.ASCENDING))
         ),
-        JoinType.LEFT
+        JoinType.LEFT,
+        MAX_BUFFERED_BYTES
     );
 
     final List<List<Object>> expectedRows = Arrays.asList(
@@ -273,7 +284,8 @@ public class SortMergeJoinFrameProcessorTest extends 
InitializedNullHandlingTest
             ImmutableList.of(new KeyColumn("countryIsoCode", 
KeyOrder.ASCENDING)),
             ImmutableList.of(new KeyColumn("countryIsoCode", 
KeyOrder.ASCENDING))
         ),
-        JoinType.INNER
+        JoinType.INNER,
+        MAX_BUFFERED_BYTES
     );
 
     assertResult(processor, outputChannel.readable(), joinSignature, 
Collections.emptyList());
@@ -313,7 +325,8 @@ public class SortMergeJoinFrameProcessorTest extends 
InitializedNullHandlingTest
             ImmutableList.of(new KeyColumn("countryIsoCode", 
KeyOrder.ASCENDING)),
             ImmutableList.of(new KeyColumn("countryIsoCode", 
KeyOrder.ASCENDING))
         ),
-        JoinType.LEFT
+        JoinType.LEFT,
+        MAX_BUFFERED_BYTES
     );
 
     final List<List<Object>> expectedRows = Arrays.asList(
@@ -383,7 +396,8 @@ public class SortMergeJoinFrameProcessorTest extends 
InitializedNullHandlingTest
         makeFrameWriterFactory(joinSignature),
         "j0.",
         ImmutableList.of(Collections.emptyList(), Collections.emptyList()),
-        JoinType.INNER
+        JoinType.INNER,
+        MAX_BUFFERED_BYTES
     );
 
     final List<List<Object>> expectedRows = Arrays.asList(
@@ -495,7 +509,8 @@ public class SortMergeJoinFrameProcessorTest extends 
InitializedNullHandlingTest
                 new KeyColumn("regionIsoCode", KeyOrder.ASCENDING)
             )
         ),
-        JoinType.LEFT
+        JoinType.LEFT,
+        MAX_BUFFERED_BYTES
     );
 
     final List<List<Object>> expectedRows = Arrays.asList(
@@ -573,7 +588,8 @@ public class SortMergeJoinFrameProcessorTest extends 
InitializedNullHandlingTest
             ImmutableList.of(new KeyColumn("regionIsoCode", 
KeyOrder.ASCENDING)),
             ImmutableList.of(new KeyColumn("regionIsoCode", 
KeyOrder.ASCENDING))
         ),
-        JoinType.RIGHT
+        JoinType.RIGHT,
+        MAX_BUFFERED_BYTES
     );
 
     final List<List<Object>> expectedRows = Arrays.asList(
@@ -654,7 +670,8 @@ public class SortMergeJoinFrameProcessorTest extends 
InitializedNullHandlingTest
             ImmutableList.of(new KeyColumn("regionIsoCode", 
KeyOrder.ASCENDING)),
             ImmutableList.of(new KeyColumn("regionIsoCode", 
KeyOrder.ASCENDING))
         ),
-        JoinType.FULL
+        JoinType.FULL,
+        MAX_BUFFERED_BYTES
     );
 
     final List<List<Object>> expectedRows = Arrays.asList(
@@ -732,7 +749,8 @@ public class SortMergeJoinFrameProcessorTest extends 
InitializedNullHandlingTest
             ImmutableList.of(new KeyColumn("countryNumber", 
KeyOrder.ASCENDING)),
             ImmutableList.of(new KeyColumn("countryNumber", 
KeyOrder.ASCENDING))
         ),
-        JoinType.LEFT
+        JoinType.LEFT,
+        MAX_BUFFERED_BYTES
     );
 
     final String countryCodeForNull;
@@ -825,7 +843,8 @@ public class SortMergeJoinFrameProcessorTest extends 
InitializedNullHandlingTest
             ImmutableList.of(new KeyColumn("countryNumber", 
KeyOrder.ASCENDING)),
             ImmutableList.of(new KeyColumn("countryNumber", 
KeyOrder.ASCENDING))
         ),
-        JoinType.RIGHT
+        JoinType.RIGHT,
+        MAX_BUFFERED_BYTES
     );
 
     final String countryCodeForNull;
@@ -918,7 +937,8 @@ public class SortMergeJoinFrameProcessorTest extends 
InitializedNullHandlingTest
             ImmutableList.of(new KeyColumn("countryIsoCode", 
KeyOrder.ASCENDING)),
             ImmutableList.of(new KeyColumn("countryIsoCode", 
KeyOrder.ASCENDING))
         ),
-        JoinType.INNER
+        JoinType.INNER,
+        MAX_BUFFERED_BYTES
     );
 
     final List<List<Object>> expectedRows = Arrays.asList(
@@ -950,6 +970,234 @@ public class SortMergeJoinFrameProcessorTest extends 
InitializedNullHandlingTest
     assertResult(processor, outputChannel.readable(), joinSignature, 
expectedRows);
   }
 
+  @Test
+  public void testInnerJoinCountryIsoCode_withMaxBufferedBytesLimit_succeeds() 
throws Exception
+  {
+    final ReadableInput factChannel = buildFactInput(
+        ImmutableList.of(
+            new KeyColumn("countryIsoCode", KeyOrder.ASCENDING),
+            new KeyColumn("page", KeyOrder.ASCENDING)
+        )
+    );
+
+    final ReadableInput countriesChannel =
+        buildCountriesInput(ImmutableList.of(new KeyColumn("countryIsoCode", 
KeyOrder.ASCENDING)));
+
+    final BlockingQueueFrameChannel outputChannel = 
BlockingQueueFrameChannel.minimal();
+
+    final RowSignature joinSignature =
+        RowSignature.builder()
+                    .add("page", ColumnType.STRING)
+                    .add("countryIsoCode", ColumnType.STRING)
+                    .add("j0.countryIsoCode", ColumnType.STRING)
+                    .add("j0.countryName", ColumnType.STRING)
+                    .add("j0.countryNumber", ColumnType.LONG)
+                    .build();
+
+    final SortMergeJoinFrameProcessor processor = new 
SortMergeJoinFrameProcessor(
+        factChannel,
+        countriesChannel,
+        outputChannel.writable(),
+        makeFrameWriterFactory(joinSignature),
+        "j0.",
+        ImmutableList.of(
+            ImmutableList.of(new KeyColumn("countryIsoCode", 
KeyOrder.ASCENDING)),
+            ImmutableList.of(new KeyColumn("countryIsoCode", 
KeyOrder.ASCENDING))
+        ),
+        JoinType.INNER,
+        1
+    );
+
+    final List<List<Object>> expectedRows = Arrays.asList(
+        Arrays.asList("Peremptory norm", "AU", "AU", "Australia", 0L),
+        Arrays.asList("Didier Leclair", "CA", "CA", "Canada", 1L),
+        Arrays.asList("Les Argonautes", "CA", "CA", "Canada", 1L),
+        Arrays.asList("Sarah Michelle Gellar", "CA", "CA", "Canada", 1L),
+        Arrays.asList("Golpe de Estado en Chile de 1973", "CL", "CL", "Chile", 
2L),
+        Arrays.asList("Diskussion:Sebastian Schulz", "DE", "DE", "Germany", 
3L),
+        Arrays.asList("Gabinete Ministerial de Rafael Correa", "EC", "EC", 
"Ecuador", 4L),
+        Arrays.asList("Saison 9 de Secret Story", "FR", "FR", "France", 5L),
+        Arrays.asList("Glasgow", "GB", "GB", "United Kingdom", 6L),
+        Arrays.asList("Giusy Ferreri discography", "IT", "IT", "Italy", 7L),
+        Arrays.asList("Roma-Bangkok", "IT", "IT", "Italy", 7L),
+        Arrays.asList("青野武", "JP", "JP", "Japan", 8L),
+        Arrays.asList("유희왕 GX", "KR", "KR", "Republic of Korea", 9L),
+        Arrays.asList("History of Fourems", "MMMM", "MMMM", "Fourems", 205L),
+        Arrays.asList("Mathis Bolly", "MX", "MX", "Mexico", 10L),
+        Arrays.asList("Алиса в Зазеркалье", "NO", "NO", "Norway", 11L),
+        Arrays.asList("Cream Soda", "SU", "SU", "States United", 15L),
+        Arrays.asList("Wendigo", "SV", "SV", "El Salvador", 12L),
+        Arrays.asList("Carlo Curti", "US", "US", "United States", 13L),
+        Arrays.asList("DirecTV", "US", "US", "United States", 13L),
+        Arrays.asList("Old Anatolian Turkish", "US", "US", "United States", 
13L),
+        Arrays.asList("Otjiwarongo Airport", "US", "US", "United States", 13L),
+        Arrays.asList("President of India", "US", "US", "United States", 13L)
+    );
+
+    assertResult(processor, outputChannel.readable(), joinSignature, 
expectedRows);
+  }
+
+  @Test
+  public void 
testInnerJoinCountryIsoCode_backwards_withMaxBufferedBytesLimit_succeeds() 
throws Exception
+  {
+    final ReadableInput factChannel = buildFactInput(
+        ImmutableList.of(
+            new KeyColumn("countryIsoCode", KeyOrder.ASCENDING),
+            new KeyColumn("page", KeyOrder.ASCENDING)
+        )
+    );
+
+    final ReadableInput countriesChannel =
+        buildCountriesInput(ImmutableList.of(new KeyColumn("countryIsoCode", 
KeyOrder.ASCENDING)));
+
+    final BlockingQueueFrameChannel outputChannel = 
BlockingQueueFrameChannel.minimal();
+
+    final RowSignature joinSignature =
+        RowSignature.builder()
+                    .add("j0.page", ColumnType.STRING)
+                    .add("j0.countryIsoCode", ColumnType.STRING)
+                    .add("countryIsoCode", ColumnType.STRING)
+                    .add("countryName", ColumnType.STRING)
+                    .add("countryNumber", ColumnType.LONG)
+                    .build();
+
+    final SortMergeJoinFrameProcessor processor = new 
SortMergeJoinFrameProcessor(
+        countriesChannel,
+        factChannel,
+        outputChannel.writable(),
+        makeFrameWriterFactory(joinSignature),
+        "j0.",
+        ImmutableList.of(
+            ImmutableList.of(new KeyColumn("countryIsoCode", 
KeyOrder.ASCENDING)),
+            ImmutableList.of(new KeyColumn("countryIsoCode", 
KeyOrder.ASCENDING))
+        ),
+        JoinType.INNER,
+        1
+    );
+
+    final List<List<Object>> expectedRows = Arrays.asList(
+        Arrays.asList("Peremptory norm", "AU", "AU", "Australia", 0L),
+        Arrays.asList("Didier Leclair", "CA", "CA", "Canada", 1L),
+        Arrays.asList("Les Argonautes", "CA", "CA", "Canada", 1L),
+        Arrays.asList("Sarah Michelle Gellar", "CA", "CA", "Canada", 1L),
+        Arrays.asList("Golpe de Estado en Chile de 1973", "CL", "CL", "Chile", 
2L),
+        Arrays.asList("Diskussion:Sebastian Schulz", "DE", "DE", "Germany", 
3L),
+        Arrays.asList("Gabinete Ministerial de Rafael Correa", "EC", "EC", 
"Ecuador", 4L),
+        Arrays.asList("Saison 9 de Secret Story", "FR", "FR", "France", 5L),
+        Arrays.asList("Glasgow", "GB", "GB", "United Kingdom", 6L),
+        Arrays.asList("Giusy Ferreri discography", "IT", "IT", "Italy", 7L),
+        Arrays.asList("Roma-Bangkok", "IT", "IT", "Italy", 7L),
+        Arrays.asList("青野武", "JP", "JP", "Japan", 8L),
+        Arrays.asList("유희왕 GX", "KR", "KR", "Republic of Korea", 9L),
+        Arrays.asList("History of Fourems", "MMMM", "MMMM", "Fourems", 205L),
+        Arrays.asList("Mathis Bolly", "MX", "MX", "Mexico", 10L),
+        Arrays.asList("Алиса в Зазеркалье", "NO", "NO", "Norway", 11L),
+        Arrays.asList("Cream Soda", "SU", "SU", "States United", 15L),
+        Arrays.asList("Wendigo", "SV", "SV", "El Salvador", 12L),
+        Arrays.asList("Carlo Curti", "US", "US", "United States", 13L),
+        Arrays.asList("DirecTV", "US", "US", "United States", 13L),
+        Arrays.asList("Old Anatolian Turkish", "US", "US", "United States", 
13L),
+        Arrays.asList("Otjiwarongo Airport", "US", "US", "United States", 13L),
+        Arrays.asList("President of India", "US", "US", "United States", 13L)
+    );
+
+    assertResult(processor, outputChannel.readable(), joinSignature, 
expectedRows);
+  }
+
+  @Test
+  public void testCountrySelfJoin() throws Exception
+  {
+    final ReadableInput factChannel1 = buildFactInput(ImmutableList.of(new 
KeyColumn("channel", KeyOrder.ASCENDING)));
+    final ReadableInput factChannel2 = buildFactInput(ImmutableList.of(new 
KeyColumn("channel", KeyOrder.ASCENDING)));
+
+    final BlockingQueueFrameChannel outputChannel = 
BlockingQueueFrameChannel.minimal();
+
+    final RowSignature joinSignature =
+        RowSignature.builder()
+                    .add("channel", ColumnType.STRING)
+                    .build();
+
+    final SortMergeJoinFrameProcessor processor = new 
SortMergeJoinFrameProcessor(
+        factChannel1,
+        factChannel2,
+        outputChannel.writable(),
+        makeFrameWriterFactory(joinSignature),
+        "j0.",
+        ImmutableList.of(
+            ImmutableList.of(new KeyColumn("channel", KeyOrder.ASCENDING)),
+            ImmutableList.of(new KeyColumn("channel", KeyOrder.ASCENDING))
+        ),
+        JoinType.INNER,
+        MAX_BUFFERED_BYTES
+    );
+
+    final List<List<Object>> expectedRows = new ArrayList<>();
+
+    final ImmutableMap<String, Long> expectedCounts =
+        ImmutableMap.<String, Long>builder()
+                    .put("#ca.wikipedia", 1L)
+                    .put("#de.wikipedia", 1L)
+                    .put("#en.wikipedia", 196L)
+                    .put("#es.wikipedia", 16L)
+                    .put("#fr.wikipedia", 9L)
+                    .put("#ja.wikipedia", 1L)
+                    .put("#ko.wikipedia", 1L)
+                    .put("#ru.wikipedia", 1L)
+                    .put("#vi.wikipedia", 9L)
+                    .build();
+
+    for (final Map.Entry<String, Long> entry : expectedCounts.entrySet()) {
+      for (int i = 0; i < Ints.checkedCast(entry.getValue()); i++) {
+        expectedRows.add(Collections.singletonList(entry.getKey()));
+      }
+    }
+
+    assertResult(processor, outputChannel.readable(), joinSignature, 
expectedRows);
+  }
+
+  @Test
+  public void testCountrySelfJoin_withMaxBufferedBytesLimit_fails() throws 
Exception
+  {
+    // Test is only valid when rowsPerInputFrame is low enough that we get 
multiple frames.
+    Assume.assumeThat(rowsPerInputFrame, Matchers.lessThanOrEqualTo(7));
+
+    final ReadableInput factChannel1 = buildFactInput(ImmutableList.of(new 
KeyColumn("channel", KeyOrder.ASCENDING)));
+    final ReadableInput factChannel2 = buildFactInput(ImmutableList.of(new 
KeyColumn("channel", KeyOrder.ASCENDING)));
+
+    final BlockingQueueFrameChannel outputChannel = 
BlockingQueueFrameChannel.minimal();
+
+    final RowSignature joinSignature =
+        RowSignature.builder()
+                    .add("channel", ColumnType.STRING)
+                    .build();
+
+    final SortMergeJoinFrameProcessor processor = new 
SortMergeJoinFrameProcessor(
+        factChannel1,
+        factChannel2,
+        outputChannel.writable(),
+        makeFrameWriterFactory(joinSignature),
+        "j0.",
+        ImmutableList.of(
+            ImmutableList.of(new KeyColumn("channel", KeyOrder.ASCENDING)),
+            ImmutableList.of(new KeyColumn("channel", KeyOrder.ASCENDING))
+        ),
+        JoinType.INNER,
+        1
+    );
+
+    final RuntimeException e = Assert.assertThrows(
+        RuntimeException.class,
+        () -> run(processor, outputChannel.readable(), joinSignature)
+    );
+
+    MatcherAssert.assertThat(e.getCause(), 
CoreMatchers.instanceOf(RuntimeException.class));
+    MatcherAssert.assertThat(e.getCause().getCause(), 
CoreMatchers.instanceOf(MSQException.class));
+    MatcherAssert.assertThat(
+        ((MSQException) e.getCause().getCause()).getFault(),
+        CoreMatchers.instanceOf(TooManyRowsWithSameKeyFault.class)
+    );
+  }
+
   private void assertResult(
       final SortMergeJoinFrameProcessor processor,
       final ReadableFrameChannel readableOutputChannel,
@@ -957,14 +1205,25 @@ public class SortMergeJoinFrameProcessorTest extends 
InitializedNullHandlingTest
       final List<List<Object>> expectedRows
   )
   {
-    final ListenableFuture<Long> retVal = exec.runFully(processor, null);
+    final List<List<Object>> rowsFromProcessor = run(processor, 
readableOutputChannel, joinSignature);
+    FrameTestUtil.assertRowsEqual(Sequences.simple(expectedRows), 
Sequences.simple(rowsFromProcessor));
+  }
+
+  private List<List<Object>> run(
+      final SortMergeJoinFrameProcessor processor,
+      final ReadableFrameChannel readableOutputChannel,
+      final RowSignature joinSignature
+  )
+  {
+    final ListenableFuture<Long> retValFromProcessor = 
exec.runFully(processor, null);
     final Sequence<List<Object>> rowsFromProcessor = 
FrameTestUtil.readRowsFromFrameChannel(
         readableOutputChannel,
         FrameReader.create(joinSignature)
     );
 
-    FrameTestUtil.assertRowsEqual(Sequences.simple(expectedRows), 
rowsFromProcessor);
-    Assert.assertEquals(0L, (long) FutureUtils.getUnchecked(retVal, true));
+    final List<List<Object>> rows = rowsFromProcessor.toList();
+    Assert.assertEquals(0L, (long) 
FutureUtils.getUnchecked(retValFromProcessor, true));
+    return rows;
   }
 
   private ReadableInput buildFactInput(final List<KeyColumn> keyColumns) 
throws IOException


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to