This is an automated email from the ASF dual-hosted git repository.

lakshsingla pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git


The following commit(s) were added to refs/heads/master by this push:
     new af5399cd9db Fixes a bug when running queries with a limit clause 
(#16643)
af5399cd9db is described below

commit af5399cd9db609e20578e5a0f4308510b1d2cf58
Author: Adarsh Sanjeev <[email protected]>
AuthorDate: Tue Jul 9 14:29:12 2024 +0530

    Fixes a bug when running queries with a limit clause (#16643)
    
    Add a shuffling based on the resultShuffleSpecFactory after a limit 
processor depending on the query destination. LimitFrameProcessors currently do 
not update the partition boosting column, so we also add the boost column to 
the previous stage, if one is required.
---
 .../org/apache/druid/msq/exec/ControllerImpl.java  |  10 +-
 .../druid/msq/indexing/MSQControllerTask.java      |  28 ++-
 .../msq/querykit/groupby/GroupByQueryKit.java      |   6 +-
 .../druid/msq/querykit/scan/ScanQueryKit.java      | 111 +++++------
 .../org/apache/druid/msq/exec/MSQExportTest.java   |  46 +++++
 .../org/apache/druid/msq/exec/MSQFaultsTest.java   |   3 +-
 .../org/apache/druid/msq/exec/MSQInsertTest.java   |  40 +++-
 .../org/apache/druid/msq/exec/MSQReplaceTest.java  |  45 +++++
 .../org/apache/druid/msq/exec/MSQSelectTest.java   | 202 +++++++++++++++++++--
 .../org/apache/druid/msq/test/MSQTestBase.java     |   4 +
 10 files changed, 418 insertions(+), 77 deletions(-)

diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java
index ad37c5380c5..3ca6d5780de 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java
@@ -101,9 +101,7 @@ import org.apache.druid.msq.indexing.MSQTuningConfig;
 import org.apache.druid.msq.indexing.WorkerCount;
 import org.apache.druid.msq.indexing.client.ControllerChatHandler;
 import org.apache.druid.msq.indexing.destination.DataSourceMSQDestination;
-import org.apache.druid.msq.indexing.destination.DurableStorageMSQDestination;
 import org.apache.druid.msq.indexing.destination.ExportMSQDestination;
-import org.apache.druid.msq.indexing.destination.TaskReportMSQDestination;
 import org.apache.druid.msq.indexing.error.CanceledFault;
 import org.apache.druid.msq.indexing.error.CannotParseExternalDataFault;
 import org.apache.druid.msq.indexing.error.FaultsExceededChecker;
@@ -1828,9 +1826,9 @@ public class ControllerImpl implements Controller
       );
 
       return builder.build();
-    } else if (querySpec.getDestination() instanceof TaskReportMSQDestination) 
{
+    } else if (MSQControllerTask.writeFinalResultsToTaskReport(querySpec)) {
       return queryDef;
-    } else if (querySpec.getDestination() instanceof 
DurableStorageMSQDestination) {
+    } else if 
(MSQControllerTask.writeFinalStageResultsToDurableStorage(querySpec)) {
 
       // attaching new query results stage if the final stage does sort during 
shuffle so that results are ordered.
       StageDefinition finalShuffleStageDef = 
queryDef.getFinalStageDefinition();
@@ -2933,12 +2931,12 @@ public class ControllerImpl implements Controller
 
       final InputChannelFactory inputChannelFactory;
 
-      if (queryKernelConfig.isDurableStorage() || 
MSQControllerTask.writeResultsToDurableStorage(querySpec)) {
+      if (queryKernelConfig.isDurableStorage() || 
MSQControllerTask.writeFinalStageResultsToDurableStorage(querySpec)) {
         inputChannelFactory = 
DurableStorageInputChannelFactory.createStandardImplementation(
             queryId(),
             MSQTasks.makeStorageConnector(context.injector()),
             closer,
-            MSQControllerTask.writeResultsToDurableStorage(querySpec)
+            MSQControllerTask.writeFinalStageResultsToDurableStorage(querySpec)
         );
       } else {
         inputChannelFactory = new WorkerInputChannelFactory(netClient, () -> 
taskIds);
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQControllerTask.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQControllerTask.java
index bdaf3964b29..b9c8ebe3b80 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQControllerTask.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/MSQControllerTask.java
@@ -52,6 +52,7 @@ import 
org.apache.druid.msq.indexing.destination.DataSourceMSQDestination;
 import org.apache.druid.msq.indexing.destination.DurableStorageMSQDestination;
 import org.apache.druid.msq.indexing.destination.ExportMSQDestination;
 import org.apache.druid.msq.indexing.destination.MSQDestination;
+import org.apache.druid.msq.indexing.destination.TaskReportMSQDestination;
 import org.apache.druid.msq.util.MultiStageQueryContext;
 import org.apache.druid.query.QueryContext;
 import org.apache.druid.rpc.ServiceClientFactory;
@@ -305,16 +306,38 @@ public class MSQControllerTask extends AbstractTask 
implements ClientTaskQuery,
     return querySpec.getDestination().getDestinationResource();
   }
 
+  /**
+   * Checks whether the task is an ingestion into a Druid datasource.
+   */
   public static boolean isIngestion(final MSQSpec querySpec)
   {
     return querySpec.getDestination() instanceof DataSourceMSQDestination;
   }
 
+  /**
+   * Checks whether the task is an export into external files.
+   */
   public static boolean isExport(final MSQSpec querySpec)
   {
     return querySpec.getDestination() instanceof ExportMSQDestination;
   }
 
+  /**
+   * Checks whether the task is an async query which writes frame files 
containing the final results into durable storage.
+   */
+  public static boolean writeFinalStageResultsToDurableStorage(final MSQSpec 
querySpec)
+  {
+    return querySpec.getDestination() instanceof DurableStorageMSQDestination;
+  }
+
+  /**
+   * Checks whether the task is an async query which writes frame files 
containing the final results into durable storage.
+   */
+  public static boolean writeFinalResultsToTaskReport(final MSQSpec querySpec)
+  {
+    return querySpec.getDestination() instanceof TaskReportMSQDestination;
+  }
+
   /**
    * Returns true if the task reads from the same table as the destination. In 
this case, we would prefer to fail
    * instead of reading any unused segments to ensure that old data is not 
read.
@@ -330,11 +353,6 @@ public class MSQControllerTask extends AbstractTask 
implements ClientTaskQuery,
     }
   }
 
-  public static boolean writeResultsToDurableStorage(final MSQSpec querySpec)
-  {
-    return querySpec.getDestination() instanceof DurableStorageMSQDestination;
-  }
-
   @Override
   public LookupLoadingSpec getLookupLoadingSpec()
   {
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByQueryKit.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByQueryKit.java
index f02e505d0c5..eb9953402ba 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByQueryKit.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/groupby/GroupByQueryKit.java
@@ -185,13 +185,14 @@ public class GroupByQueryKit implements 
QueryKit<GroupByQuery>
       );
 
       if (doLimitOrOffset) {
+        final ShuffleSpec finalShuffleSpec = 
resultShuffleSpecFactory.build(resultClusterBy, false);
         final DefaultLimitSpec limitSpec = (DefaultLimitSpec) 
queryToRun.getLimitSpec();
         queryDefBuilder.add(
             StageDefinition.builder(firstStageNumber + 2)
                            .inputs(new StageInputSpec(firstStageNumber + 1))
                            .signature(resultSignature)
                            .maxWorkerCount(1)
-                           .shuffleSpec(null) // no shuffling should be 
required after a limit processor.
+                           .shuffleSpec(finalShuffleSpec)
                            .processorFactory(
                                new OffsetLimitFrameProcessorFactory(
                                    limitSpec.getOffset(),
@@ -224,12 +225,13 @@ public class GroupByQueryKit implements 
QueryKit<GroupByQuery>
       );
       if (doLimitOrOffset) {
         final DefaultLimitSpec limitSpec = (DefaultLimitSpec) 
queryToRun.getLimitSpec();
+        final ShuffleSpec finalShuffleSpec = 
resultShuffleSpecFactory.build(resultClusterBy, false);
         queryDefBuilder.add(
             StageDefinition.builder(firstStageNumber + 2)
                            .inputs(new StageInputSpec(firstStageNumber + 1))
                            .signature(resultSignature)
                            .maxWorkerCount(1)
-                           .shuffleSpec(null)
+                           .shuffleSpec(finalShuffleSpec)
                            .processorFactory(
                                new OffsetLimitFrameProcessorFactory(
                                    limitSpec.getOffset(),
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryKit.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryKit.java
index 2927264382a..48a17a9e84e 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryKit.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/scan/ScanQueryKit.java
@@ -34,6 +34,7 @@ import org.apache.druid.msq.kernel.StageDefinition;
 import org.apache.druid.msq.querykit.DataSourcePlan;
 import org.apache.druid.msq.querykit.QueryKit;
 import org.apache.druid.msq.querykit.QueryKitUtils;
+import org.apache.druid.msq.querykit.ShuffleSpecFactories;
 import org.apache.druid.msq.querykit.ShuffleSpecFactory;
 import org.apache.druid.msq.querykit.common.OffsetLimitFrameProcessorFactory;
 import org.apache.druid.msq.util.MultiStageQueryContext;
@@ -111,69 +112,77 @@ public class ScanQueryKit implements QueryKit<ScanQuery>
     final ScanQuery queryToRun = 
originalQuery.withDataSource(dataSourcePlan.getNewDataSource());
     final int firstStageNumber = Math.max(minStageNumber, 
queryDefBuilder.getNextStageNumber());
     final RowSignature scanSignature = getAndValidateSignature(queryToRun, 
jsonMapper);
-    final ShuffleSpec shuffleSpec;
-    final RowSignature signatureToUse;
     final boolean hasLimitOrOffset = queryToRun.isLimited() || 
queryToRun.getScanRowsOffset() > 0;
 
+    final RowSignature.Builder signatureBuilder = 
RowSignature.builder().addAll(scanSignature);
+    final Granularity segmentGranularity =
+        QueryKitUtils.getSegmentGranularityFromContext(jsonMapper, 
queryToRun.getContext());
+    final List<KeyColumn> clusterByColumns = new ArrayList<>();
+
+    // Add regular orderBys.
+    for (final ScanQuery.OrderBy orderBy : queryToRun.getOrderBys()) {
+      clusterByColumns.add(
+          new KeyColumn(
+              orderBy.getColumnName(),
+              orderBy.getOrder() == ScanQuery.Order.DESCENDING ? 
KeyOrder.DESCENDING : KeyOrder.ASCENDING
+          )
+      );
+    }
 
-    // We ignore the resultShuffleSpecFactory in case:
-    //  1. There is no cluster by
-    //  2. There is an offset which means everything gets funneled into a 
single partition hence we use MaxCountShuffleSpec
-    if (queryToRun.getOrderBys().isEmpty() && hasLimitOrOffset) {
-      shuffleSpec = MixShuffleSpec.instance();
-      signatureToUse = scanSignature;
-    } else {
-      final RowSignature.Builder signatureBuilder = 
RowSignature.builder().addAll(scanSignature);
-      final Granularity segmentGranularity =
-          QueryKitUtils.getSegmentGranularityFromContext(jsonMapper, 
queryToRun.getContext());
-      final List<KeyColumn> clusterByColumns = new ArrayList<>();
-
-      // Add regular orderBys.
-      for (final ScanQuery.OrderBy orderBy : queryToRun.getOrderBys()) {
-        clusterByColumns.add(
-            new KeyColumn(
-                orderBy.getColumnName(),
-                orderBy.getOrder() == ScanQuery.Order.DESCENDING ? 
KeyOrder.DESCENDING : KeyOrder.ASCENDING
-            )
-        );
-      }
-
-      // Update partition by of next window
-      final RowSignature signatureSoFar = signatureBuilder.build();
-      boolean addShuffle = true;
-      if 
(originalQuery.getContext().containsKey(MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_COL))
 {
-        final ClusterBy windowClusterBy = (ClusterBy) 
originalQuery.getContext()
-                                                                   
.get(MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_COL);
-        for (KeyColumn c : windowClusterBy.getColumns()) {
-          if (!signatureSoFar.contains(c.columnName())) {
-            addShuffle = false;
-            break;
-          }
-        }
-        if (addShuffle) {
-          clusterByColumns.addAll(windowClusterBy.getColumns());
+    // Update partition by of next window
+    final RowSignature signatureSoFar = signatureBuilder.build();
+    boolean addShuffle = true;
+    if 
(originalQuery.getContext().containsKey(MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_COL))
 {
+      final ClusterBy windowClusterBy = (ClusterBy) originalQuery.getContext()
+                                                                 
.get(MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_COL);
+      for (KeyColumn c : windowClusterBy.getColumns()) {
+        if (!signatureSoFar.contains(c.columnName())) {
+          addShuffle = false;
+          break;
         }
-      } else {
-        // Add partition boosting column.
-        clusterByColumns.add(new 
KeyColumn(QueryKitUtils.PARTITION_BOOST_COLUMN, KeyOrder.ASCENDING));
-        signatureBuilder.add(QueryKitUtils.PARTITION_BOOST_COLUMN, 
ColumnType.LONG);
       }
+      if (addShuffle) {
+        clusterByColumns.addAll(windowClusterBy.getColumns());
+      }
+    } else {
+      // Add partition boosting column.
+      clusterByColumns.add(new KeyColumn(QueryKitUtils.PARTITION_BOOST_COLUMN, 
KeyOrder.ASCENDING));
+      signatureBuilder.add(QueryKitUtils.PARTITION_BOOST_COLUMN, 
ColumnType.LONG);
+    }
 
+    final ClusterBy clusterBy =
+        QueryKitUtils.clusterByWithSegmentGranularity(new 
ClusterBy(clusterByColumns, 0), segmentGranularity);
+    final ShuffleSpec finalShuffleSpec = 
resultShuffleSpecFactory.build(clusterBy, false);
 
-      final ClusterBy clusterBy =
-          QueryKitUtils.clusterByWithSegmentGranularity(new 
ClusterBy(clusterByColumns, 0), segmentGranularity);
-      shuffleSpec = resultShuffleSpecFactory.build(clusterBy, false);
-      signatureToUse = QueryKitUtils.sortableSignature(
-          
QueryKitUtils.signatureWithSegmentGranularity(signatureBuilder.build(), 
segmentGranularity),
-          clusterBy.getColumns()
-      );
+    final RowSignature signatureToUse = QueryKitUtils.sortableSignature(
+        
QueryKitUtils.signatureWithSegmentGranularity(signatureBuilder.build(), 
segmentGranularity),
+        clusterBy.getColumns()
+    );
+
+    ShuffleSpec scanShuffleSpec;
+    if (!hasLimitOrOffset) {
+      // If there is no limit spec, apply the final shuffling here itself. 
This will ensure partition sizes etc are respected.
+      scanShuffleSpec = finalShuffleSpec;
+    } else {
+      // If there is a limit spec, check if there are any non-boost columns to 
sort in.
+      boolean requiresSort = clusterByColumns.stream()
+                                             .anyMatch(keyColumn -> 
!QueryKitUtils.PARTITION_BOOST_COLUMN.equals(keyColumn.columnName()));
+      if (requiresSort) {
+        // If yes, do a sort into a single partition.
+        scanShuffleSpec = 
ShuffleSpecFactories.singlePartition().build(clusterBy, false);
+      } else {
+        // If the only clusterBy column is the boost column, we just use a mix 
shuffle to avoid unused shuffling.
+        // Note that we still need the boost column to be present in the row 
signature, since the limit stage would
+        // need it to be populated to do its own shuffling later.
+        scanShuffleSpec = MixShuffleSpec.instance();
+      }
     }
 
     queryDefBuilder.add(
         StageDefinition.builder(Math.max(minStageNumber, 
queryDefBuilder.getNextStageNumber()))
                        .inputs(dataSourcePlan.getInputSpecs())
                        .broadcastInputs(dataSourcePlan.getBroadcastInputs())
-                       .shuffleSpec(shuffleSpec)
+                       .shuffleSpec(scanShuffleSpec)
                        .signature(signatureToUse)
                        .maxWorkerCount(dataSourcePlan.isSingleWorker() ? 1 : 
maxWorkerCount)
                        .processorFactory(new 
ScanQueryFrameProcessorFactory(queryToRun))
@@ -185,7 +194,7 @@ public class ScanQueryKit implements QueryKit<ScanQuery>
                          .inputs(new StageInputSpec(firstStageNumber))
                          .signature(signatureToUse)
                          .maxWorkerCount(1)
-                         .shuffleSpec(null) // no shuffling should be required 
after a limit processor.
+                         .shuffleSpec(finalShuffleSpec) // Apply the final 
shuffling after limit spec.
                          .processorFactory(
                              new OffsetLimitFrameProcessorFactory(
                                  queryToRun.getScanRowsOffset(),
diff --git 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQExportTest.java
 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQExportTest.java
index 71b816e78c5..538cd471420 100644
--- 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQExportTest.java
+++ 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQExportTest.java
@@ -316,6 +316,52 @@ public class MSQExportTest extends MSQTestBase
     }
   }
 
+  @Test
+  public void testExportWithLimit() throws IOException
+  {
+    RowSignature rowSignature = RowSignature.builder()
+                                            .add("__time", ColumnType.LONG)
+                                            .add("dim1", ColumnType.STRING)
+                                            .add("cnt", 
ColumnType.LONG).build();
+
+    File exportDir = newTempFolder("export");
+
+    Map<String, Object> queryContext = new HashMap<>(DEFAULT_MSQ_CONTEXT);
+    queryContext.put(MultiStageQueryContext.CTX_ROWS_PER_PAGE, 1);
+
+    final String sql = StringUtils.format("insert into 
extern(local(exportPath=>'%s')) as csv select cnt, dim1 from foo limit 3", 
exportDir.getAbsolutePath());
+
+    testIngestQuery().setSql(sql)
+                     .setExpectedDataSource("foo1")
+                     .setQueryContext(queryContext)
+                     .setExpectedRowSignature(rowSignature)
+                     .setExpectedSegment(ImmutableSet.of())
+                     .setExpectedResultRows(ImmutableList.of())
+                     .verifyResults();
+
+    Assert.assertEquals(
+        ImmutableList.of(
+            "cnt,dim1",
+            "1,"
+        ),
+        readResultsFromFile(new File(exportDir, 
"query-test-query-worker0-partition0.csv"))
+    );
+    Assert.assertEquals(
+        ImmutableList.of(
+            "cnt,dim1",
+            "1,10.1"
+        ),
+        readResultsFromFile(new File(exportDir, 
"query-test-query-worker0-partition1.csv"))
+    );
+    Assert.assertEquals(
+        ImmutableList.of(
+            "cnt,dim1",
+            "1,2"
+            ),
+        readResultsFromFile(new File(exportDir, 
"query-test-query-worker0-partition2.csv"))
+    );
+  }
+
   private void verifyManifestFile(File exportDir, List<File> resultFiles) 
throws IOException
   {
     final File manifestFile = new File(exportDir, 
ExportMetadataManager.MANIFEST_FILE);
diff --git 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQFaultsTest.java
 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQFaultsTest.java
index d144e765957..b18e14b04f2 100644
--- 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQFaultsTest.java
+++ 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQFaultsTest.java
@@ -46,6 +46,7 @@ import 
org.apache.druid.msq.indexing.error.TooManyPartitionsFault;
 import org.apache.druid.msq.indexing.error.TooManySegmentsInTimeChunkFault;
 import org.apache.druid.msq.test.MSQTestBase;
 import org.apache.druid.msq.test.MSQTestTaskActionClient;
+import org.apache.druid.msq.util.MultiStageQueryContext;
 import org.apache.druid.segment.column.ColumnType;
 import org.apache.druid.segment.column.RowSignature;
 import org.apache.druid.segment.realtime.appenderator.SegmentIdWithShardSpec;
@@ -291,7 +292,7 @@ public class MSQFaultsTest extends MSQTestBase
   {
     Map<String, Object> context = ImmutableMap.<String, Object>builder()
                                               .putAll(DEFAULT_MSQ_CONTEXT)
-                                              .put("rowsPerSegment", 1)
+                                              
.put(MultiStageQueryContext.CTX_ROWS_PER_SEGMENT, 1)
                                               .build();
 
 
diff --git 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQInsertTest.java
 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQInsertTest.java
index 03ed429848a..098b143b277 100644
--- 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQInsertTest.java
+++ 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQInsertTest.java
@@ -1455,7 +1455,7 @@ public class MSQInsertTest extends MSQTestBase
                              + "SELECT __time, m1 "
                              + "FROM foo "
                              + "LIMIT 50 "
-                             + "OFFSET 10"
+                             + "OFFSET 10 "
                              + "PARTITIONED BY ALL TIME")
                      .setExpectedValidationErrorMatcher(
                          invalidSqlContains("INSERT and REPLACE queries cannot 
have an OFFSET")
@@ -1464,6 +1464,44 @@ public class MSQInsertTest extends MSQTestBase
                      .verifyPlanningErrors();
   }
 
+  @MethodSource("data")
+  @ParameterizedTest(name = "{index}:with context {0}")
+  public void testInsertOnFoo1WithLimit(String contextName, Map<String, 
Object> context)
+  {
+    Map<String, Object> queryContext = ImmutableMap.<String, Object>builder()
+                                                   .putAll(context)
+                                                   
.put(MultiStageQueryContext.CTX_ROWS_PER_SEGMENT, 2)
+                                                   .build();
+
+    List<Object[]> expectedRows = ImmutableList.of(
+        new Object[]{946771200000L, "10.1", 1L},
+        new Object[]{978307200000L, "1", 1L},
+        new Object[]{946857600000L, "2", 1L},
+        new Object[]{978480000000L, "abc", 1L}
+    );
+
+    RowSignature rowSignature = RowSignature.builder()
+                                            .add("__time", ColumnType.LONG)
+                                            .add("dim1", ColumnType.STRING)
+                                            .add("cnt", ColumnType.LONG)
+                                            .build();
+
+    testIngestQuery().setSql(
+                         "insert into foo1 select __time, dim1, cnt from foo 
where dim1 != '' limit 4 partitioned by all clustered by dim1")
+                     .setExpectedDataSource("foo1")
+                     .setQueryContext(queryContext)
+                     .setExpectedRowSignature(rowSignature)
+                     .setExpectedSegment(ImmutableSet.of(SegmentId.of("foo1", 
Intervals.ETERNITY, "test", 0), SegmentId.of("foo1", Intervals.ETERNITY, 
"test", 1)))
+                     .setExpectedResultRows(expectedRows)
+                     .setExpectedMSQSegmentReport(
+                         new MSQSegmentReport(
+                             NumberedShardSpec.class.getSimpleName(),
+                             "Using NumberedShardSpec to generate segments 
since the query is inserting rows."
+                         )
+                     )
+                     .verifyResults();
+  }
+
   @MethodSource("data")
   @ParameterizedTest(name = "{index}:with context {0}")
   public void testCorrectNumberOfWorkersUsedAutoModeWithoutBytesLimit(String 
contextName, Map<String, Object> context) throws IOException
diff --git 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQReplaceTest.java
 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQReplaceTest.java
index 7d7f4e310c6..2e05d447910 100644
--- 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQReplaceTest.java
+++ 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQReplaceTest.java
@@ -906,6 +906,51 @@ public class MSQReplaceTest extends MSQTestBase
                      .verifyResults();
   }
 
+  @MethodSource("data")
+  @ParameterizedTest(name = "{index}:with context {0}")
+  public void testReplaceOnFoo1WithLimit(String contextName, Map<String, 
Object> context)
+  {
+    Map<String, Object> queryContext = ImmutableMap.<String, Object>builder()
+                                                   .putAll(context)
+                                                   
.put(MultiStageQueryContext.CTX_ROWS_PER_SEGMENT, 2)
+                                                   .build();
+
+    List<Object[]> expectedRows = ImmutableList.of(
+        new Object[]{946684800000L, NullHandling.sqlCompatible() ? "" : null},
+        new Object[]{978307200000L, "1"},
+        new Object[]{946771200000L, "10.1"},
+        new Object[]{946857600000L, "2"}
+    );
+
+    RowSignature rowSignature = RowSignature.builder()
+                                            .add("__time", ColumnType.LONG)
+                                            .add("dim1", ColumnType.STRING)
+                                            .build();
+
+    testIngestQuery().setSql(
+                         "REPLACE INTO \"foo1\" OVERWRITE ALL\n"
+                         + "SELECT\n"
+                         + "  \"__time\",\n"
+                         + "  \"dim1\"\n"
+                         + "FROM foo\n"
+                         + "LIMIT 4\n"
+                         + "PARTITIONED BY ALL\n"
+                         + "CLUSTERED BY dim1")
+                     .setExpectedDataSource("foo1")
+                     .setQueryContext(queryContext)
+                     .setExpectedRowSignature(rowSignature)
+                     .setExpectedShardSpec(DimensionRangeShardSpec.class)
+                     .setExpectedSegment(ImmutableSet.of(SegmentId.of("foo1", 
Intervals.ETERNITY, "test", 0), SegmentId.of("foo1", Intervals.ETERNITY, 
"test", 1)))
+                     .setExpectedResultRows(expectedRows)
+                     .setExpectedMSQSegmentReport(
+                         new MSQSegmentReport(
+                             DimensionRangeShardSpec.class.getSimpleName(),
+                             "Using RangeShardSpec to generate segments."
+                         )
+                     )
+                     .verifyResults();
+  }
+
   @MethodSource("data")
   @ParameterizedTest(name = "{index}:with context {0}")
   public void testReplaceTimeChunksLargerThanData(String contextName, 
Map<String, Object> context)
diff --git 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java
 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java
index 84dddd526c1..2d14e743497 100644
--- 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java
+++ 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java
@@ -24,6 +24,7 @@ import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSet;
 import org.apache.druid.common.config.NullHandling;
 import org.apache.druid.data.input.impl.CsvInputFormat;
+import org.apache.druid.data.input.impl.InlineInputSource;
 import org.apache.druid.data.input.impl.JsonInputFormat;
 import org.apache.druid.data.input.impl.LocalInputSource;
 import org.apache.druid.data.input.impl.systemfield.SystemFields;
@@ -624,8 +625,17 @@ public class MSQSelectTest extends MSQTestBase
                                                .add("dim1", ColumnType.STRING)
                                                .build();
 
+    final ImmutableList<Object[]> expectedResults = ImmutableList.of(
+        new Object[]{1L, ""},
+        new Object[]{1L, "10.1"},
+        new Object[]{1L, "2"},
+        new Object[]{1L, "1"},
+        new Object[]{1L, "def"},
+        new Object[]{1L, "abc"}
+    );
+
     testSelectQuery()
-        .setSql("select cnt,dim1 from foo limit 10")
+        .setSql("select cnt, dim1 from foo limit 10")
         .setExpectedMSQSpec(
             MSQSpec.builder()
                    .query(
@@ -646,6 +656,7 @@ public class MSQSelectTest extends MSQTestBase
         )
         .setQueryContext(context)
         .setExpectedRowSignature(resultSignature)
+        .setExpectedResultRows(expectedResults)
         .setExpectedCountersForStageWorkerChannel(
             CounterSnapshotMatcher
                 .with().totalFiles(1),
@@ -653,22 +664,31 @@ public class MSQSelectTest extends MSQTestBase
         )
         .setExpectedCountersForStageWorkerChannel(
             CounterSnapshotMatcher
-                .with().rows(6).frames(1),
+                .with().rows(6),
             0, 0, "output"
         )
         .setExpectedCountersForStageWorkerChannel(
             CounterSnapshotMatcher
-                .with().rows(6).frames(1),
+                .with().rows(6),
             0, 0, "shuffle"
         )
-        .setExpectedResultRows(ImmutableList.of(
-            new Object[]{1L, ""},
-            new Object[]{1L, "10.1"},
-            new Object[]{1L, "2"},
-            new Object[]{1L, "1"},
-            new Object[]{1L, "def"},
-            new Object[]{1L, "abc"}
-        )).verifyResults();
+        .setExpectedCountersForStageWorkerChannel(
+            CounterSnapshotMatcher
+                .with().rows(6),
+            1, 0, "input0"
+        )
+        .setExpectedCountersForStageWorkerChannel(
+            CounterSnapshotMatcher
+                .with().rows(6),
+            1, 0, "output"
+        )
+        .setExpectedCountersForStageWorkerChannel(
+            CounterSnapshotMatcher
+                
.with().rows(!context.containsKey(MultiStageQueryContext.CTX_ROWS_PER_PAGE) ? 
new long[] {6} : new long[] {2, 2, 2}),
+            1, 0, "shuffle"
+        )
+        .setExpectedResultRows(expectedResults)
+        .verifyResults();
   }
 
   @MethodSource("data")
@@ -1699,6 +1719,166 @@ public class MSQSelectTest extends MSQTestBase
         )).verifyResults();
   }
 
+  @MethodSource("data")
+  @ParameterizedTest(name = "{index}:with context {0}")
+  public void testGroupByWithLimit(String contextName, Map<String, Object> 
context)
+  {
+    RowSignature expectedResultSignature = RowSignature.builder()
+                                                       .add("dim1", 
ColumnType.STRING)
+                                                       .add("cnt", 
ColumnType.LONG)
+                                                       .build();
+
+    GroupByQuery query = GroupByQuery.builder()
+                                     .setDataSource(CalciteTests.DATASOURCE1)
+                                     
.setInterval(querySegmentSpec(Filtration.eternity()))
+                                     .setGranularity(Granularities.ALL)
+                                     .setDimensions(dimensions(new 
DefaultDimensionSpec("dim1", "d0")))
+                                     .setAggregatorSpecs(
+                                         aggregators(
+                                             new CountAggregatorFactory(
+                                                 "a0"
+                                             )
+                                         )
+                                     )
+                                     .setDimFilter(not(equality("dim1", "", 
ColumnType.STRING)))
+                                     .setLimit(1)
+                                     .setContext(context)
+                                     .build();
+
+    testSelectQuery()
+        .setSql("SELECT dim1, cnt FROM (SELECT dim1, COUNT(*) AS cnt FROM foo 
GROUP BY dim1 HAVING dim1 != '' LIMIT 1) LIMIT 20")
+        .setExpectedMSQSpec(MSQSpec.builder()
+                                   .query(query)
+                                   .columnMappings(new 
ColumnMappings(ImmutableList.of(
+                                       new ColumnMapping("d0", "dim1"),
+                                       new ColumnMapping("a0", "cnt")
+                                   )))
+                                   
.tuningConfig(MSQTuningConfig.defaultConfig())
+                                   
.destination(isDurableStorageDestination(contextName, context)
+                                                ? 
DurableStorageMSQDestination.INSTANCE
+                                                : 
TaskReportMSQDestination.INSTANCE)
+                                   .build())
+        .setExpectedRowSignature(expectedResultSignature)
+        .setQueryContext(context)
+        .setExpectedResultRows(ImmutableList.of(
+            new Object[]{"1", 1L}
+        )).verifyResults();
+  }
+
+  @MethodSource("data")
+  @ParameterizedTest(name = "{index}:with context {0}")
+  public void testGroupByWithLimitAndOrdering(String contextName, Map<String, 
Object> context)
+  {
+    RowSignature rowSignature = RowSignature.builder()
+                                            .add("dim1", ColumnType.STRING)
+                                            .add("count", ColumnType.LONG)
+                                            .build();
+
+    GroupByQuery query = GroupByQuery.builder()
+                                     .setDataSource(
+                                         new ExternalDataSource(
+                                             new 
InlineInputSource("dim1\nabc\nxyz\ndef\nxyz\nabc\nxyz\nabc\nxyz\ndef\nbbb\naaa"),
+                                             new CsvInputFormat(null, null, 
null, true, 0),
+                                             
RowSignature.builder().add("dim1", ColumnType.STRING).build()
+                                         )
+                                     )
+                                     
.setInterval(querySegmentSpec(Filtration.eternity()))
+                                     .setGranularity(Granularities.ALL)
+                                     .addOrderByColumn(new 
OrderByColumnSpec("a0", OrderByColumnSpec.Direction.DESCENDING, 
StringComparators.NUMERIC))
+                                     .addOrderByColumn(new 
OrderByColumnSpec("d0", OrderByColumnSpec.Direction.ASCENDING, 
StringComparators.LEXICOGRAPHIC))
+                                     .setDimensions(dimensions(new 
DefaultDimensionSpec("dim1", "d0")))
+                                     .setAggregatorSpecs(
+                                         aggregators(
+                                             new CountAggregatorFactory(
+                                                 "a0"
+                                             )
+                                         )
+                                     )
+                                     .setLimit(4)
+                                     .setContext(context)
+                                     .build();
+
+    List<Object[]> expectedRows = ImmutableList.of(
+        new Object[]{"xyz", 4L},
+        new Object[]{"abc", 3L},
+        new Object[]{"def", 2L},
+        new Object[]{"aaa", 1L}
+    );
+
+    testSelectQuery()
+        .setSql("WITH \"ext\" AS (\n"
+                + "  SELECT *\n"
+                + "  FROM TABLE(\n"
+                + "    EXTERN(\n"
+                + "      
'{\"type\":\"inline\",\"data\":\"dim1\\nabc\\nxyz\\ndef\\nxyz\\nabc\\nxyz\\nabc\\nxyz\\ndef\\nbbb\\naaa\"}',\n"
+                + "      '{\"type\":\"csv\",\"findColumnsFromHeader\":true}'\n"
+                + "    )\n"
+                + "  ) EXTEND (\"dim1\" VARCHAR)\n"
+                + ")\n"
+                + "SELECT\n"
+                + "  \"dim1\",\n"
+                + "  COUNT(*) AS \"count\"\n"
+                + "FROM \"ext\"\n"
+                + "GROUP BY 1\n"
+                + "ORDER BY 2 DESC, 1\n"
+                + "LIMIT 4\n")
+        .setExpectedMSQSpec(MSQSpec.builder()
+                                   .query(query)
+                                   .columnMappings(new 
ColumnMappings(ImmutableList.of(
+                                       new ColumnMapping("d0", "dim1"),
+                                       new ColumnMapping("a0", "count")
+                                   )))
+                                   
.tuningConfig(MSQTuningConfig.defaultConfig())
+                                   
.destination(isDurableStorageDestination(contextName, context)
+                                                ? 
DurableStorageMSQDestination.INSTANCE
+                                                : 
TaskReportMSQDestination.INSTANCE)
+                                   .build())
+        .setExpectedRowSignature(rowSignature)
+        .setExpectedCountersForStageWorkerChannel(
+            CounterSnapshotMatcher
+                .with().totalFiles(1),
+            0, 0, "input0"
+        )
+        .setExpectedCountersForStageWorkerChannel(
+            CounterSnapshotMatcher
+                .with().rows(5),
+            0, 0, "output"
+        )
+        .setExpectedCountersForStageWorkerChannel(
+            CounterSnapshotMatcher
+                .with().rows(5),
+            1, 0, "shuffle"
+        )
+        .setExpectedCountersForStageWorkerChannel(
+            CounterSnapshotMatcher
+                .with().rows(5),
+            1, 0, "input0"
+        )
+        .setExpectedCountersForStageWorkerChannel(
+            CounterSnapshotMatcher
+                .with().rows(5),
+            1, 0, "output"
+        )
+        .setExpectedCountersForStageWorkerChannel(
+            CounterSnapshotMatcher
+                .with().rows(5),
+            2, 0, "input0"
+        )
+        .setExpectedCountersForStageWorkerChannel(
+            CounterSnapshotMatcher
+                .with().rows(4),
+            2, 0, "output"
+        )
+        .setExpectedCountersForStageWorkerChannel(
+            CounterSnapshotMatcher
+                
.with().rows(!context.containsKey(MultiStageQueryContext.CTX_ROWS_PER_PAGE) ? 
new long[] {4} : new long[] {2, 2}),
+            2, 0, "shuffle"
+        )
+        .setQueryContext(context)
+        .setExpectedResultRows(expectedRows)
+        .verifyResults();
+  }
+
   @MethodSource("data")
   @ParameterizedTest(name = "{index}:with context {0}")
   public void testHavingOnApproximateCountDistinct(String contextName, 
Map<String, Object> context)
diff --git 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java
 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java
index f7c9b3296ca..5f0bd545b7c 100644
--- 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java
+++ 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java
@@ -1288,6 +1288,10 @@ public class MSQTestBase extends BaseCalciteQueryTest
                                                                       .stream()
                                                                       
.filter(segmentId -> segmentId.getInterval()
                                                                                
                     .contains((Long) row[0]))
+                                                                      
.filter(segmentId -> {
+                                                                        
List<List<Object>> lists = segmentIdVsOutputRowsMap.get(segmentId);
+                                                                        return 
lists.contains(Arrays.asList(row));
+                                                                      })
                                                                       
.collect(Collectors.toList());
             if (diskSegmentList.size() != 1) {
               throw new IllegalStateException("Single key in multiple 
partitions");


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to