This is an automated email from the ASF dual-hosted git repository.

jhyde pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/master by this push:
     new bda7fb1  [CALCITE-4321] JDBC adapter omits FILTER (WHERE ...) 
expressions when generating SQL (Jeremiah Rhoads Hall)
bda7fb1 is described below

commit bda7fb188b01c0334a7bf1662101e259321893a4
Author: Jeremiah Rhoads Hall <[email protected]>
AuthorDate: Mon Oct 5 13:06:39 2020 -0700

    [CALCITE-4321] JDBC adapter omits FILTER (WHERE ...) expressions when 
generating SQL (Jeremiah Rhoads Hall)
    
    Working fix and simple test case for FILTER-WHERE conditions
    not being added to aggregate calls during SQL implementation
    in class SqlImplementor.
    
    Add `supportsAggregateFunctionFilter()` to SqlDialect. If a
    dialect does not support FILTER-WHERE, SqlImplementor instead
    generates an aggregate function with CASE. (This works for
    all aggregate functions where NULL values are ignored.)
    
    Extend the CASE rewrite to support COUNT with zero, one or
    more arguments (Julian Hyde).
    
    Close apache/calcite#2204
---
 .../org/apache/calcite/adapter/jdbc/JdbcRules.java |  4 +
 .../apache/calcite/rel/rel2sql/SqlImplementor.java | 92 ++++++++++++++++------
 .../java/org/apache/calcite/sql/SqlDialect.java    |  6 ++
 .../calcite/sql/dialect/BigQuerySqlDialect.java    |  4 +
 .../calcite/sql/dialect/HsqldbSqlDialect.java      |  4 +
 .../calcite/rel/rel2sql/RelToSqlConverterTest.java | 62 ++++++++++++++-
 core/src/test/resources/sql/agg.iq                 | 17 ++++
 7 files changed, 162 insertions(+), 27 deletions(-)

diff --git a/core/src/main/java/org/apache/calcite/adapter/jdbc/JdbcRules.java 
b/core/src/main/java/org/apache/calcite/adapter/jdbc/JdbcRules.java
index 23fcd43..0f22db3 100644
--- a/core/src/main/java/org/apache/calcite/adapter/jdbc/JdbcRules.java
+++ b/core/src/main/java/org/apache/calcite/adapter/jdbc/JdbcRules.java
@@ -704,6 +704,10 @@ public class JdbcRules {
           throw new InvalidRelException("cannot implement aggregate function "
               + aggCall.getAggregation());
         }
+        if (aggCall.hasFilter() && !dialect.supportsAggregateFunctionFilter()) 
{
+          throw new InvalidRelException("dialect does not support aggregate "
+              + "functions FILTER clauses");
+        }
       }
     }
 
diff --git 
a/core/src/main/java/org/apache/calcite/rel/rel2sql/SqlImplementor.java 
b/core/src/main/java/org/apache/calcite/rel/rel2sql/SqlImplementor.java
index 4dce253..61fa55f 100644
--- a/core/src/main/java/org/apache/calcite/rel/rel2sql/SqlImplementor.java
+++ b/core/src/main/java/org/apache/calcite/rel/rel2sql/SqlImplementor.java
@@ -18,6 +18,7 @@ package org.apache.calcite.rel.rel2sql;
 
 import org.apache.calcite.linq4j.Ord;
 import org.apache.calcite.linq4j.tree.Expressions;
+import org.apache.calcite.rel.RelCollation;
 import org.apache.calcite.rel.RelFieldCollation;
 import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.SingleRel;
@@ -1140,44 +1141,83 @@ public abstract class SqlImplementor {
 
     /** Converts a call to an aggregate function to an expression. */
     public SqlNode toSql(AggregateCall aggCall) {
-      final SqlOperator op = aggCall.getAggregation();
-      final List<SqlNode> operandList = Expressions.list();
-      for (int arg : aggCall.getArgList()) {
-        operandList.add(field(arg));
-      }
-
-      if ((op instanceof SqlCountAggFunction) && operandList.isEmpty()) {
-        // If there is no parameter in "count" function, add a star identifier 
to it
-        operandList.add(SqlIdentifier.STAR);
-      }
+      return toSql(aggCall.getAggregation(), aggCall.isDistinct(),
+          Util.transform(aggCall.getArgList(), this::field),
+          aggCall.filterArg, aggCall.collation);
+    }
 
+    /** Converts a call to an aggregate function, with a given list of 
operands,
+     * to an expression. */
+    private SqlCall toSql(SqlOperator op, boolean distinct,
+        List<SqlNode> operandList, int filterArg, RelCollation collation) {
       final SqlLiteral qualifier =
-          aggCall.isDistinct() ? SqlSelectKeyword.DISTINCT.symbol(POS) : null;
-      final SqlNode[] operands = operandList.toArray(new SqlNode[0]);
-      List<SqlNode> orderByList = Expressions.list();
-      for (RelFieldCollation field : aggCall.collation.getFieldCollations()) {
-        addOrderItem(orderByList, field);
-      }
-      SqlNodeList orderList = new SqlNodeList(orderByList, POS);
+          distinct ? SqlSelectKeyword.DISTINCT.symbol(POS) : null;
       if (op instanceof SqlSumEmptyIsZeroAggFunction) {
-        final SqlNode node =
-            withOrder(
-                SqlStdOperatorTable.SUM.createCall(qualifier, POS, operands),
-                orderList);
+        final SqlNode node = toSql(SqlStdOperatorTable.SUM, distinct,
+            operandList, filterArg, collation);
         return SqlStdOperatorTable.COALESCE.createCall(POS, node,
             SqlLiteral.createExactNumeric("0", POS));
+      }
+
+      // Handle filter on dialects that do support FILTER by generating CASE.
+      if (filterArg >= 0 && !dialect.supportsAggregateFunctionFilter()) {
+        // SUM(x) FILTER(WHERE b)  ==>  SUM(CASE WHEN b THEN x END)
+        // COUNT(*) FILTER(WHERE b)  ==>  COUNT(CASE WHEN b THEN 1 END)
+        // COUNT(x) FILTER(WHERE b)  ==>  COUNT(CASE WHEN b THEN x END)
+        // COUNT(x, y) FILTER(WHERE b)  ==>  COUNT(CASE WHEN b THEN x END, y)
+        final SqlNodeList whenList = SqlNodeList.of(field(filterArg));
+        final SqlNodeList thenList =
+            SqlNodeList.of(operandList.isEmpty()
+                ? SqlLiteral.createExactNumeric("1", POS)
+                : operandList.get(0));
+        final SqlNode elseList = SqlLiteral.createNull(POS);
+        final SqlCall caseCall =
+            SqlStdOperatorTable.CASE.createCall(null, POS, null, whenList,
+                thenList, elseList);
+        final List<SqlNode> newOperandList = new ArrayList<>();
+        newOperandList.add(caseCall);
+        if (operandList.size() > 1) {
+          newOperandList.addAll(Util.skip(operandList));
+        }
+        return toSql(op, distinct, newOperandList, -1, collation);
+      }
+
+      final SqlNode[] operands;
+      if (op instanceof SqlCountAggFunction && operandList.isEmpty()) {
+        // If there is no parameter in "count" function, add a star identifier 
to it
+        operands = new SqlNode[] {SqlIdentifier.STAR};
       } else {
-        return withOrder(op.createCall(qualifier, POS, operands), orderList);
+        operands = operandList.toArray(new SqlNode[0]);
       }
+      final SqlCall call =
+          op.createCall(qualifier, POS, operands);
+
+      // Handle filter by generating FILTER (WHERE ...)
+      final SqlCall call2;
+      if (filterArg < 0) {
+        call2 = call;
+      } else {
+        assert dialect.supportsAggregateFunctionFilter(); // we checked above
+        call2 = SqlStdOperatorTable.FILTER.createCall(POS, call,
+            field(filterArg));
+      }
+
+      // Handle collation
+      return withOrder(call2, collation);
     }
 
     /** Wraps a call in a {@link SqlKind#WITHIN_GROUP} call, if
-     * {@code orderList} is non-empty. */
-    private SqlNode withOrder(SqlCall call, SqlNodeList orderList) {
-      if (orderList == null || orderList.size() == 0) {
+     * {@code collation} is non-empty. */
+    private SqlCall withOrder(SqlCall call, RelCollation collation) {
+      if (collation.getFieldCollations().isEmpty()) {
         return call;
       }
-      return SqlStdOperatorTable.WITHIN_GROUP.createCall(POS, call, orderList);
+      final List<SqlNode> orderByList = new ArrayList<>();
+      for (RelFieldCollation field : collation.getFieldCollations()) {
+        addOrderItem(orderByList, field);
+      }
+      return SqlStdOperatorTable.WITHIN_GROUP.createCall(POS, call,
+          new SqlNodeList(orderByList, POS));
     }
 
     /** Converts a collation to an ORDER BY item. */
diff --git a/core/src/main/java/org/apache/calcite/sql/SqlDialect.java 
b/core/src/main/java/org/apache/calcite/sql/SqlDialect.java
index 0c8229e..f3ed086 100644
--- a/core/src/main/java/org/apache/calcite/sql/SqlDialect.java
+++ b/core/src/main/java/org/apache/calcite/sql/SqlDialect.java
@@ -714,6 +714,12 @@ public class SqlDialect {
     return false;
   }
 
+  /** Returns whether this dialect supports the use of FILTER clauses for
+   * aggregate functions. e.g. {@code COUNT(*) FILTER (WHERE a = 2)}. */
+  public boolean supportsAggregateFunctionFilter() {
+    return true;
+  }
+
   /** Returns whether this dialect supports window functions (OVER clause). */
   public boolean supportsWindowFunctions() {
     return true;
diff --git 
a/core/src/main/java/org/apache/calcite/sql/dialect/BigQuerySqlDialect.java 
b/core/src/main/java/org/apache/calcite/sql/dialect/BigQuerySqlDialect.java
index cf125cb..33eb123 100644
--- a/core/src/main/java/org/apache/calcite/sql/dialect/BigQuerySqlDialect.java
+++ b/core/src/main/java/org/apache/calcite/sql/dialect/BigQuerySqlDialect.java
@@ -117,6 +117,10 @@ public class BigQuerySqlDialect extends SqlDialect {
     return false;
   }
 
+  @Override public boolean supportsAggregateFunctionFilter() {
+    return false;
+  }
+
   @Override public @Nonnull SqlParser.Config configureParser(
       SqlParser.Config configBuilder) {
     return super.configureParser(configBuilder)
diff --git 
a/core/src/main/java/org/apache/calcite/sql/dialect/HsqldbSqlDialect.java 
b/core/src/main/java/org/apache/calcite/sql/dialect/HsqldbSqlDialect.java
index 6b105fc..bda481c 100644
--- a/core/src/main/java/org/apache/calcite/sql/dialect/HsqldbSqlDialect.java
+++ b/core/src/main/java/org/apache/calcite/sql/dialect/HsqldbSqlDialect.java
@@ -47,6 +47,10 @@ public class HsqldbSqlDialect extends SqlDialect {
     return false;
   }
 
+  @Override public boolean supportsAggregateFunctionFilter() {
+    return false;
+  }
+
   @Override public boolean supportsWindowFunctions() {
     return false;
   }
diff --git 
a/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java 
b/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java
index dc75751..201af86 100644
--- 
a/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java
+++ 
b/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java
@@ -203,6 +203,66 @@ class RelToSqlConverterTest {
     sql(query).ok("SELECT *\nFROM \"foodmart\".\"product\"");
   }
 
+  @Test void testAggregateFilterWhereToSqlFromProductTable() {
+    String query = "select\n"
+        + "  sum(\"shelf_width\") filter (where \"net_weight\" > 0),\n"
+        + "  sum(\"shelf_width\")\n"
+        + "from \"foodmart\".\"product\"\n"
+        + "where \"product_id\" > 0\n"
+        + "group by \"product_id\"";
+    final String expected = "SELECT"
+        + " SUM(\"shelf_width\") FILTER (WHERE \"net_weight\" > 0 IS TRUE),"
+        + " SUM(\"shelf_width\")\n"
+        + "FROM \"foodmart\".\"product\"\n"
+        + "WHERE \"product_id\" > 0\n"
+        + "GROUP BY \"product_id\"";
+    sql(query).ok(expected);
+  }
+
+  @Test void testAggregateFilterWhereToBigQuerySqlFromProductTable() {
+    String query = "select\n"
+        + "  sum(\"shelf_width\") filter (where \"net_weight\" > 0),\n"
+        + "  sum(\"shelf_width\")\n"
+        + "from \"foodmart\".\"product\"\n"
+        + "where \"product_id\" > 0\n"
+        + "group by \"product_id\"";
+    final String expected = "SELECT SUM(CASE WHEN net_weight > 0 IS TRUE"
+        + " THEN shelf_width ELSE NULL END), "
+        + "SUM(shelf_width)\n"
+        + "FROM foodmart.product\n"
+        + "WHERE product_id > 0\n"
+        + "GROUP BY product_id";
+    sql(query).withBigQuery().ok(expected);
+  }
+
+  @Test void testPivotToSqlFromProductTable() {
+    String query = "select * from (\n"
+        + "  select \"shelf_width\", \"net_weight\", \"product_id\"\n"
+        + "  from \"foodmart\".\"product\")\n"
+        + "  pivot (sum(\"shelf_width\") as w, count(*) as c\n"
+        + "    for (\"product_id\") in (10, 20))";
+    final String expected = "SELECT \"net_weight\","
+        + " SUM(\"shelf_width\") FILTER (WHERE \"product_id\" = 10) AS 
\"10_W\","
+        + " COUNT(*) FILTER (WHERE \"product_id\" = 10) AS \"10_C\","
+        + " SUM(\"shelf_width\") FILTER (WHERE \"product_id\" = 20) AS 
\"20_W\","
+        + " COUNT(*) FILTER (WHERE \"product_id\" = 20) AS \"20_C\"\n"
+        + "FROM \"foodmart\".\"product\"\n"
+        + "GROUP BY \"net_weight\"";
+    // BigQuery does not support FILTER, so we generate CASE around the
+    // arguments to the aggregate functions.
+    final String expectedBigQuery = "SELECT net_weight,"
+        + " SUM(CASE WHEN product_id = 10 "
+        + "THEN shelf_width ELSE NULL END) AS `10_W`,"
+        + " COUNT(CASE WHEN product_id = 10 THEN 1 ELSE NULL END) AS `10_C`,"
+        + " SUM(CASE WHEN product_id = 20 "
+        + "THEN shelf_width ELSE NULL END) AS `20_W`,"
+        + " COUNT(CASE WHEN product_id = 20 THEN 1 ELSE NULL END) AS `20_C`\n"
+        + "FROM foodmart.product\n"
+        + "GROUP BY net_weight";
+    sql(query).ok(expected)
+        .withBigQuery().ok(expectedBigQuery);
+  }
+
   @Test void testSimpleSelectQueryFromProductTable() {
     String query = "select \"product_id\", \"product_class_id\" from 
\"product\"";
     final String expected = "SELECT \"product_id\", \"product_class_id\"\n"
@@ -4770,11 +4830,11 @@ class RelToSqlConverterTest {
   }
 
   @Test void testWithinGroup4() {
-    // filter in AggregateCall is not unparsed
     final String query = "select \"product_class_id\", collect(\"net_weight\") 
"
         + "within group (order by \"net_weight\" desc) filter (where 
\"net_weight\" > 0)"
         + "from \"product\" group by \"product_class_id\"";
     final String expected = "SELECT \"product_class_id\", 
COLLECT(\"net_weight\") "
+        + "FILTER (WHERE \"net_weight\" > 0 IS TRUE) "
         + "WITHIN GROUP (ORDER BY \"net_weight\" DESC)\n"
         + "FROM \"foodmart\".\"product\"\n"
         + "GROUP BY \"product_class_id\"";
diff --git a/core/src/test/resources/sql/agg.iq 
b/core/src/test/resources/sql/agg.iq
index 16cb93f..3653b3a 100644
--- a/core/src/test/resources/sql/agg.iq
+++ b/core/src/test/resources/sql/agg.iq
@@ -1380,6 +1380,23 @@ from (values 0, null, 0, 1) as t(x);
 
 !ok
 
+# Composite COUNT and FILTER
+select count(*) as c,
+  count(*) filter (where z > 1) as cf,
+  count(x) as cx,
+  count(x) filter (where z > 1) as cxf,
+  count(x, y) as cxy,
+  count(x, y) filter (where z > 1) as cxyf
+from (values (1, 1, 1), (2, 2, 2), (3, null, 3), (null, 4, 4)) as t(x, y, z);
++---+----+----+-----+-----+------+
+| C | CF | CX | CXF | CXY | CXYF |
++---+----+----+-----+-----+------+
+| 4 |  3 |  3 |   2 |   2 |    1 |
++---+----+----+-----+-----+------+
+(1 row)
+
+!ok
+
 # [CALCITE-1293] Bad code generated when argument to COUNT(DISTINCT) is a
 # GROUP BY column
 select count(distinct deptno) as cd, count(*) as c

Reply via email to