adarshsanjeev commented on code in PR #14450:
URL: https://github.com/apache/druid/pull/14450#discussion_r1251905160


##########
extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java:
##########
@@ -198,6 +203,28 @@ public Optional<QueryDefinitionBuilder> 
getSubQueryDefBuilder()
     return Optional.ofNullable(subQueryDefBuilder);
   }
 
+  private static JoinAlgorithm deduceJoinAlgorithm(JoinAlgorithm 
preferredJoinAlgorithm, JoinDataSource joinDataSource)
+  {
+
+    if (JoinAlgorithm.BROADCAST.equals(preferredJoinAlgorithm)) {
+      return JoinAlgorithm.BROADCAST;
+    }
+
+    // preferredJoinAlgorithm would only be sortMerge now
+
+    if 
(isConditionEqualityOnLeftAndRightColumns(joinDataSource.getConditionAnalysis()))
 {

Review Comment:
   We might want to update the MSQ docs for join algorithm, that the join 
algorithm is more of a hint now, and might be a different one than expected.



##########
extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java:
##########
@@ -198,6 +203,28 @@ public Optional<QueryDefinitionBuilder> 
getSubQueryDefBuilder()
     return Optional.ofNullable(subQueryDefBuilder);
   }
 
+  private static JoinAlgorithm deduceJoinAlgorithm(JoinAlgorithm 
preferredJoinAlgorithm, JoinDataSource joinDataSource)

Review Comment:
   We might want to log the return here. If the query fails before we get the 
full plan, we might not know the join algorithm.



##########
sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java:
##########
@@ -5652,4 +5652,128 @@ public void testJoinsWithThreeConditions()
         )
     );
   }
+
+  @Test
+  public void testJoinWithInputRefCondition()
+  {
+    cannotVectorize();
+    Map<String, Object> context = new HashMap<>(QUERY_CONTEXT_DEFAULT);
+
+    Query expectedQuery;
+
+    if (!NullHandling.sqlCompatible()) {
+      expectedQuery = Druids.newTimeseriesQueryBuilder()
+                            .dataSource(
+                                join(
+                                    new 
TableDataSource(CalciteTests.DATASOURCE1),
+                                    new QueryDataSource(
+                                        GroupByQuery.builder()
+                                                    
.setInterval(querySegmentSpec(Filtration.eternity()))
+                                                    
.setGranularity(Granularities.ALL)
+                                                    .setDataSource(new 
TableDataSource(CalciteTests.DATASOURCE1))
+                                                    
.setVirtualColumns(expressionVirtualColumn(
+                                                        "v0",
+                                                        "1",
+                                                        ColumnType.LONG
+                                                    ))
+                                                    .setDimensions(
+                                                        new 
DefaultDimensionSpec("m1", "d0", ColumnType.FLOAT),
+                                                        new 
DefaultDimensionSpec("v0", "d1", ColumnType.LONG)
+                                                    )
+                                                    .build()
+                                    ),
+                                    "j0.",
+                                    "(floor(100) == \"j0.d0\")",
+                                    JoinType.LEFT
+                                )
+                            )
+                            .granularity(Granularities.ALL)
+                            .aggregators(aggregators(
+                                new FilteredAggregatorFactory(
+                                    new CountAggregatorFactory("a0"),
+                                    new SelectorDimFilter("j0.d1", null, null)
+                                )
+                            ))
+                            
.context(getTimeseriesContextWithFloorTime(TIMESERIES_CONTEXT_BY_GRAN, "d0"))
+                            .intervals(querySegmentSpec(Filtration.eternity()))
+                            .context(context)
+                            .build();
+
+    } else {
+      expectedQuery = Druids.newTimeseriesQueryBuilder()
+                            .dataSource(
+                                join(
+                                    join(
+                                        new TableDataSource("foo"),
+                                        new QueryDataSource(
+                                            Druids.newTimeseriesQueryBuilder()
+                                                  .dataSource("foo")
+                                                  .aggregators(
+                                                      new 
CountAggregatorFactory("a0"),
+                                                      new 
FilteredAggregatorFactory(
+                                                          new 
CountAggregatorFactory("a1"),
+                                                          not(selector("m1", 
null, null)),
+                                                          "a1"
+                                                      )
+                                                  )
+                                                  
.intervals(querySegmentSpec(Filtration.eternity()))
+                                                  .context(context)
+                                                  .build()
+                                        ),
+                                        "j0.",
+                                        "1",
+                                        JoinType.INNER
+                                    ),
+                                    new QueryDataSource(
+                                        GroupByQuery.builder()
+                                                    
.setInterval(querySegmentSpec(Filtration.eternity()))
+                                                    
.setGranularity(Granularities.ALL)
+                                                    .setDataSource(new 
TableDataSource(CalciteTests.DATASOURCE1))
+                                                    
.setVirtualColumns(expressionVirtualColumn(
+                                                        "v0",
+                                                        "1",
+                                                        ColumnType.LONG
+                                                    ))
+                                                    .setDimensions(
+                                                        new 
DefaultDimensionSpec("m1", "d0", ColumnType.FLOAT),
+                                                        new 
DefaultDimensionSpec("v0", "d1", ColumnType.LONG)
+                                                    )
+                                                    .build()
+                                    ),
+                                    "_j0.",
+                                    "(floor(100) == \"_j0.d0\")",
+                                    JoinType.LEFT
+                                )
+                            )
+                            .granularity(Granularities.ALL)
+                            .aggregators(aggregators(
+                                new FilteredAggregatorFactory(
+                                    new CountAggregatorFactory("a0"),
+                                    or(
+                                        new SelectorDimFilter("j0.a0", "0", 
null),
+                                        and(
+                                            selector("_j0.d1", null, null),
+                                            expressionFilter("(\"j0.a1\" >= 
\"j0.a0\")")
+                                        )
+
+                                    )
+                                )
+                            ))
+                            
.context(getTimeseriesContextWithFloorTime(TIMESERIES_CONTEXT_BY_GRAN, "d0"))
+                            .intervals(querySegmentSpec(Filtration.eternity()))
+                            .context(context)
+                            .build();
+
+    }
+
+    testQuery(
+        "SELECT COUNT(*) FILTER (WHERE FLOOR(100) NOT IN (SELECT m1 FROM foo)) 
"
+        + "FROM foo",
+        context,

Review Comment:
   Should we add a test with join in the query context as "Sort merge", but 
because of the query structure, msq instead uses broadcast joins?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to