This is an automated email from the ASF dual-hosted git repository.
lakshsingla pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git
The following commit(s) were added to refs/heads/master by this push:
new 5ce536355e Fix planning bug while using sort merge frame processor
(#14450)
5ce536355e is described below
commit 5ce536355e80cbb72e7319359c2047508d907e7d
Author: Laksh Singla <[email protected]>
AuthorDate: Tue Jul 11 15:28:44 2023 +0530
Fix planning bug while using sort merge frame processor (#14450)
sqlJoinAlgorithm is now a hint to the planner to execute the join in the
specified manner. The planner can decide to ignore the hint if it deduces that
the specified algorithm can be detrimental to the performance of the join
beforehand.
---
docs/multi-stage-query/reference.md | 10 +-
.../apache/druid/msq/querykit/DataSourcePlan.java | 56 +++++++++-
.../org/apache/druid/msq/exec/MSQSelectTest.java | 103 +++++++++++++++++
.../org/apache/druid/msq/test/MSQTestBase.java | 16 +++
.../druid/sql/calcite/rule/DruidJoinRule.java | 70 +++++++-----
.../druid/sql/calcite/CalciteJoinQueryTest.java | 124 +++++++++++++++++++++
.../druid/sql/calcite/rule/DruidJoinRuleTest.java | 21 ++--
7 files changed, 363 insertions(+), 37 deletions(-)
diff --git a/docs/multi-stage-query/reference.md
b/docs/multi-stage-query/reference.md
index 5bbe935f1e..08335ff114 100644
--- a/docs/multi-stage-query/reference.md
+++ b/docs/multi-stage-query/reference.md
@@ -234,7 +234,7 @@ The following table lists the context parameters for the
MSQ task engine:
| `maxNumTasks` | SELECT, INSERT, REPLACE<br /><br />The maximum total number
of tasks to launch, including the controller task. The lowest possible value
for this setting is 2: one controller and one worker. All tasks must be able to
launch simultaneously. If they cannot, the query returns a `TaskStartTimeout`
error code after approximately 10 minutes.<br /><br />May also be provided as
`numTasks`. If both are present, `maxNumTasks` takes priority.
[...]
| `taskAssignment` | SELECT, INSERT, REPLACE<br /><br />Determines how many
tasks to use. Possible values include: <ul><li>`max`: Uses as many tasks as
possible, up to `maxNumTasks`.</li><li>`auto`: When file sizes can be
determined through directory listing (for example: local files, S3, GCS, HDFS)
uses as few tasks as possible without exceeding 512 MiB or 10,000 files per
task, unless exceeding these limits is necessary to stay within `maxNumTasks`.
When calculating the size of files, [...]
| `finalizeAggregations` | SELECT, INSERT, REPLACE<br /><br />Determines the
type of aggregation to return. If true, Druid finalizes the results of complex
aggregations that directly appear in query results. If false, Druid returns the
aggregation's intermediate type rather than finalized type. This parameter is
useful during ingestion, where it enables storing sketches directly in Druid
tables. For more information about aggregations, see [SQL aggregation
functions](../querying/sql-aggr [...]
-| `sqlJoinAlgorithm` | SELECT, INSERT, REPLACE<br /><br />Algorithm to use for
JOIN. Use `broadcast` (the default) for broadcast hash join or `sortMerge` for
sort-merge join. Affects all JOIN operations in the query. See [Joins](#joins)
for more details.
[...]
+| `sqlJoinAlgorithm` | SELECT, INSERT, REPLACE<br /><br />Algorithm to use for
JOIN. Use `broadcast` (the default) for broadcast hash join or `sortMerge` for
sort-merge join. Affects all JOIN operations in the query. This is a hint to
the MSQ engine and the actual joins in the query may proceed in a different way
than specified. See [Joins](#joins) for more details.
[...]
| `rowsInMemory` | INSERT or REPLACE<br /><br />Maximum number of rows to
store in memory at once before flushing to disk during the segment generation
process. Ignored for non-INSERT queries. In most cases, use the default value.
You may need to override the default if you run into one of the [known
issues](./known-issues.md) around memory usage.
[...]
| `segmentSortOrder` | INSERT or REPLACE<br /><br />Normally, Druid sorts rows
in individual segments using `__time` first, followed by the [CLUSTERED
BY](#clustered-by) clause. When you set `segmentSortOrder`, Druid sorts rows in
segments using this column list first, followed by the CLUSTERED BY order.<br
/><br />You provide the column list as comma-separated values or as a JSON
array in string form. If your query includes `__time`, then this list must
begin with `__time`. For example, [...]
| `maxParseExceptions`| SELECT, INSERT, REPLACE<br /><br />Maximum number of
parse exceptions that are ignored while executing the query before it stops
with `TooManyWarningsFault`. To ignore all the parse exceptions, set the value
to -1.
[...]
@@ -253,6 +253,12 @@ Joins in multi-stage queries use one of two algorithms
based on what you set the
If you omit this context parameter, the MSQ task engine uses broadcast since
it's the default join algorithm. The context parameter applies to the entire
SQL statement, so you can't mix different
join algorithms in the same query.
+`sqlJoinAlgorithm` is a hint to the planner to execute the join in the
specified manner. The planner can decide to ignore
+the hint if it deduces that the specified algorithm can be detrimental to the
performance of the join beforehand. This intelligence
+is very limited as of now, and the `sqlJoinAlgorithm` set would be respected
in most cases, therefore the user should set it
+appropriately. See the advantages and the drawbacks for the
[broadcast](#broadcast) and the [sort-merge](#sort-merge) join to
+determine which join to use beforehand.
+
### Broadcast
The default join algorithm for multi-stage queries is a broadcast hash join,
which is similar to how
@@ -439,7 +445,7 @@ The following table describes error codes you may encounter
in the `multiStageQu
| <a name="error_TooManyInputFiles">`TooManyInputFiles`</a> | Exceeded the
maximum number of input files or segments per worker (10,000 files or
segments).<br /><br />If you encounter this limit, consider adding more
workers, or breaking up your query into smaller queries that process fewer
files or segments per query. | `numInputFiles`: The total number of input
files/segments for the stage.<br /><br />`maxInputFiles`: The maximum number of
input files/segments per worker per stage.<br [...]
| <a name="error_TooManyPartitions">`TooManyPartitions`</a> | Exceeded the
maximum number of partitions for a stage (25,000 partitions).<br /><br />This
can occur with INSERT or REPLACE statements that generate large numbers of
segments, since each segment is associated with a partition. If you encounter
this limit, consider breaking up your INSERT or REPLACE statement into smaller
statements that process less data per statement. | `maxPartitions`: The limit
on partitions which was exceeded |
| <a name="error_TooManyClusteredByColumns">`TooManyClusteredByColumns`</a> |
Exceeded the maximum number of clustering columns for a stage (1,500
columns).<br /><br />This can occur with `CLUSTERED BY`, `ORDER BY`, or `GROUP
BY` with a large number of columns. | `numColumns`: The number of columns
requested.<br /><br />`maxColumns`: The limit on columns which was
exceeded.`stage`: The stage number exceeding the limit<br /><br /> |
-| <a name="error_TooManyRowsWithSameKey">`TooManyRowsWithSameKey`</a> | The
number of rows for a given key exceeded the maximum number of buffered bytes on
both sides of a join. See the [Limits](#limits) table for the specific limit.
Only occurs when `sqlJoinAlgorithm` is `sortMerge`. | `key`: The key that had a
large number of rows.<br /><br />`numBytes`: Number of bytes buffered, which
may include other keys.<br /><br />`maxBytes`: Maximum number of bytes
buffered. |
+| <a name="error_TooManyRowsWithSameKey">`TooManyRowsWithSameKey`</a> | The
number of rows for a given key exceeded the maximum number of buffered bytes on
both sides of a join. See the [Limits](#limits) table for the specific limit.
Only occurs when join is executed via the sort-merge join algorithm. | `key`:
The key that had a large number of rows.<br /><br />`numBytes`: Number of bytes
buffered, which may include other keys.<br /><br />`maxBytes`: Maximum number
of bytes buffered. |
| <a name="error_TooManyColumns">`TooManyColumns`</a> | Exceeded the maximum
number of columns for a stage (2,000 columns). | `numColumns`: The number of
columns requested.<br /><br />`maxColumns`: The limit on columns which was
exceeded. |
| <a name="error_TooManyWarnings">`TooManyWarnings`</a> | Exceeded the maximum
allowed number of warnings of a particular type. | `rootErrorCode`: The error
code corresponding to the exception that exceeded the required limit. <br /><br
/>`maxWarnings`: Maximum number of warnings that are allowed for the
corresponding `rootErrorCode`. |
| <a name="error_TooManyWorkers">`TooManyWorkers`</a> | Exceeded the maximum
number of simultaneously-running workers. See the [Limits](#limits) table for
more details. | `workers`: The number of simultaneously running workers that
exceeded a hard or soft limit. This may be larger than the number of workers in
any one stage if multiple stages are running simultaneously. <br /><br
/>`maxWorkers`: The hard or soft limit on workers that was exceeded. If this is
lower than the hard limit (1, [...]
diff --git
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java
index 95f5eae7bb..477c3e0e19 100644
---
a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java
+++
b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/DataSourcePlan.java
@@ -30,6 +30,7 @@ import org.apache.druid.frame.key.KeyColumn;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.UOE;
+import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.msq.input.InputSpec;
import org.apache.druid.msq.input.external.ExternalInputSpec;
import org.apache.druid.msq.input.inline.InlineInputSpec;
@@ -56,6 +57,7 @@ import
org.apache.druid.query.spec.MultipleIntervalSegmentSpec;
import org.apache.druid.query.spec.QuerySegmentSpec;
import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.column.RowSignature;
+import org.apache.druid.segment.join.JoinConditionAnalysis;
import org.apache.druid.sql.calcite.external.ExternalDataSource;
import org.apache.druid.sql.calcite.parser.DruidSqlInsert;
import org.apache.druid.sql.calcite.planner.JoinAlgorithm;
@@ -79,6 +81,8 @@ public class DataSourcePlan
*/
private static final Map<String, Object> CONTEXT_MAP_NO_SEGMENT_GRANULARITY
= new HashMap<>();
+ private static final Logger log = new Logger(DataSourcePlan.class);
+
static {
CONTEXT_MAP_NO_SEGMENT_GRANULARITY.put(DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY,
null);
}
@@ -144,9 +148,13 @@ public class DataSourcePlan
broadcast
);
} else if (dataSource instanceof JoinDataSource) {
- final JoinAlgorithm joinAlgorithm =
PlannerContext.getJoinAlgorithm(queryContext);
+ final JoinAlgorithm preferredJoinAlgorithm =
PlannerContext.getJoinAlgorithm(queryContext);
+ final JoinAlgorithm deducedJoinAlgorithm = deduceJoinAlgorithm(
+ preferredJoinAlgorithm,
+ ((JoinDataSource) dataSource)
+ );
- switch (joinAlgorithm) {
+ switch (deducedJoinAlgorithm) {
case BROADCAST:
return forBroadcastHashJoin(
queryKit,
@@ -171,7 +179,7 @@ public class DataSourcePlan
);
default:
- throw new UOE("Cannot handle join algorithm [%s]", joinAlgorithm);
+ throw new UOE("Cannot handle join algorithm [%s]",
deducedJoinAlgorithm);
}
} else {
throw new UOE("Cannot handle dataSource [%s]", dataSource);
@@ -198,6 +206,48 @@ public class DataSourcePlan
return Optional.ofNullable(subQueryDefBuilder);
}
+ /**
+ * Contains the logic that deduces the join algorithm to be used. Ideally,
this should reside while planning the
+ * native query, however we don't have the resources and the structure in
place (when adding this function) to do so.
+ * Therefore, this is done while planning the MSQ query
+ * It takes into account the algorithm specified by "sqlJoinAlgorithm" in
the query context and the join condition
+ * that is present in the query.
+ */
+ private static JoinAlgorithm deduceJoinAlgorithm(JoinAlgorithm
preferredJoinAlgorithm, JoinDataSource joinDataSource)
+ {
+ JoinAlgorithm deducedJoinAlgorithm;
+ if (JoinAlgorithm.BROADCAST.equals(preferredJoinAlgorithm)) {
+ deducedJoinAlgorithm = JoinAlgorithm.BROADCAST;
+ } else if
(isConditionEqualityOnLeftAndRightColumns(joinDataSource.getConditionAnalysis()))
{
+ deducedJoinAlgorithm = JoinAlgorithm.SORT_MERGE;
+ } else {
+ deducedJoinAlgorithm = JoinAlgorithm.BROADCAST;
+ }
+
+ if (deducedJoinAlgorithm != preferredJoinAlgorithm) {
+ log.debug(
+ "User wanted to plan join [%s] as [%s], however the join will be
executed as [%s]",
+ joinDataSource,
+ preferredJoinAlgorithm.toString(),
+ deducedJoinAlgorithm.toString()
+ );
+ }
+
+ return deducedJoinAlgorithm;
+ }
+
+ /**
+ * Checks if the join condition on two tables "table1" and "table2" is of
the form
+ * table1.columnA = table2.columnA && table1.columnB = table2.columnB && ....
+ * sortMerge algorithm can help these types of join conditions
+ */
+ private static boolean
isConditionEqualityOnLeftAndRightColumns(JoinConditionAnalysis
joinConditionAnalysis)
+ {
+ return joinConditionAnalysis.getEquiConditions()
+ .stream()
+ .allMatch(equality ->
equality.getLeftExpr().isIdentifier());
+ }
+
/**
* Whether this datasource must be processed by a single worker. True if,
and only if, all inputs are broadcast.
*/
diff --git
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java
index d751946f24..0d4b3aff2f 100644
---
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java
+++
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java
@@ -41,12 +41,14 @@ import
org.apache.druid.msq.indexing.destination.DurableStorageMSQDestination;
import org.apache.druid.msq.indexing.destination.MSQSelectDestination;
import org.apache.druid.msq.indexing.destination.TaskReportMSQDestination;
import org.apache.druid.msq.indexing.report.MSQResultsReport;
+import org.apache.druid.msq.querykit.common.SortMergeJoinFrameProcessorFactory;
import org.apache.druid.msq.test.CounterSnapshotMatcher;
import org.apache.druid.msq.test.MSQTestBase;
import org.apache.druid.msq.test.MSQTestFileUtils;
import org.apache.druid.msq.util.MultiStageQueryContext;
import org.apache.druid.query.InlineDataSource;
import org.apache.druid.query.LookupDataSource;
+import org.apache.druid.query.Query;
import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
@@ -93,6 +95,7 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
+import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
@@ -2013,6 +2016,106 @@ public class MSQSelectTest extends MSQTestBase
.verifyResults();
}
+ @Test
+ public void testJoinUsesDifferentAlgorithm()
+ {
+
+ // This test asserts that the join algorithnm used is a different one from
that supplied. In sqlCompatible() mode
+ // the query gets planned differently, therefore we do use the sortMerge
processor. Instead of having separate
+ // handling, a similar test has been described in CalciteJoinQueryMSQTest,
therefore we don't want to repeat that
+ // here, hence ignoring in sqlCompatible() mode
+ if (NullHandling.sqlCompatible()) {
+ return;
+ }
+
+ RowSignature rowSignature = RowSignature.builder().add("cnt",
ColumnType.LONG).build();
+
+ Map<String, Object> queryContext = new HashMap<>(context);
+ queryContext.put(PlannerContext.CTX_SQL_JOIN_ALGORITHM,
JoinAlgorithm.SORT_MERGE.toString());
+
+ Query<?> expectedQuery;
+
+ expectedQuery = GroupByQuery
+ .builder()
+ .setDataSource(
+ join(
+ new QueryDataSource(
+ newScanQueryBuilder()
+ .dataSource("foo")
+ .virtualColumns(expressionVirtualColumn("v0", "0",
ColumnType.LONG))
+ .columns("v0")
+ .context(defaultScanQueryContext(
+ queryContext,
+ RowSignature.builder().add("v0",
ColumnType.LONG).build()
+ ))
+ .intervals(querySegmentSpec(Intervals.ETERNITY))
+ .build()
+ ),
+ new QueryDataSource(
+ GroupByQuery.builder()
+ .setDataSource("foo")
+
.setVirtualColumns(expressionVirtualColumn("v0", "1", ColumnType.LONG))
+ .setDimensions(
+ new DefaultDimensionSpec("m1", "d0",
ColumnType.FLOAT),
+ new DefaultDimensionSpec("v0", "d1",
ColumnType.LONG)
+ )
+ .setContext(queryContext)
+
.setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
+ .setGranularity(Granularities.ALL)
+ .build()
+
+ ),
+ "j0.",
+ "(floor(100) == \"j0.d0\")",
+ JoinType.LEFT
+ )
+ )
+ .setAggregatorSpecs(
+ new FilteredAggregatorFactory(
+ new CountAggregatorFactory("a0"),
+ new SelectorDimFilter("j0.d1", null, null),
+ "a0"
+ )
+ )
+ .setContext(queryContext)
+ .setQuerySegmentSpec(querySegmentSpec(Intervals.ETERNITY))
+ .setGranularity(Granularities.ALL)
+ .build();
+
+ testSelectQuery()
+ .setSql(
+ "SELECT COUNT(*) FILTER (WHERE FLOOR(100) NOT IN (SELECT m1 FROM
foo)) AS cnt "
+ + "FROM foo"
+ )
+ .setExpectedRowSignature(rowSignature)
+ .setExpectedMSQSpec(
+ MSQSpec
+ .builder()
+ .query(expectedQuery)
+ .columnMappings(new ColumnMappings(
+ ImmutableList.of(
+ new ColumnMapping("a0", "cnt")
+ )
+ ))
+ .destination(isDurableStorageDestination()
+ ? DurableStorageMSQDestination.INSTANCE
+ : TaskReportMSQDestination.INSTANCE)
+ .tuningConfig(MSQTuningConfig.defaultConfig())
+ .build())
+ .setQueryContext(queryContext)
+ .addAdhocReportAssertions(
+ msqTaskReportPayload ->
msqTaskReportPayload.getStages().getStages().stream().noneMatch(
+ stage -> stage.getStageDefinition()
+ .getProcessorFactory()
+ .getClass()
+ .equals(SortMergeJoinFrameProcessorFactory.class)
+ ),
+ "assert the query didn't use sort merge"
+ )
+ .setExpectedResultRows(ImmutableList.of(new Object[]{6L}))
+ .verifyResults();
+ }
+
@Nonnull
private List<Object[]> expectedMultiValueFooRowsGroup()
{
diff --git
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java
index 68965f40bf..4e00fd657a 100644
---
a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java
+++
b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java
@@ -202,6 +202,7 @@ import java.util.Objects;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
+import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Collectors;
@@ -807,6 +808,7 @@ public class MSQTestBase extends BaseCalciteQueryTest
protected Set<Interval> expectedTombstoneIntervals = null;
protected List<Object[]> expectedResultRows = null;
protected Matcher<Throwable> expectedValidationErrorMatcher = null;
+ protected List<Pair<Predicate<MSQTaskReportPayload>, String>>
adhocReportAssertionAndReasons = new ArrayList<>();
protected Matcher<Throwable> expectedExecutionErrorMatcher = null;
protected MSQFault expectedMSQFault = null;
protected Class<? extends MSQFault> expectedMSQFaultClass = null;
@@ -868,6 +870,12 @@ public class MSQTestBase extends BaseCalciteQueryTest
return asBuilder();
}
+ public Builder addAdhocReportAssertions(Predicate<MSQTaskReportPayload>
predicate, String reason)
+ {
+ this.adhocReportAssertionAndReasons.add(Pair.of(predicate, reason));
+ return asBuilder();
+ }
+
public Builder setExpectedValidationErrorMatcher(Matcher<Throwable>
expectedValidationErrorMatcher)
{
this.expectedValidationErrorMatcher = expectedValidationErrorMatcher;
@@ -1230,6 +1238,11 @@ public class MSQTestBase extends BaseCalciteQueryTest
}
Assert.assertEquals(expectedTombstoneSegmentIds,
tombstoneSegmentIds);
}
+
+ for (Pair<Predicate<MSQTaskReportPayload>, String>
adhocReportAssertionAndReason : adhocReportAssertionAndReasons) {
+ Assert.assertTrue(adhocReportAssertionAndReason.rhs,
adhocReportAssertionAndReason.lhs.test(reportPayload));
+ }
+
// assert results
assertResultsEquals(sql, expectedResultRows, transformedOutputRows);
}
@@ -1340,6 +1353,9 @@ public class MSQTestBase extends BaseCalciteQueryTest
log.info("found row signature %s",
payload.getResults().getSignature());
log.info(rows.stream().map(Arrays::toString).collect(Collectors.joining("\n")));
+ for (Pair<Predicate<MSQTaskReportPayload>, String>
adhocReportAssertionAndReason : adhocReportAssertionAndReasons) {
+ Assert.assertTrue(adhocReportAssertionAndReason.rhs,
adhocReportAssertionAndReason.lhs.test(payload));
+ }
log.info("Found spec: %s",
objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(spec));
return new Pair<>(spec, Pair.of(payload.getResults().getSignature(),
rows));
diff --git
a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java
b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java
index c6ba499921..7bbcc44799 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidJoinRule.java
@@ -39,6 +39,7 @@ import org.apache.calcite.rex.RexSlot;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.druid.java.util.common.Pair;
@@ -92,7 +93,7 @@ public class DruidJoinRule extends RelOptRule
// 1) Can handle the join condition as a native join.
// 2) Left has a PartialDruidQuery (i.e., is a real query, not top-level
UNION ALL).
// 3) Right has a PartialDruidQuery (i.e., is a real query, not top-level
UNION ALL).
- return canHandleCondition(join.getCondition(),
join.getLeft().getRowType(), right)
+ return canHandleCondition(join.getCondition(),
join.getLeft().getRowType(), right, join.getCluster().getRexBuilder())
&& left.getPartialDruidQuery() != null
&& right.getPartialDruidQuery() != null;
}
@@ -116,7 +117,8 @@ public class DruidJoinRule extends RelOptRule
ConditionAnalysis conditionAnalysis = analyzeCondition(
join.getCondition(),
join.getLeft().getRowType(),
- right
+ right,
+ rexBuilder
).get();
final boolean isLeftDirectAccessPossible = enableLeftScanDirect && (left
instanceof DruidQueryRel);
@@ -223,9 +225,9 @@ public class DruidJoinRule extends RelOptRule
* Returns whether {@link #analyzeCondition} would return something.
*/
@VisibleForTesting
- boolean canHandleCondition(final RexNode condition, final RelDataType
leftRowType, DruidRel<?> right)
+ boolean canHandleCondition(final RexNode condition, final RelDataType
leftRowType, DruidRel<?> right, final RexBuilder rexBuilder)
{
- return analyzeCondition(condition, leftRowType, right).isPresent();
+ return analyzeCondition(condition, leftRowType, right,
rexBuilder).isPresent();
}
/**
@@ -235,7 +237,8 @@ public class DruidJoinRule extends RelOptRule
private Optional<ConditionAnalysis> analyzeCondition(
final RexNode condition,
final RelDataType leftRowType,
- final DruidRel<?> right
+ final DruidRel<?> right,
+ final RexBuilder rexBuilder
)
{
final List<RexNode> subConditions = decomposeAnd(condition);
@@ -266,8 +269,29 @@ public class DruidJoinRule extends RelOptRule
continue;
}
- if (!subCondition.isA(SqlKind.EQUALS)) {
- // If it's not EQUALS, it's not supported.
+ RexNode firstOperand;
+ RexNode secondOperand;
+
+ if (subCondition.isA(SqlKind.INPUT_REF)) {
+ firstOperand = rexBuilder.makeLiteral(true);
+ secondOperand = subCondition;
+
+ if
(!SqlTypeName.BOOLEAN_TYPES.contains(secondOperand.getType().getSqlTypeName()))
{
+ plannerContext.setPlanningError(
+ "SQL requires a join with '%s' condition where the column is of
the type %s, that is not supported",
+ subCondition.getKind(),
+ secondOperand.getType().getSqlTypeName()
+ );
+ return Optional.empty();
+
+ }
+ } else if (subCondition.isA(SqlKind.EQUALS)) {
+ final List<RexNode> operands = ((RexCall) subCondition).getOperands();
+ Preconditions.checkState(operands.size() == 2, "Expected 2 operands,
got[%s]", operands.size());
+ firstOperand = operands.get(0);
+ secondOperand = operands.get(1);
+ } else {
+ // If it's not EQUALS or a BOOLEAN input ref, it's not supported.
plannerContext.setPlanningError(
"SQL requires a join with '%s' condition that is not supported.",
subCondition.getKind()
@@ -275,16 +299,13 @@ public class DruidJoinRule extends RelOptRule
return Optional.empty();
}
- final List<RexNode> operands = ((RexCall) subCondition).getOperands();
- Preconditions.checkState(operands.size() == 2, "Expected 2 operands,
got[%s]", operands.size());
-
- if (isLeftExpression(operands.get(0), numLeftFields) &&
isRightInputRef(operands.get(1), numLeftFields)) {
- equalitySubConditions.add(Pair.of(operands.get(0), (RexInputRef)
operands.get(1)));
- rightColumns.add((RexInputRef) operands.get(1));
- } else if (isRightInputRef(operands.get(0), numLeftFields)
- && isLeftExpression(operands.get(1), numLeftFields)) {
- equalitySubConditions.add(Pair.of(operands.get(1), (RexInputRef)
operands.get(0)));
- rightColumns.add((RexInputRef) operands.get(0));
+ if (isLeftExpression(firstOperand, numLeftFields) &&
isRightInputRef(secondOperand, numLeftFields)) {
+ equalitySubConditions.add(Pair.of(firstOperand, (RexInputRef)
secondOperand));
+ rightColumns.add((RexInputRef) secondOperand);
+ } else if (isRightInputRef(firstOperand, numLeftFields)
+ && isLeftExpression(secondOperand, numLeftFields)) {
+ equalitySubConditions.add(Pair.of(secondOperand, (RexInputRef)
firstOperand));
+ rightColumns.add((RexInputRef) firstOperand);
} else {
// Cannot handle this condition.
plannerContext.setPlanningError("SQL is resulting in a join that has
unsupported operand types.");
@@ -310,7 +331,12 @@ public class DruidJoinRule extends RelOptRule
}
}
- return Optional.of(new ConditionAnalysis(numLeftFields,
equalitySubConditions, literalSubConditions));
+ return Optional.of(
+ new ConditionAnalysis(
+ numLeftFields,
+ equalitySubConditions,
+ literalSubConditions
+ ));
}
@VisibleForTesting
@@ -341,13 +367,6 @@ public class DruidJoinRule extends RelOptRule
private boolean isLeftExpression(final RexNode rexNode, final int
numLeftFields)
{
- if (!plannerContext.getJoinAlgorithm().canHandleLeftExpressions()) {
- // Must be INPUT_REF.
- if (!rexNode.isA(SqlKind.INPUT_REF)) {
- return false;
- }
- }
-
return
ImmutableBitSet.range(numLeftFields).contains(RelOptUtil.InputFinder.bits(rexNode));
}
@@ -375,6 +394,7 @@ public class DruidJoinRule extends RelOptRule
*/
private final List<RexLiteral> literalSubConditions;
+
ConditionAnalysis(
int numLeftFields,
List<Pair<RexNode, RexInputRef>> equalitySubConditions,
diff --git
a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java
b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java
index db631fe674..c4ff4a17a3 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java
@@ -5654,4 +5654,128 @@ public class CalciteJoinQueryTest extends
BaseCalciteQueryTest
)
);
}
+
+ @Test
+ public void testJoinWithInputRefCondition()
+ {
+ cannotVectorize();
+ Map<String, Object> context = new HashMap<>(QUERY_CONTEXT_DEFAULT);
+
+ Query<?> expectedQuery;
+
+ if (!NullHandling.sqlCompatible()) {
+ expectedQuery = Druids.newTimeseriesQueryBuilder()
+ .dataSource(
+ join(
+ new
TableDataSource(CalciteTests.DATASOURCE1),
+ new QueryDataSource(
+ GroupByQuery.builder()
+
.setInterval(querySegmentSpec(Filtration.eternity()))
+
.setGranularity(Granularities.ALL)
+ .setDataSource(new
TableDataSource(CalciteTests.DATASOURCE1))
+
.setVirtualColumns(expressionVirtualColumn(
+ "v0",
+ "1",
+ ColumnType.LONG
+ ))
+ .setDimensions(
+ new
DefaultDimensionSpec("m1", "d0", ColumnType.FLOAT),
+ new
DefaultDimensionSpec("v0", "d1", ColumnType.LONG)
+ )
+ .build()
+ ),
+ "j0.",
+ "(floor(100) == \"j0.d0\")",
+ JoinType.LEFT
+ )
+ )
+ .granularity(Granularities.ALL)
+ .aggregators(aggregators(
+ new FilteredAggregatorFactory(
+ new CountAggregatorFactory("a0"),
+ new SelectorDimFilter("j0.d1", null, null)
+ )
+ ))
+
.context(getTimeseriesContextWithFloorTime(TIMESERIES_CONTEXT_BY_GRAN, "d0"))
+ .intervals(querySegmentSpec(Filtration.eternity()))
+ .context(context)
+ .build();
+
+ } else {
+ expectedQuery = Druids.newTimeseriesQueryBuilder()
+ .dataSource(
+ join(
+ join(
+ new TableDataSource("foo"),
+ new QueryDataSource(
+ Druids.newTimeseriesQueryBuilder()
+ .dataSource("foo")
+ .aggregators(
+ new
CountAggregatorFactory("a0"),
+ new
FilteredAggregatorFactory(
+ new
CountAggregatorFactory("a1"),
+ not(selector("m1",
null, null)),
+ "a1"
+ )
+ )
+
.intervals(querySegmentSpec(Filtration.eternity()))
+ .context(context)
+ .build()
+ ),
+ "j0.",
+ "1",
+ JoinType.INNER
+ ),
+ new QueryDataSource(
+ GroupByQuery.builder()
+
.setInterval(querySegmentSpec(Filtration.eternity()))
+
.setGranularity(Granularities.ALL)
+ .setDataSource(new
TableDataSource(CalciteTests.DATASOURCE1))
+
.setVirtualColumns(expressionVirtualColumn(
+ "v0",
+ "1",
+ ColumnType.LONG
+ ))
+ .setDimensions(
+ new
DefaultDimensionSpec("m1", "d0", ColumnType.FLOAT),
+ new
DefaultDimensionSpec("v0", "d1", ColumnType.LONG)
+ )
+ .build()
+ ),
+ "_j0.",
+ "(floor(100) == \"_j0.d0\")",
+ JoinType.LEFT
+ )
+ )
+ .granularity(Granularities.ALL)
+ .aggregators(aggregators(
+ new FilteredAggregatorFactory(
+ new CountAggregatorFactory("a0"),
+ or(
+ new SelectorDimFilter("j0.a0", "0",
null),
+ and(
+ selector("_j0.d1", null, null),
+ expressionFilter("(\"j0.a1\" >=
\"j0.a0\")")
+ )
+
+ )
+ )
+ ))
+
.context(getTimeseriesContextWithFloorTime(TIMESERIES_CONTEXT_BY_GRAN, "d0"))
+ .intervals(querySegmentSpec(Filtration.eternity()))
+ .context(context)
+ .build();
+
+ }
+
+ testQuery(
+ "SELECT COUNT(*) FILTER (WHERE FLOOR(100) NOT IN (SELECT m1 FROM foo))
"
+ + "FROM foo",
+ context,
+ ImmutableList.of(expectedQuery),
+ ImmutableList.of(
+ new Object[]{6L}
+ )
+ );
+ }
}
diff --git
a/sql/src/test/java/org/apache/druid/sql/calcite/rule/DruidJoinRuleTest.java
b/sql/src/test/java/org/apache/druid/sql/calcite/rule/DruidJoinRuleTest.java
index 41c6895dff..e531580162 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/rule/DruidJoinRuleTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/rule/DruidJoinRuleTest.java
@@ -84,7 +84,8 @@ public class DruidJoinRuleTest
rexBuilder.makeInputRef(joinType, 1)
),
leftType,
- null
+ null,
+ rexBuilder
)
);
}
@@ -104,7 +105,8 @@ public class DruidJoinRuleTest
rexBuilder.makeInputRef(typeFactory.createSqlType(SqlTypeName.VARCHAR), 1)
),
leftType,
- null
+ null,
+ rexBuilder
)
);
}
@@ -124,7 +126,8 @@ public class DruidJoinRuleTest
)
),
leftType,
- null
+ null,
+ rexBuilder
)
);
}
@@ -140,7 +143,8 @@ public class DruidJoinRuleTest
rexBuilder.makeInputRef(typeFactory.createSqlType(SqlTypeName.VARCHAR), 0)
),
leftType,
- null
+ null,
+ rexBuilder
)
);
}
@@ -156,7 +160,8 @@ public class DruidJoinRuleTest
rexBuilder.makeInputRef(typeFactory.createSqlType(SqlTypeName.VARCHAR), 1)
),
leftType,
- null
+ null,
+ rexBuilder
)
);
}
@@ -168,7 +173,8 @@ public class DruidJoinRuleTest
druidJoinRule.canHandleCondition(
rexBuilder.makeLiteral(true),
leftType,
- null
+ null,
+ rexBuilder
)
);
}
@@ -180,7 +186,8 @@ public class DruidJoinRuleTest
druidJoinRule.canHandleCondition(
rexBuilder.makeLiteral(false),
leftType,
- null
+ null,
+ rexBuilder
)
);
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]