This is an automated email from the ASF dual-hosted git repository.

lakshsingla pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git


The following commit(s) were added to refs/heads/master by this push:
     new 5ce536355e Fix planning bug while using sort merge frame processor 
(#14450)
5ce536355e is described below

commit 5ce536355e80cbb72e7319359c2047508d907e7d
Author: Laksh Singla <[email protected]>
AuthorDate: Tue Jul 11 15:28:44 2023 +0530

    Fix planning bug while using sort merge frame processor (#14450)
    
    sqlJoinAlgorithm is now a hint to the planner to execute the join in the 
specified manner. The planner can decide to ignore the hint if it deduces that 
the specified algorithm can be detrimental to the performance of the join 
beforehand.
---
 docs/multi-stage-query/reference.md                |  10 +-
 .../apache/druid/msq/querykit/DataSourcePlan.java  |  56 +++++++++-
 .../org/apache/druid/msq/exec/MSQSelectTest.java   | 103 +++++++++++++++++
 .../org/apache/druid/msq/test/MSQTestBase.java     |  16 +++
 .../druid/sql/calcite/rule/DruidJoinRule.java      |  70 +++++++-----
 .../druid/sql/calcite/CalciteJoinQueryTest.java    | 124 +++++++++++++++++++++
 .../druid/sql/calcite/rule/DruidJoinRuleTest.java  |  21 ++--
 7 files changed, 363 insertions(+), 37 deletions(-)

diff --git a/docs/multi-stage-query/reference.md 
b/docs/multi-stage-query/reference.md
index 5bbe935f1e..08335ff114 100644
--- a/docs/multi-stage-query/reference.md
+++ b/docs/multi-stage-query/reference.md
@@ -234,7 +234,7 @@ The following table lists the context parameters for the 
MSQ task engine:
 | `maxNumTasks` | SELECT, INSERT, REPLACE<br /><br />The maximum total number 
of tasks to launch, including the controller task. The lowest possible value 
for this setting is 2: one controller and one worker. All tasks must be able to 
launch simultaneously. If they cannot, the query returns a `TaskStartTimeout` 
error code after approximately 10 minutes.<br /><br />May also be provided as 
`numTasks`. If both are present, `maxNumTasks` takes priority.                  
                      [...]
 | `taskAssignment` | SELECT, INSERT, REPLACE<br /><br />Determines how many 
tasks to use. Possible values include: <ul><li>`max`: Uses as many tasks as 
possible, up to `maxNumTasks`.</li><li>`auto`: When file sizes can be 
determined through directory listing (for example: local files, S3, GCS, HDFS) 
uses as few tasks as possible without exceeding 512 MiB or 10,000 files per 
task, unless exceeding these limits is necessary to stay within `maxNumTasks`. 
When calculating the size of files,  [...]
 | `finalizeAggregations` | SELECT, INSERT, REPLACE<br /><br />Determines the 
type of aggregation to return. If true, Druid finalizes the results of complex 
aggregations that directly appear in query results. If false, Druid returns the 
aggregation's intermediate type rather than finalized type. This parameter is 
useful during ingestion, where it enables storing sketches directly in Druid 
tables. For more information about aggregations, see [SQL aggregation 
functions](../querying/sql-aggr [...]
-| `sqlJoinAlgorithm` | SELECT, INSERT, REPLACE<br /><br />Algorithm to use for 
JOIN. Use `broadcast` (the default) for broadcast hash join or `sortMerge` for 
sort-merge join. Affects all JOIN operations in the query. See [Joins](#joins) 
for more details.                                                               
                                                                                
                                                                                
                [...]
+| `sqlJoinAlgorithm` | SELECT, INSERT, REPLACE<br /><br />Algorithm to use for 
JOIN. Use `broadcast` (the default) for broadcast hash join or `sortMerge` for 
sort-merge join. Affects all JOIN operations in the query. This is a hint to 
the MSQ engine and the actual joins in the query may proceed in a different way 
than specified. See [Joins](#joins) for more details.                           
                                                                                
                  [...]
 | `rowsInMemory` | INSERT or REPLACE<br /><br />Maximum number of rows to 
store in memory at once before flushing to disk during the segment generation 
process. Ignored for non-INSERT queries. In most cases, use the default value. 
You may need to override the default if you run into one of the [known 
issues](./known-issues.md) around memory usage.                                 
                                                                                
                               [...]
 | `segmentSortOrder` | INSERT or REPLACE<br /><br />Normally, Druid sorts rows 
in individual segments using `__time` first, followed by the [CLUSTERED 
BY](#clustered-by) clause. When you set `segmentSortOrder`, Druid sorts rows in 
segments using this column list first, followed by the CLUSTERED BY order.<br 
/><br />You provide the column list as comma-separated values or as a JSON 
array in string form. If your query includes `__time`, then this list must 
begin with `__time`. For example, [...]
 | `maxParseExceptions`| SELECT, INSERT, REPLACE<br /><br />Maximum number of 
parse exceptions that are ignored while executing the query before it stops 
with `TooManyWarningsFault`. To ignore all the parse exceptions, set the value 
to -1.                                                                          
                                                                                
                                                                                
                     [...]
@@ -253,6 +253,12 @@ Joins in multi-stage queries use one of two algorithms 
based on what you set the
 If you omit this context parameter, the MSQ task engine uses broadcast since 
it's the default join algorithm. The context parameter applies to the entire 
SQL statement, so you can't mix different
 join algorithms in the same query.
 
+`sqlJoinAlgorithm` is a hint to the planner to execute the join in the 
specified manner. The planner can decide to ignore
+the hint if it deduces that the specified algorithm can be detrimental to the 
performance of the join beforehand. This intelligence
+is very limited as of now, and the `sqlJoinAlgorithm` set would be respected 
in most cases, therefore the user should set it
+appropriately. See the advantages and the drawbacks for the 
[broadcast](#broadcast) and the [sort-merge](#sort-merge) join to 
+determine which join to use beforehand.
+
 ### Broadcast
 
 The default join algorithm for multi-stage queries is a broadcast hash join, 
which is similar to how
@@ -439,7 +445,7 @@ The following table describes error codes you may encounter 
in the `multiStageQu
 | <a name="error_TooManyInputFiles">`TooManyInputFiles`</a> | Exceeded the 
maximum number of input files or segments per worker (10,000 files or 
segments).<br /><br />If you encounter this limit, consider adding more 
workers, or breaking up your query into smaller queries that process fewer 
files or segments per query. | `numInputFiles`: The total number of input 
files/segments for the stage.<br /><br />`maxInputFiles`: The maximum number of 
input files/segments per worker per stage.<br  [...]
 | <a name="error_TooManyPartitions">`TooManyPartitions`</a> | Exceeded the 
maximum number of partitions for a stage (25,000 partitions).<br /><br />This 
can occur with INSERT or REPLACE statements that generate large numbers of 
segments, since each segment is associated with a partition. If you encounter 
this limit, consider breaking up your INSERT or REPLACE statement into smaller 
statements that process less data per statement. | `maxPartitions`: The limit 
on partitions which was exceeded |
 | <a name="error_TooManyClusteredByColumns">`TooManyClusteredByColumns`</a>  | 
Exceeded the maximum number of clustering columns for a stage (1,500 
columns).<br /><br />This can occur with `CLUSTERED BY`, `ORDER BY`, or `GROUP 
BY` with a large number of columns. | `numColumns`: The number of columns 
requested.<br /><br />`maxColumns`: The limit on columns which was 
exceeded.`stage`: The stage number exceeding the limit<br /><br /> |
-| <a name="error_TooManyRowsWithSameKey">`TooManyRowsWithSameKey`</a> | The 
number of rows for a given key exceeded the maximum number of buffered bytes on 
both sides of a join. See the [Limits](#limits) table for the specific limit. 
Only occurs when `sqlJoinAlgorithm` is `sortMerge`. | `key`: The key that had a 
large number of rows.<br /><br />`numBytes`: Number of bytes buffered, which 
may include other keys.<br /><br />`maxBytes`: Maximum number of bytes 
buffered. |
+| <a name="error_TooManyRowsWithSameKey">`TooManyRowsWithSameKey`</a> | The 
number of rows for a given key exceeded the maximum number of buffered bytes on 
both sides of a join. See the [Limits](#limits) table for the specific limit. 
Only occurs when join is executed via the sort-merge join algorithm. | `key`: 
The key that had a large number of rows.<br /><br />`numBytes`: Number of bytes 
buffered, which may include other keys.<br /><br />`maxBytes`: Maximum number 
of bytes buffered. |
 | <a name="error_TooManyColumns">`TooManyColumns`</a> | Exceeded the maximum 
number of columns for a stage (2,000 columns). | `numColumns`: The number of 
columns requested.<br /><br />`maxColumns`: The limit on columns which was 
exceeded. |
 | <a name="error_TooManyWarnings">`TooManyWarnings`</a> | Exceeded the maximum 
allowed number of warnings of a particular type. | `rootErrorCode`: The error 
code corresponding to the exception that exceeded the required limit. <br /><br 
/>`maxWarnings`: Maximum number of warnings that are allowed for the 
corresponding `rootErrorCode`. |
 | <a name="error_TooManyWorkers">`TooManyWorkers`</a> | Exceeded the maximum 
number of simultaneously-running workers. See the [Limits](#limits) table for 
more details. | `workers`: The number of simultaneously running workers that 
exceeded a hard or soft limit. This may be larger than the number of workers in 
any one stage if multiple stages are running simultaneously. <br /><br 
/>`maxWorkers`: The hard or soft limit on workers that was exceeded. If this is 
lower than the hard limit (1, [...]
diff --git 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java
 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java
index 95f5eae7bb..477c3e0e19 100644
--- 
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java
+++ 
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java
@@ -30,6 +30,7 @@ import org.apache.druid.frame.key.KeyColumn;
 import org.apache.druid.java.util.common.IAE;
 import org.apache.druid.java.util.common.Intervals;
 import org.apache.druid.java.util.common.UOE;
+import org.apache.druid.java.util.common.logger.Logger;
 import org.apache.druid.msq.input.InputSpec;
 import org.apache.druid.msq.input.external.ExternalInputSpec;
 import org.apache.druid.msq.input.inline.InlineInputSpec;
@@ -56,6 +57,7 @@ import 
org.apache.druid.query.spec.MultipleIntervalSegmentSpec;
 import org.apache.druid.query.spec.QuerySegmentSpec;
 import org.apache.druid.segment.column.ColumnHolder;
 import org.apache.druid.segment.column.RowSignature;
+import org.apache.druid.segment.join.JoinConditionAnalysis;
 import org.apache.druid.sql.calcite.external.ExternalDataSource;
 import org.apache.druid.sql.calcite.parser.DruidSqlInsert;
 import org.apache.druid.sql.calcite.planner.JoinAlgorithm;
@@ -79,6 +81,8 @@ public class DataSourcePlan
    */
   private static final Map<String, Object> CONTEXT_MAP_NO_SEGMENT_GRANULARITY 
= new HashMap<>();
 
+  private static final Logger log = new Logger(DataSourcePlan.class);
+
   static {
     
CONTEXT_MAP_NO_SEGMENT_GRANULARITY.put(DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY,
 null);
   }
@@ -144,9 +148,13 @@ public class DataSourcePlan
           broadcast
       );
     } else if (dataSource instanceof JoinDataSource) {
-      final JoinAlgorithm joinAlgorithm = 
PlannerContext.getJoinAlgorithm(queryContext);
+      final JoinAlgorithm preferredJoinAlgorithm = 
PlannerContext.getJoinAlgorithm(queryContext);
+      final JoinAlgorithm deducedJoinAlgorithm = deduceJoinAlgorithm(
+          preferredJoinAlgorithm,
+          ((JoinDataSource) dataSource)
+      );
 
-      switch (joinAlgorithm) {
+      switch (deducedJoinAlgorithm) {
         case BROADCAST:
           return forBroadcastHashJoin(
               queryKit,
@@ -171,7 +179,7 @@ public class DataSourcePlan
           );
 
         default:
-          throw new UOE("Cannot handle join algorithm [%s]", joinAlgorithm);
+          throw new UOE("Cannot handle join algorithm [%s]", 
deducedJoinAlgorithm);
       }
     } else {
       throw new UOE("Cannot handle dataSource [%s]", dataSource);
@@ -198,6 +206,48 @@ public class DataSourcePlan
     return Optional.ofNullable(subQueryDefBuilder);
   }
 
+  /**
+   * Contains the logic that deduces the join algorithm to be used. Ideally, 
this should reside while planning the
+   * native query, however we don't have the resources and the structure in 
place (when adding this function) to do so.
+   * Therefore, this is done while planning the MSQ query
+   * It takes into account the algorithm specified by "sqlJoinAlgorithm" in 
the query context and the join condition
+   * that is present in the query.
+   */
+  private static JoinAlgorithm deduceJoinAlgorithm(JoinAlgorithm 
preferredJoinAlgorithm, JoinDataSource joinDataSource)
+  {
+    JoinAlgorithm deducedJoinAlgorithm;
+    if (JoinAlgorithm.BROADCAST.equals(preferredJoinAlgorithm)) {
+      deducedJoinAlgorithm = JoinAlgorithm.BROADCAST;
+    } else if 
(isConditionEqualityOnLeftAndRightColumns(joinDataSource.getConditionAnalysis()))
 {
+      deducedJoinAlgorithm = JoinAlgorithm.SORT_MERGE;
+    } else {
+      deducedJoinAlgorithm = JoinAlgorithm.BROADCAST;
+    }
+
+    if (deducedJoinAlgorithm != preferredJoinAlgorithm) {
+      log.debug(
+          "User wanted to plan join [%s] as [%s], however the join will be 
executed as [%s]",
+          joinDataSource,
+          preferredJoinAlgorithm.toString(),
+          deducedJoinAlgorithm.toString()
+      );
+    }
+
+    return deducedJoinAlgorithm;
+  }
+
+  /**
+   * Checks if the join condition on two tables "table1" and "table2" is of 
the form
+   * table1.columnA = table2.columnA && table1.columnB = table2.columnB && ....
+   * sortMerge algorithm can help these types of join conditions
+   */
+  private static boolean 
isConditionEqualityOnLeftAndRightColumns(JoinConditionAnalysis 
joinConditionAnalysis)
+  {
+    return joinConditionAnalysis.getEquiConditions()
+                                .stream()
+                                .allMatch(equality -> 
equality.getLeftExpr().isIdentifier());
+  }
+
   /**
    * Whether this datasource must be processed by a single worker. True if, 
and only if, all inputs are broadcast.
    */
diff --git 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java
 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java
index d751946f24..0d4b3aff2f 100644
--- 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java
+++ 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java
@@ -41,12 +41,14 @@ import 
org.apache.druid.msq.indexing.destination.DurableStorageMSQDestination;
 import org.apache.druid.msq.indexing.destination.MSQSelectDestination;
 import org.apache.druid.msq.indexing.destination.TaskReportMSQDestination;
 import org.apache.druid.msq.indexing.report.MSQResultsReport;
+import org.apache.druid.msq.querykit.common.SortMergeJoinFrameProcessorFactory;
 import org.apache.druid.msq.test.CounterSnapshotMatcher;
 import org.apache.druid.msq.test.MSQTestBase;
 import org.apache.druid.msq.test.MSQTestFileUtils;
 import org.apache.druid.msq.util.MultiStageQueryContext;
 import org.apache.druid.query.InlineDataSource;
 import org.apache.druid.query.LookupDataSource;
+import org.apache.druid.query.Query;
 import org.apache.druid.query.QueryDataSource;
 import org.apache.druid.query.TableDataSource;
 import org.apache.druid.query.aggregation.CountAggregatorFactory;
@@ -93,6 +95,7 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Locale;
 import java.util.Map;
@@ -2013,6 +2016,106 @@ public class MSQSelectTest extends MSQTestBase
         .verifyResults();
   }
 
+  @Test
+  public void testJoinUsesDifferentAlgorithm()
+  {
+
+    // This test asserts that the join algorithnm used is a different one from 
that supplied. In sqlCompatible() mode
+    // the query gets planned differently, therefore we do use the sortMerge 
processor. Instead of having separate
+    // handling, a similar test has been described in CalciteJoinQueryMSQTest, 
therefore we don't want to repeat that
+    // here, hence ignoring in sqlCompatible() mode
+    if (NullHandling.sqlCompatible()) {
+      return;
+    }
+
+    RowSignature rowSignature = RowSignature.builder().add("cnt", 
ColumnType.LONG).build();
+
+    Map<String, Object> queryContext = new HashMap<>(context);
+    queryContext.put(PlannerContext.CTX_SQL_JOIN_ALGORITHM, 
JoinAlgorithm.SORT_MERGE.toString());
+
+    Query<?> expectedQuery;
+
+    expectedQuery = GroupByQuery
+        .builder()
+        .setDataSource(
+            join(
+                new QueryDataSource(
+                    newScanQueryBuilder()
+                        .dataSource("foo")
+                        .virtualColumns(expressionVirtualColumn("v0", "0", 
ColumnType.LONG))
+                        .columns("v0")
+                        .context(defaultScanQueryContext(
+                            queryContext,
+                            RowSignature.builder().add("v0", 
ColumnType.LONG).build()
+                        ))
+                        .intervals(querySegmentSpec(Intervals.ETERNITY))
+                        .build()
+                ),
+                new QueryDataSource(
+                    GroupByQuery.builder()
+                                .setDataSource("foo")
+                                
.setVirtualColumns(expressionVirtualColumn("v0", "1", ColumnType.LONG))
+                                .setDimensions(
+                                    new DefaultDimensionSpec("m1", "d0", 
ColumnType.FLOAT),
+                                    new DefaultDimensionSpec("v0", "d1", 
ColumnType.LONG)
+                                )
+                                .setContext(queryContext)
+                                
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
+                                .setGranularity(Granularities.ALL)
+                                .build()
+
+                ),
+                "j0.",
+                "(floor(100) == \"j0.d0\")",
+                JoinType.LEFT
+            )
+        )
+        .setAggregatorSpecs(
+            new FilteredAggregatorFactory(
+                new CountAggregatorFactory("a0"),
+                new SelectorDimFilter("j0.d1", null, null),
+                "a0"
+            )
+        )
+        .setContext(queryContext)
+        .setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
+        .setGranularity(Granularities.ALL)
+        .build();
+
+    testSelectQuery()
+        .setSql(
+            "SELECT COUNT(*) FILTER (WHERE FLOOR(100) NOT IN (SELECT m1 FROM 
foo)) AS cnt "
+            + "FROM foo"
+        )
+        .setExpectedRowSignature(rowSignature)
+        .setExpectedMSQSpec(
+            MSQSpec
+                .builder()
+                .query(expectedQuery)
+                .columnMappings(new ColumnMappings(
+                    ImmutableList.of(
+                        new ColumnMapping("a0", "cnt")
+                    )
+                ))
+                .destination(isDurableStorageDestination()
+                             ? DurableStorageMSQDestination.INSTANCE
+                             : TaskReportMSQDestination.INSTANCE)
+                .tuningConfig(MSQTuningConfig.defaultConfig())
+                .build())
+        .setQueryContext(queryContext)
+        .addAdhocReportAssertions(
+            msqTaskReportPayload -> 
msqTaskReportPayload.getStages().getStages().stream().noneMatch(
+                stage -> stage.getStageDefinition()
+                              .getProcessorFactory()
+                              .getClass()
+                              .equals(SortMergeJoinFrameProcessorFactory.class)
+            ),
+            "assert the query didn't use sort merge"
+        )
+        .setExpectedResultRows(ImmutableList.of(new Object[]{6L}))
+        .verifyResults();
+  }
+
   @Nonnull
   private List<Object[]> expectedMultiValueFooRowsGroup()
   {
diff --git 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java
 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java
index 68965f40bf..4e00fd657a 100644
--- 
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java
+++ 
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java
@@ -202,6 +202,7 @@ import java.util.Objects;
 import java.util.Set;
 import java.util.SortedMap;
 import java.util.TreeMap;
+import java.util.function.Predicate;
 import java.util.function.Supplier;
 import java.util.stream.Collectors;
 
@@ -807,6 +808,7 @@ public class MSQTestBase extends BaseCalciteQueryTest
     protected Set<Interval> expectedTombstoneIntervals = null;
     protected List<Object[]> expectedResultRows = null;
     protected Matcher<Throwable> expectedValidationErrorMatcher = null;
+    protected List<Pair<Predicate<MSQTaskReportPayload>, String>> 
adhocReportAssertionAndReasons = new ArrayList<>();
     protected Matcher<Throwable> expectedExecutionErrorMatcher = null;
     protected MSQFault expectedMSQFault = null;
     protected Class<? extends MSQFault> expectedMSQFaultClass = null;
@@ -868,6 +870,12 @@ public class MSQTestBase extends BaseCalciteQueryTest
       return asBuilder();
     }
 
+    public Builder addAdhocReportAssertions(Predicate<MSQTaskReportPayload> 
predicate, String reason)
+    {
+      this.adhocReportAssertionAndReasons.add(Pair.of(predicate, reason));
+      return asBuilder();
+    }
+
     public Builder setExpectedValidationErrorMatcher(Matcher<Throwable> 
expectedValidationErrorMatcher)
     {
       this.expectedValidationErrorMatcher = expectedValidationErrorMatcher;
@@ -1230,6 +1238,11 @@ public class MSQTestBase extends BaseCalciteQueryTest
           }
           Assert.assertEquals(expectedTombstoneSegmentIds, 
tombstoneSegmentIds);
         }
+
+        for (Pair<Predicate<MSQTaskReportPayload>, String> 
adhocReportAssertionAndReason : adhocReportAssertionAndReasons) {
+          Assert.assertTrue(adhocReportAssertionAndReason.rhs, 
adhocReportAssertionAndReason.lhs.test(reportPayload));
+        }
+
         // assert results
         assertResultsEquals(sql, expectedResultRows, transformedOutputRows);
       }
@@ -1340,6 +1353,9 @@ public class MSQTestBase extends BaseCalciteQueryTest
           log.info("found row signature %s", 
payload.getResults().getSignature());
           
log.info(rows.stream().map(Arrays::toString).collect(Collectors.joining("\n")));
 
+          for (Pair<Predicate<MSQTaskReportPayload>, String> 
adhocReportAssertionAndReason : adhocReportAssertionAndReasons) {
+            Assert.assertTrue(adhocReportAssertionAndReason.rhs, 
adhocReportAssertionAndReason.lhs.test(payload));
+          }
 
           log.info("Found spec: %s", 
objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(spec));
           return new Pair<>(spec, Pair.of(payload.getResults().getSignature(), 
rows));
diff --git 
a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java 
b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java
index c6ba499921..7bbcc44799 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java
@@ -39,6 +39,7 @@ import org.apache.calcite.rex.RexSlot;
 import org.apache.calcite.rex.RexUtil;
 import org.apache.calcite.sql.SqlKind;
 import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.sql.type.SqlTypeName;
 import org.apache.calcite.tools.RelBuilder;
 import org.apache.calcite.util.ImmutableBitSet;
 import org.apache.druid.java.util.common.Pair;
@@ -92,7 +93,7 @@ public class DruidJoinRule extends RelOptRule
     // 1) Can handle the join condition as a native join.
     // 2) Left has a PartialDruidQuery (i.e., is a real query, not top-level 
UNION ALL).
     // 3) Right has a PartialDruidQuery (i.e., is a real query, not top-level 
UNION ALL).
-    return canHandleCondition(join.getCondition(), 
join.getLeft().getRowType(), right)
+    return canHandleCondition(join.getCondition(), 
join.getLeft().getRowType(), right, join.getCluster().getRexBuilder())
            && left.getPartialDruidQuery() != null
            && right.getPartialDruidQuery() != null;
   }
@@ -116,7 +117,8 @@ public class DruidJoinRule extends RelOptRule
     ConditionAnalysis conditionAnalysis = analyzeCondition(
         join.getCondition(),
         join.getLeft().getRowType(),
-        right
+        right,
+        rexBuilder
     ).get();
     final boolean isLeftDirectAccessPossible = enableLeftScanDirect && (left 
instanceof DruidQueryRel);
 
@@ -223,9 +225,9 @@ public class DruidJoinRule extends RelOptRule
    * Returns whether {@link #analyzeCondition} would return something.
    */
   @VisibleForTesting
-  boolean canHandleCondition(final RexNode condition, final RelDataType 
leftRowType, DruidRel<?> right)
+  boolean canHandleCondition(final RexNode condition, final RelDataType 
leftRowType, DruidRel<?> right, final RexBuilder rexBuilder)
   {
-    return analyzeCondition(condition, leftRowType, right).isPresent();
+    return analyzeCondition(condition, leftRowType, right, 
rexBuilder).isPresent();
   }
 
   /**
@@ -235,7 +237,8 @@ public class DruidJoinRule extends RelOptRule
   private Optional<ConditionAnalysis> analyzeCondition(
       final RexNode condition,
       final RelDataType leftRowType,
-      final DruidRel<?> right
+      final DruidRel<?> right,
+      final RexBuilder rexBuilder
   )
   {
     final List<RexNode> subConditions = decomposeAnd(condition);
@@ -266,8 +269,29 @@ public class DruidJoinRule extends RelOptRule
         continue;
       }
 
-      if (!subCondition.isA(SqlKind.EQUALS)) {
-        // If it's not EQUALS, it's not supported.
+      RexNode firstOperand;
+      RexNode secondOperand;
+
+      if (subCondition.isA(SqlKind.INPUT_REF)) {
+        firstOperand = rexBuilder.makeLiteral(true);
+        secondOperand = subCondition;
+
+        if 
(!SqlTypeName.BOOLEAN_TYPES.contains(secondOperand.getType().getSqlTypeName())) 
{
+          plannerContext.setPlanningError(
+              "SQL requires a join with '%s' condition where the column is of 
the type %s, that is not supported",
+              subCondition.getKind(),
+              secondOperand.getType().getSqlTypeName()
+          );
+          return Optional.empty();
+
+        }
+      } else if (subCondition.isA(SqlKind.EQUALS)) {
+        final List<RexNode> operands = ((RexCall) subCondition).getOperands();
+        Preconditions.checkState(operands.size() == 2, "Expected 2 operands, 
got[%s]", operands.size());
+        firstOperand = operands.get(0);
+        secondOperand = operands.get(1);
+      } else {
+        // If it's not EQUALS or a BOOLEAN input ref, it's not supported.
         plannerContext.setPlanningError(
             "SQL requires a join with '%s' condition that is not supported.",
             subCondition.getKind()
@@ -275,16 +299,13 @@ public class DruidJoinRule extends RelOptRule
         return Optional.empty();
       }
 
-      final List<RexNode> operands = ((RexCall) subCondition).getOperands();
-      Preconditions.checkState(operands.size() == 2, "Expected 2 operands, 
got[%s]", operands.size());
-
-      if (isLeftExpression(operands.get(0), numLeftFields) && 
isRightInputRef(operands.get(1), numLeftFields)) {
-        equalitySubConditions.add(Pair.of(operands.get(0), (RexInputRef) 
operands.get(1)));
-        rightColumns.add((RexInputRef) operands.get(1));
-      } else if (isRightInputRef(operands.get(0), numLeftFields)
-                 && isLeftExpression(operands.get(1), numLeftFields)) {
-        equalitySubConditions.add(Pair.of(operands.get(1), (RexInputRef) 
operands.get(0)));
-        rightColumns.add((RexInputRef) operands.get(0));
+      if (isLeftExpression(firstOperand, numLeftFields) && 
isRightInputRef(secondOperand, numLeftFields)) {
+        equalitySubConditions.add(Pair.of(firstOperand, (RexInputRef) 
secondOperand));
+        rightColumns.add((RexInputRef) secondOperand);
+      } else if (isRightInputRef(firstOperand, numLeftFields)
+                 && isLeftExpression(secondOperand, numLeftFields)) {
+        equalitySubConditions.add(Pair.of(secondOperand, (RexInputRef) 
firstOperand));
+        rightColumns.add((RexInputRef) firstOperand);
       } else {
         // Cannot handle this condition.
         plannerContext.setPlanningError("SQL is resulting in a join that has 
unsupported operand types.");
@@ -310,7 +331,12 @@ public class DruidJoinRule extends RelOptRule
       }
     }
 
-    return Optional.of(new ConditionAnalysis(numLeftFields, 
equalitySubConditions, literalSubConditions));
+    return Optional.of(
+        new ConditionAnalysis(
+            numLeftFields,
+            equalitySubConditions,
+            literalSubConditions
+        ));
   }
 
   @VisibleForTesting
@@ -341,13 +367,6 @@ public class DruidJoinRule extends RelOptRule
 
   private boolean isLeftExpression(final RexNode rexNode, final int 
numLeftFields)
   {
-    if (!plannerContext.getJoinAlgorithm().canHandleLeftExpressions()) {
-      // Must be INPUT_REF.
-      if (!rexNode.isA(SqlKind.INPUT_REF)) {
-        return false;
-      }
-    }
-
     return 
ImmutableBitSet.range(numLeftFields).contains(RelOptUtil.InputFinder.bits(rexNode));
   }
 
@@ -375,6 +394,7 @@ public class DruidJoinRule extends RelOptRule
      */
     private final List<RexLiteral> literalSubConditions;
 
+
     ConditionAnalysis(
         int numLeftFields,
         List<Pair<RexNode, RexInputRef>> equalitySubConditions,
diff --git 
a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java 
b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java
index db631fe674..c4ff4a17a3 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java
@@ -5654,4 +5654,128 @@ public class CalciteJoinQueryTest extends 
BaseCalciteQueryTest
         )
     );
   }
+
+  @Test
+  public void testJoinWithInputRefCondition()
+  {
+    cannotVectorize();
+    Map<String, Object> context = new HashMap<>(QUERY_CONTEXT_DEFAULT);
+
+    Query<?> expectedQuery;
+
+    if (!NullHandling.sqlCompatible()) {
+      expectedQuery = Druids.newTimeseriesQueryBuilder()
+                            .dataSource(
+                                join(
+                                    new 
TableDataSource(CalciteTests.DATASOURCE1),
+                                    new QueryDataSource(
+                                        GroupByQuery.builder()
+                                                    
.setInterval(querySegmentSpec(Filtration.eternity()))
+                                                    
.setGranularity(Granularities.ALL)
+                                                    .setDataSource(new 
TableDataSource(CalciteTests.DATASOURCE1))
+                                                    
.setVirtualColumns(expressionVirtualColumn(
+                                                        "v0",
+                                                        "1",
+                                                        ColumnType.LONG
+                                                    ))
+                                                    .setDimensions(
+                                                        new 
DefaultDimensionSpec("m1", "d0", ColumnType.FLOAT),
+                                                        new 
DefaultDimensionSpec("v0", "d1", ColumnType.LONG)
+                                                    )
+                                                    .build()
+                                    ),
+                                    "j0.",
+                                    "(floor(100) == \"j0.d0\")",
+                                    JoinType.LEFT
+                                )
+                            )
+                            .granularity(Granularities.ALL)
+                            .aggregators(aggregators(
+                                new FilteredAggregatorFactory(
+                                    new CountAggregatorFactory("a0"),
+                                    new SelectorDimFilter("j0.d1", null, null)
+                                )
+                            ))
+                            
.context(getTimeseriesContextWithFloorTime(TIMESERIES_CONTEXT_BY_GRAN, "d0"))
+                            .intervals(querySegmentSpec(Filtration.eternity()))
+                            .context(context)
+                            .build();
+
+    } else {
+      expectedQuery = Druids.newTimeseriesQueryBuilder()
+                            .dataSource(
+                                join(
+                                    join(
+                                        new TableDataSource("foo"),
+                                        new QueryDataSource(
+                                            Druids.newTimeseriesQueryBuilder()
+                                                  .dataSource("foo")
+                                                  .aggregators(
+                                                      new 
CountAggregatorFactory("a0"),
+                                                      new 
FilteredAggregatorFactory(
+                                                          new 
CountAggregatorFactory("a1"),
+                                                          not(selector("m1", 
null, null)),
+                                                          "a1"
+                                                      )
+                                                  )
+                                                  
.intervals(querySegmentSpec(Filtration.eternity()))
+                                                  .context(context)
+                                                  .build()
+                                        ),
+                                        "j0.",
+                                        "1",
+                                        JoinType.INNER
+                                    ),
+                                    new QueryDataSource(
+                                        GroupByQuery.builder()
+                                                    
.setInterval(querySegmentSpec(Filtration.eternity()))
+                                                    
.setGranularity(Granularities.ALL)
+                                                    .setDataSource(new 
TableDataSource(CalciteTests.DATASOURCE1))
+                                                    
.setVirtualColumns(expressionVirtualColumn(
+                                                        "v0",
+                                                        "1",
+                                                        ColumnType.LONG
+                                                    ))
+                                                    .setDimensions(
+                                                        new 
DefaultDimensionSpec("m1", "d0", ColumnType.FLOAT),
+                                                        new 
DefaultDimensionSpec("v0", "d1", ColumnType.LONG)
+                                                    )
+                                                    .build()
+                                    ),
+                                    "_j0.",
+                                    "(floor(100) == \"_j0.d0\")",
+                                    JoinType.LEFT
+                                )
+                            )
+                            .granularity(Granularities.ALL)
+                            .aggregators(aggregators(
+                                new FilteredAggregatorFactory(
+                                    new CountAggregatorFactory("a0"),
+                                    or(
+                                        new SelectorDimFilter("j0.a0", "0", 
null),
+                                        and(
+                                            selector("_j0.d1", null, null),
+                                            expressionFilter("(\"j0.a1\" >= 
\"j0.a0\")")
+                                        )
+
+                                    )
+                                )
+                            ))
+                            
.context(getTimeseriesContextWithFloorTime(TIMESERIES_CONTEXT_BY_GRAN, "d0"))
+                            .intervals(querySegmentSpec(Filtration.eternity()))
+                            .context(context)
+                            .build();
+
+    }
+
+    testQuery(
+        "SELECT COUNT(*) FILTER (WHERE FLOOR(100) NOT IN (SELECT m1 FROM foo)) 
"
+        + "FROM foo",
+        context,
+        ImmutableList.of(expectedQuery),
+        ImmutableList.of(
+            new Object[]{6L}
+        )
+    );
+  }
 }
diff --git 
a/sql/src/test/java/org/apache/druid/sql/calcite/rule/DruidJoinRuleTest.java 
b/sql/src/test/java/org/apache/druid/sql/calcite/rule/DruidJoinRuleTest.java
index 41c6895dff..e531580162 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/rule/DruidJoinRuleTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/rule/DruidJoinRuleTest.java
@@ -84,7 +84,8 @@ public class DruidJoinRuleTest
                 rexBuilder.makeInputRef(joinType, 1)
             ),
             leftType,
-            null
+            null,
+            rexBuilder
         )
     );
   }
@@ -104,7 +105,8 @@ public class DruidJoinRuleTest
                 
rexBuilder.makeInputRef(typeFactory.createSqlType(SqlTypeName.VARCHAR), 1)
             ),
             leftType,
-            null
+            null,
+            rexBuilder
         )
     );
   }
@@ -124,7 +126,8 @@ public class DruidJoinRuleTest
                 )
             ),
             leftType,
-            null
+            null,
+            rexBuilder
         )
     );
   }
@@ -140,7 +143,8 @@ public class DruidJoinRuleTest
                 
rexBuilder.makeInputRef(typeFactory.createSqlType(SqlTypeName.VARCHAR), 0)
             ),
             leftType,
-            null
+            null,
+            rexBuilder
         )
     );
   }
@@ -156,7 +160,8 @@ public class DruidJoinRuleTest
                 
rexBuilder.makeInputRef(typeFactory.createSqlType(SqlTypeName.VARCHAR), 1)
             ),
             leftType,
-            null
+            null,
+            rexBuilder
         )
     );
   }
@@ -168,7 +173,8 @@ public class DruidJoinRuleTest
         druidJoinRule.canHandleCondition(
             rexBuilder.makeLiteral(true),
             leftType,
-            null
+            null,
+            rexBuilder
         )
     );
   }
@@ -180,7 +186,8 @@ public class DruidJoinRuleTest
         druidJoinRule.canHandleCondition(
             rexBuilder.makeLiteral(false),
             leftType,
-            null
+            null,
+            rexBuilder
         )
     );
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to