[FLINK-4832] [table] Fix global aggregation of empty tables (Count/Sum = 0).
- Fix injects a union with a null record before the global aggregation. This closes #2840 Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/ecfb5b5f Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/ecfb5b5f Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/ecfb5b5f Branch: refs/heads/master Commit: ecfb5b5f6fd6bf1555c7240d77dd9aca982f4416 Parents: 0bb6847 Author: Anton Mushin <[email protected]> Authored: Mon Nov 21 15:49:41 2016 +0400 Committer: Fabian Hueske <[email protected]> Committed: Tue Nov 29 13:30:51 2016 +0100 ---------------------------------------------------------------------- .../api/table/plan/rules/FlinkRuleSets.scala | 1 + .../rules/dataSet/DataSetAggregateRule.scala | 6 + .../DataSetAggregateWithNullValuesRule.scala | 96 +++++++ .../scala/batch/sql/AggregationsITCase.scala | 39 +++ .../flink/api/table/AggregationTest.scala | 261 +++++++++++++++++++ .../flink/api/table/utils/TableTestBase.scala | 9 +- 6 files changed, 410 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/ecfb5b5f/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala index 5653083..26c025e 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala @@ -98,6 +98,7 @@ object FlinkRuleSets { // translate to Flink DataSet nodes DataSetAggregateRule.INSTANCE, + DataSetAggregateWithNullValuesRule.INSTANCE, DataSetCalcRule.INSTANCE, DataSetJoinRule.INSTANCE, DataSetScanRule.INSTANCE, http://git-wip-us.apache.org/repos/asf/flink/blob/ecfb5b5f/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetAggregateRule.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetAggregateRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetAggregateRule.scala index 72ed27e..0311c48 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetAggregateRule.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetAggregateRule.scala @@ -37,6 +37,12 @@ class DataSetAggregateRule override def matches(call: RelOptRuleCall): Boolean = { val agg: LogicalAggregate = call.rel(0).asInstanceOf[LogicalAggregate] + //for non grouped agg sets should attach null row to source data + //need apply DataSetAggregateWithNullValuesRule + if (agg.getGroupSet.isEmpty) { + return false + } + // check if we have distinct aggregates val distinctAggs = agg.getAggCallList.exists(_.isDistinct) if (distinctAggs) { http://git-wip-us.apache.org/repos/asf/flink/blob/ecfb5b5f/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala new file mode 100644 index 0000000..54cb8d1 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.table.plan.rules.dataSet + +import org.apache.calcite.plan._ +import scala.collection.JavaConversions._ +import com.google.common.collect.ImmutableList +import org.apache.calcite.rel.RelNode +import org.apache.calcite.rel.convert.ConverterRule +import org.apache.calcite.rel.logical.{LogicalValues, LogicalUnion, LogicalAggregate} +import org.apache.calcite.rex.RexLiteral +import org.apache.flink.api.table._ +import org.apache.flink.api.table.plan.nodes.dataset.{DataSetAggregate, DataSetConvention} + +/** + * Rule for insert [[Row]] with null records into a [[DataSetAggregate]] + * Rule apply for non grouped aggregate query + */ +class DataSetAggregateWithNullValuesRule + extends ConverterRule( + classOf[LogicalAggregate], + Convention.NONE, + DataSetConvention.INSTANCE, + "DataSetAggregateWithNullValuesRule") +{ + + override def matches(call: RelOptRuleCall): Boolean = { + val agg: LogicalAggregate = call.rel(0).asInstanceOf[LogicalAggregate] + + //for grouped agg sets shouldn't attach of null row + //need apply other rules. e.g. [[DataSetAggregateRule]] + if (!agg.getGroupSet.isEmpty) { + return false + } + + // check if we have distinct aggregates + val distinctAggs = agg.getAggCallList.exists(_.isDistinct) + if (distinctAggs) { + throw TableException("DISTINCT aggregates are currently not supported.") + } + + // check if we have grouping sets + val groupSets = agg.getGroupSets.size() == 0 || agg.getGroupSets.get(0) != agg.getGroupSet + if (groupSets || agg.indicator) { + throw TableException("GROUPING SETS are currently not supported.") + } + !distinctAggs && !groupSets && !agg.indicator + } + + override def convert(rel: RelNode): RelNode = { + val agg: LogicalAggregate = rel.asInstanceOf[LogicalAggregate] + val traitSet: RelTraitSet = rel.getTraitSet.replace(DataSetConvention.INSTANCE) + val cluster: RelOptCluster = rel.getCluster + + val fieldTypes = agg.getInput.getRowType.getFieldList.map(_.getType) + val nullLiterals: ImmutableList[ImmutableList[RexLiteral]] = + ImmutableList.of(ImmutableList.copyOf[RexLiteral]( + for (fieldType <- fieldTypes) + yield { + cluster.getRexBuilder. + makeLiteral(null, fieldType, false).asInstanceOf[RexLiteral] + })) + + val logicalValues = LogicalValues.create(cluster, agg.getInput.getRowType, nullLiterals) + val logicalUnion = LogicalUnion.create(List(logicalValues, agg.getInput), true) + + new DataSetAggregate( + cluster, + traitSet, + RelOptRule.convert(logicalUnion, DataSetConvention.INSTANCE), + agg.getNamedAggCalls, + rel.getRowType, + agg.getInput.getRowType, + agg.getGroupSet.toArray + ) + } +} + +object DataSetAggregateWithNullValuesRule { + val INSTANCE: RelOptRule = new DataSetAggregateWithNullValuesRule +} http://git-wip-us.apache.org/repos/asf/flink/blob/ecfb5b5f/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/AggregationsITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/AggregationsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/AggregationsITCase.scala index 2dce751..35bb7dc 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/AggregationsITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/batch/sql/AggregationsITCase.scala @@ -258,4 +258,43 @@ class AggregationsITCase( // must fail. grouping sets are not supported tEnv.sql(sqlQuery).toDataSet[Row] } + + @Test + def testAggregateEmptyDataSets(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val sqlQuery = "SELECT avg(a), sum(a), count(b) " + + "FROM MyTable where a = 4 group by a" + + val sqlQuery2 = "SELECT avg(a), sum(a), count(b) " + + "FROM MyTable where a = 4" + + val sqlQuery3 = "SELECT avg(a), sum(a), count(b) " + + "FROM MyTable" + + val ds = env.fromElements( + (1: Byte, 1: Short), + (2: Byte, 2: Short)) + .toTable(tEnv, 'a, 'b) + + tEnv.registerTable("MyTable", ds) + + val result = tEnv.sql(sqlQuery) + val result2 = tEnv.sql(sqlQuery2) + val result3 = tEnv.sql(sqlQuery3) + + val results = result.toDataSet[Row].collect() + val expected = Seq.empty + val results2 = result2.toDataSet[Row].collect() + val expected2 = "null,null,0" + val results3 = result3.toDataSet[Row].collect() + val expected3 = "1,3,2" + + assert(results.equals(expected), + "Empty result is expected for grouped set, but actual: " + results) + TestBaseUtils.compareResultAsText(results2.asJava, expected2) + TestBaseUtils.compareResultAsText(results3.asJava, expected3) + } } http://git-wip-us.apache.org/repos/asf/flink/blob/ecfb5b5f/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/AggregationTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/AggregationTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/AggregationTest.scala new file mode 100644 index 0000000..6c9d2e8 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/AggregationTest.scala @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.table + +import org.apache.flink.api.scala._ +import org.apache.flink.api.scala.table._ +import org.apache.flink.api.table.utils.TableTestBase +import org.apache.flink.api.table.utils.TableTestUtil._ +import org.junit.Test + +/** + * Test for testing aggregate plans. + */ +class AggregationTest extends TableTestBase { + + @Test + def testAggregateQueryBatchSQL(): Unit = { + val util = batchTestUtil() + util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c) + + val sqlQuery = "SELECT avg(a), sum(b), count(c) FROM MyTable" + + val setValues = unaryNode( + "DataSetValues", + batchTableNode(0), + tuples(List(null,null,null)), + term("values","a","b","c") + ) + val union = unaryNode( + "DataSetUnion", + setValues, + term("union","a","b","c") + ) + + val aggregate = unaryNode( + "DataSetAggregate", + union, + term("select", + "AVG(a) AS EXPR$0", + "SUM(b) AS EXPR$1", + "COUNT(c) AS EXPR$2") + ) + util.verifySql(sqlQuery, aggregate) + } + + @Test + def testAggregateWithFilterQueryBatchSQL(): Unit = { + val util = batchTestUtil() + util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c) + + val sqlQuery = "SELECT avg(a), sum(b), count(c) FROM MyTable WHERE a = 1" + + val calcNode = unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", "a", "b", "c"), + term("where", "=(a, 1)") + ) + + val setValues = unaryNode( + "DataSetValues", + calcNode, + tuples(List(null,null,null)), + term("values","a","b","c") + ) + + val union = unaryNode( + "DataSetUnion", + setValues, + term("union","a","b","c") + ) + + val aggregate = unaryNode( + "DataSetAggregate", + union, + term("select", + "AVG(a) AS EXPR$0", + "SUM(b) AS EXPR$1", + "COUNT(c) AS EXPR$2") + ) + util.verifySql(sqlQuery, aggregate) + } + + @Test + def testAggregateGroupQueryBatchSQL(): Unit = { + val util = batchTestUtil() + util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c) + + val sqlQuery = "SELECT avg(a), sum(b), count(c) FROM MyTable GROUP BY a" + + val aggregate = unaryNode( + "DataSetAggregate", + batchTableNode(0), + term("groupBy", "a"), + term("select", + "a", + "AVG(a) AS EXPR$0", + "SUM(b) AS EXPR$1", + "COUNT(c) AS EXPR$2") + ) + val expected = unaryNode( + "DataSetCalc", + aggregate, + term("select", + "EXPR$0", + "EXPR$1", + "EXPR$2") + ) + util.verifySql(sqlQuery, expected) + } + + @Test + def testAggregateGroupWithFilterQueryBatchSQL(): Unit = { + val util = batchTestUtil() + util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c) + + val sqlQuery = "SELECT avg(a), sum(b), count(c) FROM MyTable WHERE a = 1 GROUP BY a" + + val calcNode = unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select","a", "b", "c") , + term("where","=(a, 1)") + ) + + val aggregate = unaryNode( + "DataSetAggregate", + calcNode, + term("groupBy", "a"), + term("select", + "a", + "AVG(a) AS EXPR$0", + "SUM(b) AS EXPR$1", + "COUNT(c) AS EXPR$2") + ) + val expected = unaryNode( + "DataSetCalc", + aggregate, + term("select", + "EXPR$0", + "EXPR$1", + "EXPR$2") + ) + util.verifySql(sqlQuery, expected) + } + + @Test + def testAggregateGroupWithFilterTableApi(): Unit = { + + val util = batchTestUtil() + val sourceTable = util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c) + + val resultTable = sourceTable.groupBy('a) + .select('a, 'a.avg, 'b.sum, 'c.count) + .where('a === 1) + + val calcNode = unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", "a", "b", "c"), + term("where", "=(a, 1)") + ) + + val expected = unaryNode( + "DataSetAggregate", + calcNode, + term("groupBy", "a"), + term("select", + "a", + "AVG(a) AS TMP_0", + "SUM(b) AS TMP_1", + "COUNT(c) AS TMP_2") + ) + + util.verifyTable(resultTable,expected) + } + + @Test + def testAggregateTableApi(): Unit = { + val util = batchTestUtil() + val sourceTable = util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c) + val resultTable = sourceTable.select('a.avg,'b.sum,'c.count) + + val setValues = unaryNode( + "DataSetValues", + batchTableNode(0), + tuples(List(null,null,null)), + term("values","a","b","c") + ) + val union = unaryNode( + "DataSetUnion", + setValues, + term("union","a","b","c") + ) + + val expected = unaryNode( + "DataSetAggregate", + union, + term("select", + "AVG(a) AS TMP_0", + "SUM(b) AS TMP_1", + "COUNT(c) AS TMP_2") + ) + util.verifyTable(resultTable, expected) + } + + @Test + def testAggregateWithFilterTableApi(): Unit = { + val util = batchTestUtil() + val sourceTable = util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c) + + val resultTable = sourceTable.select('a,'b,'c).where('a === 1) + .select('a.avg,'b.sum,'c.count) + + val calcNode = unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", "a", "b", "c"), + term("where", "=(a, 1)") + ) + + val setValues = unaryNode( + "DataSetValues", + calcNode, + tuples(List(null,null,null)), + term("values","a","b","c") + ) + + val union = unaryNode( + "DataSetUnion", + setValues, + term("union","a","b","c") + ) + + val expected = unaryNode( + "DataSetAggregate", + union, + term("select", + "AVG(a) AS TMP_0", + "SUM(b) AS TMP_1", + "COUNT(c) AS TMP_2") + ) + + util.verifyTable(resultTable, expected) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/ecfb5b5f/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala index 2ea15a0..539bb61 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala @@ -70,20 +70,25 @@ object TableTestUtil { def unaryNode(node: String, input: String, term: String*): String = { s"""$node(${term.mkString(", ")}) |$input - |""".stripMargin + |""".stripMargin.stripLineEnd } def binaryNode(node: String, left: String, right: String, term: String*): String = { s"""$node(${term.mkString(", ")}) |$left |$right - |""".stripMargin + |""".stripMargin.stripLineEnd } def term(term: AnyRef, value: AnyRef*): String = { s"$term=[${value.mkString(", ")}]" } + def tuples(value:List[AnyRef]*): String={ + val listValues = value.map( listValue => s"{ ${listValue.mkString(", ")} }") + term("tuples","[" + listValues.mkString(", ") + "]") + } + def batchTableNode(idx: Int): String = { s"DataSetScan(table=[[_DataSetTable_$idx]])" }
