Akshat-Jain commented on code in PR #16804:
URL: https://github.com/apache/druid/pull/16804#discussion_r1696356682


##########
extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQWindowTest.java:
##########
@@ -2048,4 +2053,143 @@ public void testReplaceGroupByOnWikipedia(String 
contextName, Map<String, Object
                      .setExpectedSegment(ImmutableSet.of(SegmentId.of("foo1", 
Intervals.ETERNITY, "test", 0)))
                      .verifyResults();
   }
+
+  @MethodSource("data")
+  @ParameterizedTest(name = "{index}:with context {0}")
+  public void testWindowOnMixOfEmptyAndNonEmptyOverWithMultipleWorkers(String 
contextName, Map<String, Object> context)
+  {
+    final Map<String, Object> multipleWorkerContext = new HashMap<>(context);
+    multipleWorkerContext.put(MultiStageQueryContext.CTX_MAX_NUM_TASKS, 5);
+
+    final RowSignature rowSignature = RowSignature.builder()
+                                            .add("countryName", 
ColumnType.STRING)
+                                            .add("cityName", ColumnType.STRING)
+                                            .add("channel", ColumnType.STRING)
+                                            .add("c1", ColumnType.LONG)
+                                            .add("c2", ColumnType.LONG)
+                                            .build();
+
+    final Map<String, Object> contextWithRowSignature =
+        ImmutableMap.<String, Object>builder()
+                    .putAll(multipleWorkerContext)
+                    .put(
+                        DruidQuery.CTX_SCAN_SIGNATURE,
+                        
"[{\"name\":\"d0\",\"type\":\"STRING\"},{\"name\":\"d1\",\"type\":\"STRING\"},{\"name\":\"d2\",\"type\":\"STRING\"},{\"name\":\"w0\",\"type\":\"LONG\"},{\"name\":\"w1\",\"type\":\"LONG\"}]"
+                    )
+                    .build();
+
+    final GroupByQuery groupByQuery = GroupByQuery.builder()
+                                           
.setDataSource(CalciteTests.WIKIPEDIA)
+                                           
.setInterval(querySegmentSpec(Filtration
+                                                                             
.eternity()))
+                                           .setGranularity(Granularities.ALL)
+                                           .setDimensions(dimensions(
+                                               new DefaultDimensionSpec(
+                                                   "countryName",
+                                                   "d0",
+                                                   ColumnType.STRING
+                                               ),
+                                               new DefaultDimensionSpec(
+                                                   "cityName",
+                                                   "d1",
+                                                   ColumnType.STRING
+                                               ),
+                                               new DefaultDimensionSpec(
+                                                   "channel",
+                                                   "d2",
+                                                   ColumnType.STRING
+                                               )
+                                           ))
+                                           .setDimFilter(in("countryName", 
ImmutableList.of("Austria", "Republic of Korea")))
+                                           .setContext(multipleWorkerContext)
+                                           .build();
+
+    final AggregatorFactory[] aggs = {
+        new FilteredAggregatorFactory(new CountAggregatorFactory("w1"), 
notNull("d2"), "w1")
+    };
+
+    final WindowOperatorQuery windowQuery = new WindowOperatorQuery(
+        new QueryDataSource(groupByQuery),
+        new LegacySegmentSpec(Intervals.ETERNITY),
+        multipleWorkerContext,
+        RowSignature.builder()
+                    .add("d0", ColumnType.STRING)
+                    .add("d1", ColumnType.STRING)
+                    .add("d2", ColumnType.STRING)
+                    .add("w0", ColumnType.LONG)
+                    .add("w1", ColumnType.LONG).build(),
+        ImmutableList.of(
+            new 
NaiveSortOperatorFactory(ImmutableList.of(ColumnWithDirection.ascending("d0"), 
ColumnWithDirection.ascending("d1"), ColumnWithDirection.ascending("d2"))),
+            new NaivePartitioningOperatorFactory(Collections.emptyList()),
+            new WindowOperatorFactory(new WindowRowNumberProcessor("w0")),
+            new 
NaiveSortOperatorFactory(ImmutableList.of(ColumnWithDirection.ascending("d1"), 
ColumnWithDirection.ascending("d0"), ColumnWithDirection.ascending("d2"))),
+            new 
NaivePartitioningOperatorFactory(Collections.singletonList("d1")),
+            new WindowOperatorFactory(new 
WindowFramedAggregateProcessor(WindowFrame.forOrderBy("d0", "d1", "d2"), aggs))
+        ),
+        ImmutableList.of()
+    );
+
+    final ScanQuery scanQuery = Druids.newScanQueryBuilder()
+                                  .dataSource(new QueryDataSource(windowQuery))
+                                  
.intervals(querySegmentSpec(Filtration.eternity()))
+                                  .columns("d0", "d1", "d2", "w0", "w1")
+                                  .orderBy(
+                                      ImmutableList.of(
+                                          new ScanQuery.OrderBy("d0", 
ScanQuery.Order.ASCENDING),
+                                          new ScanQuery.OrderBy("d1", 
ScanQuery.Order.ASCENDING),
+                                          new ScanQuery.OrderBy("d2", 
ScanQuery.Order.ASCENDING)
+                                      )
+                                  )
+                                  .columnTypes(ColumnType.STRING, 
ColumnType.STRING, ColumnType.STRING, ColumnType.LONG, ColumnType.LONG)
+                                  .limit(Long.MAX_VALUE)
+                                  
.resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
+                                  .context(contextWithRowSignature)
+                                  .build();
+
+    final String sql = "select countryName, cityName, channel, \n"
+                          + "row_number() over (order by countryName, 
cityName, channel) as c1, \n"
+                          + "count(channel) over (partition by cityName order 
by countryName, cityName, channel) as c2\n"
+                          + "from wikipedia\n"
+                          + "where countryName in ('Austria', 'Republic of 
Korea')\n"
+                          + "group by countryName, cityName, channel "
+                          + "order by countryName, cityName, channel";
+
+    final String nullValue = NullHandling.sqlCompatible() ? null : "";
+
+    testSelectQuery()
+        .setSql(sql)
+        .setExpectedMSQSpec(MSQSpec.builder()
+                                   .query(scanQuery)
+                                   .columnMappings(
+                                       new ColumnMappings(ImmutableList.of(
+                                           new ColumnMapping("d0", 
"countryName"),
+                                           new ColumnMapping("d1", "cityName"),
+                                           new ColumnMapping("d2", "channel"),
+                                           new ColumnMapping("w0", "c1"),
+                                           new ColumnMapping("w1", "c2")
+                                       )
+                                       ))
+                                   
.tuningConfig(MSQTuningConfig.defaultConfig())
+                                   .build())
+        .setExpectedRowSignature(rowSignature)
+        .setExpectedResultRows(
+            ImmutableList.<Object[]>of(
+                new Object[]{"Austria", nullValue, "#de.wikipedia", 1L, 1L},
+                new Object[]{"Austria", "Horsching", "#de.wikipedia", 2L, 1L},
+                new Object[]{"Austria", "Vienna", "#de.wikipedia", 3L, 1L},
+                new Object[]{"Austria", "Vienna", "#es.wikipedia", 4L, 2L},
+                new Object[]{"Austria", "Vienna", "#tr.wikipedia", 5L, 3L},
+                new Object[]{"Republic of Korea", nullValue, "#en.wikipedia", 
6L, 2L},
+                new Object[]{"Republic of Korea", nullValue, "#ja.wikipedia", 
7L, 3L},
+                new Object[]{"Republic of Korea", nullValue, "#ko.wikipedia", 
8L, 4L},
+                new Object[]{"Republic of Korea", "Jeonju", "#ko.wikipedia", 
9L, 1L},
+                new Object[]{"Republic of Korea", "Seongnam-si", 
"#ko.wikipedia", 10L, 1L},
+                new Object[]{"Republic of Korea", "Seoul", "#ko.wikipedia", 
11L, 1L},
+                new Object[]{"Republic of Korea", "Suwon-si", "#ko.wikipedia", 
12L, 1L},
+                new Object[]{"Republic of Korea", "Yongsan-dong", 
"#ko.wikipedia", 13L, 1L}
+            )
+        )
+        .setQueryContext(multipleWorkerContext)

Review Comment:
   @adarshsanjeev The counters don't seem to have shuffle kind available 
though? Can you elaborate how asserting the counters would help us assert that 
MixShuffleSpec is being used here?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to