LakshSingla commented on code in PR #14450:
URL: https://github.com/apache/druid/pull/14450#discussion_r1252602160


##########
sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java:
##########
@@ -5652,4 +5652,128 @@ public void testJoinsWithThreeConditions()
         )
     );
   }
+
+  @Test
+  public void testJoinWithInputRefCondition()
+  {
+    cannotVectorize();
+    Map<String, Object> context = new HashMap<>(QUERY_CONTEXT_DEFAULT);
+
+    Query expectedQuery;
+
+    if (!NullHandling.sqlCompatible()) {
+      expectedQuery = Druids.newTimeseriesQueryBuilder()
+                            .dataSource(
+                                join(
+                                    new 
TableDataSource(CalciteTests.DATASOURCE1),
+                                    new QueryDataSource(
+                                        GroupByQuery.builder()
+                                                    
.setInterval(querySegmentSpec(Filtration.eternity()))
+                                                    
.setGranularity(Granularities.ALL)
+                                                    .setDataSource(new 
TableDataSource(CalciteTests.DATASOURCE1))
+                                                    
.setVirtualColumns(expressionVirtualColumn(
+                                                        "v0",
+                                                        "1",
+                                                        ColumnType.LONG
+                                                    ))
+                                                    .setDimensions(
+                                                        new 
DefaultDimensionSpec("m1", "d0", ColumnType.FLOAT),
+                                                        new 
DefaultDimensionSpec("v0", "d1", ColumnType.LONG)
+                                                    )
+                                                    .build()
+                                    ),
+                                    "j0.",
+                                    "(floor(100) == \"j0.d0\")",
+                                    JoinType.LEFT
+                                )
+                            )
+                            .granularity(Granularities.ALL)
+                            .aggregators(aggregators(
+                                new FilteredAggregatorFactory(
+                                    new CountAggregatorFactory("a0"),
+                                    new SelectorDimFilter("j0.d1", null, null)
+                                )
+                            ))
+                            
.context(getTimeseriesContextWithFloorTime(TIMESERIES_CONTEXT_BY_GRAN, "d0"))
+                            .intervals(querySegmentSpec(Filtration.eternity()))
+                            .context(context)
+                            .build();
+
+    } else {
+      expectedQuery = Druids.newTimeseriesQueryBuilder()
+                            .dataSource(
+                                join(
+                                    join(
+                                        new TableDataSource("foo"),
+                                        new QueryDataSource(
+                                            Druids.newTimeseriesQueryBuilder()
+                                                  .dataSource("foo")
+                                                  .aggregators(
+                                                      new 
CountAggregatorFactory("a0"),
+                                                      new 
FilteredAggregatorFactory(
+                                                          new 
CountAggregatorFactory("a1"),
+                                                          not(selector("m1", 
null, null)),
+                                                          "a1"
+                                                      )
+                                                  )
+                                                  
.intervals(querySegmentSpec(Filtration.eternity()))
+                                                  .context(context)
+                                                  .build()
+                                        ),
+                                        "j0.",
+                                        "1",
+                                        JoinType.INNER
+                                    ),
+                                    new QueryDataSource(
+                                        GroupByQuery.builder()
+                                                    
.setInterval(querySegmentSpec(Filtration.eternity()))
+                                                    
.setGranularity(Granularities.ALL)
+                                                    .setDataSource(new 
TableDataSource(CalciteTests.DATASOURCE1))
+                                                    
.setVirtualColumns(expressionVirtualColumn(
+                                                        "v0",
+                                                        "1",
+                                                        ColumnType.LONG
+                                                    ))
+                                                    .setDimensions(
+                                                        new 
DefaultDimensionSpec("m1", "d0", ColumnType.FLOAT),
+                                                        new 
DefaultDimensionSpec("v0", "d1", ColumnType.LONG)
+                                                    )
+                                                    .build()
+                                    ),
+                                    "_j0.",
+                                    "(floor(100) == \"_j0.d0\")",
+                                    JoinType.LEFT
+                                )
+                            )
+                            .granularity(Granularities.ALL)
+                            .aggregators(aggregators(
+                                new FilteredAggregatorFactory(
+                                    new CountAggregatorFactory("a0"),
+                                    or(
+                                        new SelectorDimFilter("j0.a0", "0", 
null),
+                                        and(
+                                            selector("_j0.d1", null, null),
+                                            expressionFilter("(\"j0.a1\" >= 
\"j0.a0\")")
+                                        )
+
+                                    )
+                                )
+                            ))
+                            
.context(getTimeseriesContextWithFloorTime(TIMESERIES_CONTEXT_BY_GRAN, "d0"))
+                            .intervals(querySegmentSpec(Filtration.eternity()))
+                            .context(context)
+                            .build();
+
+    }
+
+    testQuery(
+        "SELECT COUNT(*) FILTER (WHERE FLOOR(100) NOT IN (SELECT m1 FROM foo)) 
"
+        + "FROM foo",
+        context,

Review Comment:
   This is a test in that sense since this class is subclassed to execute with 
both the joins, however, there isn't any assertion made (since it stems from 
the parent test class). 
   I'll add a similar test in the MSQ engine and see if I can add some 
assertions on the algorithm used. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to