This is an automated email from the ASF dual-hosted git repository.

cwylie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git


The following commit(s) were added to refs/heads/master by this push:
     new ae62092  Fix classCastException when inputs to union are join (#10950)
ae62092 is described below

commit ae620921df2230dfb05a4396012cb72aae428069
Author: Abhishek Agarwal <[email protected]>
AuthorDate: Tue Mar 9 10:50:26 2021 +0530

    Fix classCastException when inputs to union are join (#10950)
    
    * Fix union queries
    
    * Add tests
---
 .../sql/calcite/rule/DruidUnionDataSourceRule.java |  14 ++-
 .../druid/sql/calcite/rule/DruidUnionRule.java     |   5 +-
 .../apache/druid/sql/calcite/CalciteQueryTest.java | 114 +++++++++++++++++++++
 3 files changed, 130 insertions(+), 3 deletions(-)

diff --git 
a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidUnionDataSourceRule.java
 
b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidUnionDataSourceRule.java
index 49e77ce..fe9c0d4 100644
--- 
a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidUnionDataSourceRule.java
+++ 
b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidUnionDataSourceRule.java
@@ -69,8 +69,7 @@ public class DruidUnionDataSourceRule extends RelOptRule
     final DruidRel<?> firstDruidRel = call.rel(1);
     final DruidQueryRel secondDruidRel = call.rel(2);
 
-    // Can only do UNION ALL of inputs that have compatible schemas (or schema 
mappings).
-    return unionRel.all && isUnionCompatible(firstDruidRel, secondDruidRel);
+    return isCompatible(unionRel, firstDruidRel, secondDruidRel);
   }
 
   @Override
@@ -111,6 +110,17 @@ public class DruidUnionDataSourceRule extends RelOptRule
     }
   }
 
+  // Can only do UNION ALL of inputs that have compatible schemas (or schema 
mappings) and right side
+  // is a simple table scan
+  public static boolean isCompatible(final Union unionRel, final DruidRel<?> 
first, final DruidRel<?> second)
+  {
+    if (!(second instanceof DruidQueryRel)) {
+      return false;
+    }
+
+    return unionRel.all && isUnionCompatible(first, second);
+  }
+
   private static boolean isUnionCompatible(final DruidRel<?> first, final 
DruidRel<?> second)
   {
     final Optional<List<String>> columnNames = 
getColumnNamesIfTableOrUnion(first);
diff --git 
a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidUnionRule.java 
b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidUnionRule.java
index e97ed2b..5863415 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidUnionRule.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidUnionRule.java
@@ -55,7 +55,10 @@ public class DruidUnionRule extends RelOptRule
   public boolean matches(RelOptRuleCall call)
   {
     // Make DruidUnionRule and DruidUnionDataSourceRule mutually exclusive.
-    return !DruidUnionDataSourceRule.instance().matches(call);
+    final Union unionRel = call.rel(0);
+    final DruidRel<?> firstDruidRel = call.rel(1);
+    final DruidRel<?> secondDruidRel = call.rel(2);
+    return !DruidUnionDataSourceRule.isCompatible(unionRel, firstDruidRel, 
secondDruidRel);
   }
 
   @Override
diff --git 
a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java 
b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
index c4279f2..88fd1dc 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java
@@ -4269,6 +4269,120 @@ public class CalciteQueryTest extends 
BaseCalciteQueryTest
   }
 
   @Test
+  public void testUnionAllTwoQueriesLeftQueryIsJoin() throws Exception
+  {
+    cannotVectorize();
+
+    testQuery(
+        "(SELECT COUNT(*) FROM foo INNER JOIN lookup.lookyloo ON foo.dim1 = 
lookyloo.k)  UNION ALL SELECT SUM(cnt) FROM foo",
+        ImmutableList.of(
+            Druids.newTimeseriesQueryBuilder()
+                  .dataSource(
+                      join(
+                      new TableDataSource(CalciteTests.DATASOURCE1),
+                      new LookupDataSource("lookyloo"),
+                      "j0.",
+                      equalsCondition(DruidExpression.fromColumn("dim1"), 
DruidExpression.fromColumn("j0.k")),
+                      JoinType.INNER
+                  ))
+                  .intervals(querySegmentSpec(Filtration.eternity()))
+                  .granularity(Granularities.ALL)
+                  .aggregators(aggregators(new CountAggregatorFactory("a0")))
+                  .context(TIMESERIES_CONTEXT_DEFAULT)
+                  .build(),
+            Druids.newTimeseriesQueryBuilder()
+                  .dataSource(CalciteTests.DATASOURCE1)
+                  .intervals(querySegmentSpec(Filtration.eternity()))
+                  .granularity(Granularities.ALL)
+                  .aggregators(aggregators(new LongSumAggregatorFactory("a0", 
"cnt")))
+                  .context(TIMESERIES_CONTEXT_DEFAULT)
+                  .build()
+        ),
+        ImmutableList.of(new Object[]{1L}, new Object[]{6L})
+    );
+  }
+
+  @Test
+  public void testUnionAllTwoQueriesRightQueryIsJoin() throws Exception
+  {
+    cannotVectorize();
+
+    testQuery(
+        "(SELECT SUM(cnt) FROM foo UNION ALL SELECT COUNT(*) FROM foo INNER 
JOIN lookup.lookyloo ON foo.dim1 = lookyloo.k) ",
+        ImmutableList.of(
+            Druids.newTimeseriesQueryBuilder()
+                  .dataSource(CalciteTests.DATASOURCE1)
+                  .intervals(querySegmentSpec(Filtration.eternity()))
+                  .granularity(Granularities.ALL)
+                  .aggregators(aggregators(new LongSumAggregatorFactory("a0", 
"cnt")))
+                  .context(TIMESERIES_CONTEXT_DEFAULT)
+                  .build(),
+            Druids.newTimeseriesQueryBuilder()
+                  .dataSource(
+                      join(
+                          new TableDataSource(CalciteTests.DATASOURCE1),
+                          new LookupDataSource("lookyloo"),
+                          "j0.",
+                          equalsCondition(DruidExpression.fromColumn("dim1"), 
DruidExpression.fromColumn("j0.k")),
+                          JoinType.INNER
+                      ))
+                  .intervals(querySegmentSpec(Filtration.eternity()))
+                  .granularity(Granularities.ALL)
+                  .aggregators(aggregators(new CountAggregatorFactory("a0")))
+                  .context(TIMESERIES_CONTEXT_DEFAULT)
+                  .build()
+        ),
+        ImmutableList.of(new Object[]{6L}, new Object[]{1L})
+    );
+  }
+
+  @Test
+  public void testUnionAllTwoQueriesBothQueriesAreJoin() throws Exception
+  {
+    cannotVectorize();
+
+    testQuery(
+        "("
+        + "SELECT COUNT(*) FROM foo LEFT JOIN lookup.lookyloo ON foo.dim1 = 
lookyloo.k "
+        + "                               UNION ALL                            
           "
+        + "SELECT COUNT(*) FROM foo INNER JOIN lookup.lookyloo ON foo.dim1 = 
lookyloo.k"
+        + ") ",
+        ImmutableList.of(
+            Druids.newTimeseriesQueryBuilder()
+                  .dataSource(
+                      join(
+                          new TableDataSource(CalciteTests.DATASOURCE1),
+                          new LookupDataSource("lookyloo"),
+                          "j0.",
+                          equalsCondition(DruidExpression.fromColumn("dim1"), 
DruidExpression.fromColumn("j0.k")),
+                          JoinType.LEFT
+                      )
+                  )
+                  .intervals(querySegmentSpec(Filtration.eternity()))
+                  .granularity(Granularities.ALL)
+                  .aggregators(aggregators(new CountAggregatorFactory("a0")))
+                  .context(TIMESERIES_CONTEXT_DEFAULT)
+                  .build(),
+            Druids.newTimeseriesQueryBuilder()
+                  .dataSource(
+                      join(
+                          new TableDataSource(CalciteTests.DATASOURCE1),
+                          new LookupDataSource("lookyloo"),
+                          "j0.",
+                          equalsCondition(DruidExpression.fromColumn("dim1"), 
DruidExpression.fromColumn("j0.k")),
+                          JoinType.INNER
+                      ))
+                  .intervals(querySegmentSpec(Filtration.eternity()))
+                  .granularity(Granularities.ALL)
+                  .aggregators(aggregators(new CountAggregatorFactory("a0")))
+                  .context(TIMESERIES_CONTEXT_DEFAULT)
+                  .build()
+        ),
+        ImmutableList.of(new Object[]{6L}, new Object[]{1L})
+    );
+  }
+
+  @Test
   public void testPruneDeadAggregators() throws Exception
   {
     // Test for ProjectAggregatePruneUnusedCallRule.


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to