[FLINK-5266] [table] Inject projection of unused fields before aggregations.
This closes #2961. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/15e7f0a8 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/15e7f0a8 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/15e7f0a8 Branch: refs/heads/master Commit: 15e7f0a8c7fd161d5847e7b2afae35b212ea23f0 Parents: 5dab934 Author: Kurt Young <[email protected]> Authored: Thu Dec 8 10:35:43 2016 +0800 Committer: Fabian Hueske <[email protected]> Committed: Thu Dec 15 11:36:40 2016 +0100 ---------------------------------------------------------------------- .../api/table/plan/ProjectionTranslator.scala | 105 ++++-- .../org/apache/flink/api/table/table.scala | 83 ++--- .../org/apache/flink/api/table/windows.scala | 2 +- .../scala/stream/table/GroupWindowTest.scala | 120 +++++-- .../api/table/plan/FieldProjectionTest.scala | 317 +++++++++++++++++++ .../flink/api/table/utils/TableTestBase.scala | 4 + 6 files changed, 551 insertions(+), 80 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/15e7f0a8/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala index 22b77b4..a25c402 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/ProjectionTranslator.scala @@ -29,31 +29,19 @@ object ProjectionTranslator { /** * Extracts and deduplicates all aggregation and window property expressions (zero, one, or more) - * from all expressions and replaces the original expressions by field accesses expressions. + * from the given expressions. * - * @param exprs a list of expressions to convert + * @param exprs a list of expressions to extract * @param tableEnv the TableEnvironment - * @return a Tuple3, the first field contains the converted expressions, the second field the - * extracted and deduplicated aggregations, and the third field the extracted and - * deduplicated window properties. + * @return a Tuple2, the first field contains the extracted and deduplicated aggregations, + * and the second field contains the extracted and deduplicated window properties. */ def extractAggregationsAndProperties( exprs: Seq[Expression], - tableEnv: TableEnvironment) - : (Seq[NamedExpression], Seq[NamedExpression], Seq[NamedExpression]) = { - - val (aggNames, propNames) = - exprs.foldLeft( (Map[Expression, String](), Map[Expression, String]()) ) { - (x, y) => identifyAggregationsAndProperties(y, tableEnv, x._1, x._2) - } - - val replaced = exprs - .map(replaceAggregationsAndProperties(_, tableEnv, aggNames, propNames)) - .map(UnresolvedAlias) - val aggs = aggNames.map( a => Alias(a._1, a._2)).toSeq - val props = propNames.map( p => Alias(p._1, p._2)).toSeq - - (replaced, aggs, props) + tableEnv: TableEnvironment): (Map[Expression, String], Map[Expression, String]) = { + exprs.foldLeft((Map[Expression, String](), Map[Expression, String]())) { + (x, y) => identifyAggregationsAndProperties(y, tableEnv, x._1, x._2) + } } /** Identifies and deduplicates aggregation functions and window properties. */ @@ -106,7 +94,24 @@ object ProjectionTranslator { } } - /** Replaces aggregations and projections by named field references. */ + /** + * Replaces expressions with deduplicated aggregations and properties. + * + * @param exprs a list of expressions to replace + * @param tableEnv the TableEnvironment + * @param aggNames the deduplicated aggregations + * @param propNames the deduplicated properties + * @return a list of replaced expressions + */ + def replaceAggregationsAndProperties( + exprs: Seq[Expression], + tableEnv: TableEnvironment, + aggNames: Map[Expression, String], + propNames: Map[Expression, String]): Seq[NamedExpression] = { + exprs.map(replaceAggregationsAndProperties(_, tableEnv, aggNames, propNames)) + .map(UnresolvedAlias) + } + private def replaceAggregationsAndProperties( exp: Expression, tableEnv: TableEnvironment, @@ -197,4 +202,62 @@ object ProjectionTranslator { } projectList } + + /** + * Extract all field references from the given expressions. + * + * @param exprs a list of expressions to extract + * @return a list of field references extracted from the given expressions + */ + def extractFieldReferences(exprs: Seq[Expression]): Seq[NamedExpression] = { + exprs.foldLeft(Set[NamedExpression]()) { + (fieldReferences, expr) => identifyFieldReferences(expr, fieldReferences) + }.toSeq + } + + private def identifyFieldReferences( + expr: Expression, + fieldReferences: Set[NamedExpression]): Set[NamedExpression] = expr match { + + case f: UnresolvedFieldReference => + fieldReferences + UnresolvedAlias(f) + + case b: BinaryExpression => + val l = identifyFieldReferences(b.left, fieldReferences) + identifyFieldReferences(b.right, l) + + // Functions calls + case c @ Call(name, args) => + args.foldLeft(fieldReferences) { + (fieldReferences, expr) => identifyFieldReferences(expr, fieldReferences) + } + case sfc @ ScalarFunctionCall(clazz, args) => + args.foldLeft(fieldReferences) { + (fieldReferences, expr) => identifyFieldReferences(expr, fieldReferences) + } + + // array constructor + case c @ ArrayConstructor(args) => + args.foldLeft(fieldReferences) { + (fieldReferences, expr) => identifyFieldReferences(expr, fieldReferences) + } + + // ignore fields from window property + case w : WindowProperty => + fieldReferences + + // keep this case after all unwanted unary expressions + case u: UnaryExpression => + identifyFieldReferences(u.child, fieldReferences) + + // General expression + case e: Expression => + e.productIterator.foldLeft(fieldReferences) { + (fieldReferences, expr) => expr match { + case e: Expression => identifyFieldReferences(e, fieldReferences) + case _ => fieldReferences + } + } + } + } http://git-wip-us.apache.org/repos/asf/flink/blob/15e7f0a8/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala index b74ddb0..94c8e8c 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/table.scala @@ -20,10 +20,9 @@ package org.apache.flink.api.table import org.apache.calcite.rel.RelNode import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.operators.join.JoinType -import org.apache.flink.api.table.plan.logical.Minus -import org.apache.flink.api.table.expressions.{Alias, Asc, Call, Expression, ExpressionParser, Ordering, TableFunctionCall} +import org.apache.flink.api.table.expressions._ import org.apache.flink.api.table.plan.ProjectionTranslator._ -import org.apache.flink.api.table.plan.logical._ +import org.apache.flink.api.table.plan.logical.{Minus, _} import org.apache.flink.api.table.sinks.TableSink import scala.collection.JavaConverters._ @@ -77,21 +76,27 @@ class Table( * }}} */ def select(fields: Expression*): Table = { - val expandedFields = expandProjectList(fields, logicalPlan, tableEnv) - val (projection, aggs, props) = extractAggregationsAndProperties(expandedFields, tableEnv) - - if (props.nonEmpty) { + val (aggNames, propNames) = extractAggregationsAndProperties(expandedFields, tableEnv) + if (propNames.nonEmpty) { throw ValidationException("Window properties can only be used on windowed tables.") } - if (aggs.nonEmpty) { + if (aggNames.nonEmpty) { + val projectsOnAgg = replaceAggregationsAndProperties( + expandedFields, tableEnv, aggNames, propNames) + val projectFields = extractFieldReferences(expandedFields) + new Table(tableEnv, - Project(projection, - Aggregate(Nil, aggs, logicalPlan).validate(tableEnv)).validate(tableEnv)) + Project(projectsOnAgg, + Aggregate(Nil, aggNames.map(a => Alias(a._1, a._2)).toSeq, + Project(projectFields, logicalPlan).validate(tableEnv) + ).validate(tableEnv) + ).validate(tableEnv) + ) } else { new Table(tableEnv, - Project(projection, logicalPlan).validate(tableEnv)) + Project(expandedFields.map(UnresolvedAlias), logicalPlan).validate(tableEnv)) } } @@ -806,24 +811,21 @@ class GroupedTable( * }}} */ def select(fields: Expression*): Table = { - - val (projection, aggs, props) = extractAggregationsAndProperties(fields, table.tableEnv) - - if (props.nonEmpty) { + val (aggNames, propNames) = extractAggregationsAndProperties(fields, table.tableEnv) + if (propNames.nonEmpty) { throw ValidationException("Window properties can only be used on windowed tables.") } - val logical = - Project( - projection, - Aggregate( - groupKey, - aggs, - table.logicalPlan - ).validate(table.tableEnv) - ).validate(table.tableEnv) + val projectsOnAgg = replaceAggregationsAndProperties( + fields, table.tableEnv, aggNames, propNames) + val projectFields = extractFieldReferences(fields ++ groupKey) - new Table(table.tableEnv, logical) + new Table(table.tableEnv, + Project(projectsOnAgg, + Aggregate(groupKey, aggNames.map(a => Alias(a._1, a._2)).toSeq, + Project(projectFields, table.logicalPlan).validate(table.tableEnv) + ).validate(table.tableEnv) + ).validate(table.tableEnv)) } /** @@ -877,24 +879,29 @@ class GroupWindowedTable( * }}} */ def select(fields: Expression*): Table = { + val (aggNames, propNames) = extractAggregationsAndProperties(fields, table.tableEnv) + val projectsOnAgg = replaceAggregationsAndProperties( + fields, table.tableEnv, aggNames, propNames) + + val projectFields = (table.tableEnv, window) match { + // event time can be arbitrary field in batch environment + case (_: BatchTableEnvironment, w: EventTimeWindow) => + extractFieldReferences(fields ++ groupKey ++ Seq(w.timeField)) + case (_, _) => + extractFieldReferences(fields ++ groupKey) + } - val (projection, aggs, props) = extractAggregationsAndProperties(fields, table.tableEnv) - - val groupWindow = window.toLogicalWindow - - val logical = + new Table(table.tableEnv, Project( - projection, + projectsOnAgg, WindowAggregate( groupKey, - groupWindow, - props, - aggs, - table.logicalPlan + window.toLogicalWindow, + propNames.map(a => Alias(a._1, a._2)).toSeq, + aggNames.map(a => Alias(a._1, a._2)).toSeq, + Project(projectFields, table.logicalPlan).validate(table.tableEnv) ).validate(table.tableEnv) - ).validate(table.tableEnv) - - new Table(table.tableEnv, logical) + ).validate(table.tableEnv)) } /** http://git-wip-us.apache.org/repos/asf/flink/blob/15e7f0a8/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/windows.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/windows.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/windows.scala index 32d67d7..5637d7a 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/windows.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/windows.scala @@ -48,7 +48,7 @@ trait GroupWindow { * @param timeField defines the time mode for streaming tables. For batch table it defines the * time attribute on which is grouped. */ -abstract class EventTimeWindow(timeField: Expression) extends GroupWindow { +abstract class EventTimeWindow(val timeField: Expression) extends GroupWindow { protected var name: Option[Expression] = None http://git-wip-us.apache.org/repos/asf/flink/blob/15e7f0a8/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/table/GroupWindowTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/table/GroupWindowTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/table/GroupWindowTest.scala index b59b151..9c2d6b3 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/table/GroupWindowTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/table/GroupWindowTest.scala @@ -164,7 +164,11 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataStreamAggregate", - streamTableNode(0), + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "string", "int") + ), term("groupBy", "string"), term("window", ProcessingTimeTumblingGroupWindow(None, 50.milli)), term("select", "string", "COUNT(int) AS TMP_0") @@ -185,7 +189,11 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataStreamAggregate", - streamTableNode(0), + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "string", "int") + ), term("groupBy", "string"), term("window", ProcessingTimeTumblingGroupWindow(None, 2.rows)), term("select", "string", "COUNT(int) AS TMP_0") @@ -206,7 +214,11 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataStreamAggregate", - streamTableNode(0), + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "string", "int") + ), term("groupBy", "string"), term("window", EventTimeTumblingGroupWindow(None, RowtimeAttribute(), 5.milli)), term("select", "string", "COUNT(int) AS TMP_0") @@ -249,7 +261,11 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataStreamAggregate", - streamTableNode(0), + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "string", "int") + ), term("groupBy", "string"), term("window", ProcessingTimeSlidingGroupWindow(None, 50.milli, 50.milli)), term("select", "string", "COUNT(int) AS TMP_0") @@ -270,7 +286,11 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataStreamAggregate", - streamTableNode(0), + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "string", "int") + ), term("groupBy", "string"), term("window", ProcessingTimeSlidingGroupWindow(None, 2.rows, 1.rows)), term("select", "string", "COUNT(int) AS TMP_0") @@ -291,7 +311,11 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataStreamAggregate", - streamTableNode(0), + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "string", "int") + ), term("groupBy", "string"), term("window", EventTimeSlidingGroupWindow(None, RowtimeAttribute(), 8.milli, 10.milli)), term("select", "string", "COUNT(int) AS TMP_0") @@ -334,7 +358,11 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataStreamAggregate", - streamTableNode(0), + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "string", "int") + ), term("groupBy", "string"), term("window", EventTimeSessionGroupWindow(None, RowtimeAttribute(), 7.milli)), term("select", "string", "COUNT(int) AS TMP_0") @@ -355,7 +383,11 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataStreamAggregate", - streamTableNode(0), + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "string", "int") + ), term("groupBy", "string"), term("window", ProcessingTimeTumblingGroupWindow(None, 50.milli)), term("select", "string", "COUNT(int) AS TMP_0") @@ -375,7 +407,11 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataStreamAggregate", - streamTableNode(0), + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "int") + ), term("window", ProcessingTimeTumblingGroupWindow(None, 2.rows)), term("select", "COUNT(int) AS TMP_0") ) @@ -394,7 +430,11 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataStreamAggregate", - streamTableNode(0), + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "int") + ), term("window", EventTimeTumblingGroupWindow(None, RowtimeAttribute(), 5.milli)), term("select", "COUNT(int) AS TMP_0") ) @@ -414,7 +454,11 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataStreamAggregate", - streamTableNode(0), + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "int") + ), term("window", EventTimeTumblingGroupWindow(None, RowtimeAttribute(), 2.rows)), term("select", "COUNT(int) AS TMP_0") ) @@ -434,7 +478,11 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataStreamAggregate", - streamTableNode(0), + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "int") + ), term("window", ProcessingTimeSlidingGroupWindow(None, 50.milli, 50.milli)), term("select", "COUNT(int) AS TMP_0") ) @@ -453,7 +501,11 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataStreamAggregate", - streamTableNode(0), + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "int") + ), term("window", ProcessingTimeSlidingGroupWindow(None, 2.rows, 1.rows)), term("select", "COUNT(int) AS TMP_0") ) @@ -472,7 +524,11 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataStreamAggregate", - streamTableNode(0), + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "int") + ), term("window", EventTimeSlidingGroupWindow(None, RowtimeAttribute(), 8.milli, 10.milli)), term("select", "COUNT(int) AS TMP_0") ) @@ -492,7 +548,11 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataStreamAggregate", - streamTableNode(0), + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "int") + ), term("window", EventTimeSlidingGroupWindow(None, RowtimeAttribute(), 2.rows, 1.rows)), term("select", "COUNT(int) AS TMP_0") ) @@ -511,7 +571,11 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataStreamAggregate", - streamTableNode(0), + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "int") + ), term("window", EventTimeSessionGroupWindow(None, RowtimeAttribute(), 7.milli)), term("select", "COUNT(int) AS TMP_0") ) @@ -531,7 +595,11 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataStreamAggregate", - streamTableNode(0), + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "string", "int") + ), term("groupBy", "string"), term("window", EventTimeTumblingGroupWindow( @@ -560,7 +628,11 @@ class GroupWindowTest extends TableTestBase { val expected = unaryNode( "DataStreamAggregate", - streamTableNode(0), + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "string", "int") + ), term("groupBy", "string"), term("window", EventTimeSlidingGroupWindow( @@ -592,7 +664,11 @@ class GroupWindowTest extends TableTestBase { "DataStreamCalc", unaryNode( "DataStreamAggregate", - streamTableNode(0), + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "string", "int") + ), term("groupBy", "string"), term("window", EventTimeSessionGroupWindow( @@ -626,7 +702,11 @@ class GroupWindowTest extends TableTestBase { "DataStreamCalc", unaryNode( "DataStreamAggregate", - streamTableNode(0), + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "string", "int") + ), term("groupBy", "string"), term("window", EventTimeTumblingGroupWindow( http://git-wip-us.apache.org/repos/asf/flink/blob/15e7f0a8/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/plan/FieldProjectionTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/plan/FieldProjectionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/plan/FieldProjectionTest.scala new file mode 100644 index 0000000..1cefb8a --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/plan/FieldProjectionTest.scala @@ -0,0 +1,317 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.table.plan + +import org.apache.flink.api.scala._ +import org.apache.flink.api.scala.table._ +import org.apache.flink.api.table.ValidationException +import org.apache.flink.api.table.expressions.{RowtimeAttribute, Upper, WindowReference} +import org.apache.flink.api.table.functions.ScalarFunction +import org.apache.flink.api.table.plan.FieldProjectionTest._ +import org.apache.flink.api.table.plan.logical.EventTimeTumblingGroupWindow +import org.apache.flink.api.table.utils.TableTestBase +import org.apache.flink.api.table.utils.TableTestUtil._ +import org.junit.Test + +/** + * Tests for all the situations when we can do fields projection. Like selecting few fields + * from a large field count source. + */ +class FieldProjectionTest extends TableTestBase { + + val util = batchTestUtil() + + val streamUtil = streamTestUtil() + + @Test + def testSimpleSelect(): Unit = { + val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd) + val resultTable = sourceTable.select('a, 'b) + + val expected = unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", "a", "b") + ) + + util.verifyTable(resultTable, expected) + } + + @Test + def testSelectAllFields(): Unit = { + val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd) + val resultTable1 = sourceTable.select('*) + val resultTable2 = sourceTable.select('a, 'b, 'c, 'd) + + val expected = batchTableNode(0) + + util.verifyTable(resultTable1, expected) + util.verifyTable(resultTable2, expected) + } + + @Test + def testSelectAggregation(): Unit = { + val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd) + val resultTable = sourceTable.select('a.sum, 'b.max) + + val expected = unaryNode( + "DataSetAggregate", + binaryNode( + "DataSetUnion", + values( + "DataSetValues", + tuples(List(null, null)), + term("values", "a", "b") + ), + unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", "a", "b") + ), + term("union", "a", "b") + ), + term("select", "SUM(a) AS TMP_0", "MAX(b) AS TMP_1") + ) + + util.verifyTable(resultTable, expected) + } + + @Test + def testSelectFunction(): Unit = { + val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd) + + util.tEnv.registerFunction("hashCode", MyHashCode) + + val resultTable = sourceTable.select("hashCode(c), b") + + val expected = unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", s"${MyHashCode.getClass.getCanonicalName}(c) AS _c0", "b") + ) + + util.verifyTable(resultTable, expected) + } + + @Test + def testSelectFromGroupedTable(): Unit = { + val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd) + val resultTable = sourceTable.groupBy('a, 'c).select('a) + + val expected = unaryNode( + "DataSetCalc", + unaryNode( + "DataSetAggregate", + unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", "a", "c") + ), + term("groupBy", "a", "c"), + term("select", "a", "c") + ), + term("select", "a") + ) + + util.verifyTable(resultTable, expected) + } + + @Test + def testSelectAllFieldsFromGroupedTable(): Unit = { + val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd) + val resultTable = sourceTable.groupBy('a, 'c).select('a, 'c) + + val expected = unaryNode( + "DataSetAggregate", + unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", "a", "c") + ), + term("groupBy", "a", "c"), + term("select", "a", "c") + ) + + util.verifyTable(resultTable, expected) + } + + @Test + def testSelectAggregationFromGroupedTable(): Unit = { + val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd) + val resultTable = sourceTable.groupBy('c).select('a.sum) + + val expected = + unaryNode( + "DataSetCalc", + unaryNode( + "DataSetAggregate", + unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", "a", "c") + ), + term("groupBy", "c"), + term("select", "c", "SUM(a) AS TMP_0") + ), + term("select", "TMP_0 AS TMP_1") + ) + + util.verifyTable(resultTable, expected) + } + + @Test + def testSelectFromGroupedTableWithNonTrivialKey(): Unit = { + val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd) + val resultTable = sourceTable.groupBy(Upper('c) as 'k).select('a.sum) + + val expected = + unaryNode( + "DataSetCalc", + unaryNode( + "DataSetAggregate", + unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", "a", "c", "UPPER(c) AS k") + ), + term("groupBy", "k"), + term("select", "k", "SUM(a) AS TMP_0") + ), + term("select", "TMP_0 AS TMP_1") + ) + + util.verifyTable(resultTable, expected) + } + + @Test + def testSelectFromGroupedTableWithFunctionKey(): Unit = { + val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd) + val resultTable = sourceTable.groupBy(MyHashCode('c) as 'k).select('a.sum) + + val expected = + unaryNode( + "DataSetCalc", + unaryNode( + "DataSetAggregate", + unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", "a", "c", s"${MyHashCode.getClass.getCanonicalName}(c) AS k") + ), + term("groupBy", "k"), + term("select", "k", "SUM(a) AS TMP_0") + ), + term("select", "TMP_0 AS TMP_1") + ) + + util.verifyTable(resultTable, expected) + } + + @Test + def testSelectFromStreamingWindow(): Unit = { + val sourceTable = streamUtil.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd) + val resultTable = sourceTable + .window(Tumble over 5.millis on 'rowtime as 'w) + .select(Upper('c).count, 'a.sum) + + val expected = + unaryNode( + "DataStreamAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "c", "a", "UPPER(c) AS $f2") + ), + term("window", + EventTimeTumblingGroupWindow( + Some(WindowReference("w")), + RowtimeAttribute(), + 5.millis)), + term("select", "COUNT($f2) AS TMP_0", "SUM(a) AS TMP_1") + ) + + streamUtil.verifyTable(resultTable, expected) + } + + @Test + def testSelectFromStreamingGroupedWindow(): Unit = { + val sourceTable = streamUtil.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd) + val resultTable = sourceTable + .groupBy('b) + .window(Tumble over 5.millis on 'rowtime as 'w) + .select(Upper('c).count, 'a.sum, 'b) + + val expected = unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "c", "a", "b", "UPPER(c) AS $f3") + ), + term("groupBy", "b"), + term("window", + EventTimeTumblingGroupWindow( + Some(WindowReference("w")), + RowtimeAttribute(), + 5.millis)), + term("select", "b", "COUNT($f3) AS TMP_0", "SUM(a) AS TMP_1") + ), + term("select", "TMP_0 AS TMP_2", "TMP_1 AS TMP_3", "b") + ) + + streamUtil.verifyTable(resultTable, expected) + } + + @Test(expected = classOf[ValidationException]) + def testSelectFromBatchWindow1(): Unit = { + val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd) + + // time field is selected + val resultTable = sourceTable + .window(Tumble over 5.millis on 'a as 'w) + .select('a.sum, 'c.count) + + val expected = "TODO" + + util.verifyTable(resultTable, expected) + } + + @Test(expected = classOf[ValidationException]) + def testSelectFromBatchWindow2(): Unit = { + val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd) + + // time field is not selected + val resultTable = sourceTable + .window(Tumble over 5.millis on 'a as 'w) + .select('c.count) + + val expected = "TODO" + + util.verifyTable(resultTable, expected) + } +} + +object FieldProjectionTest { + + object MyHashCode extends ScalarFunction { + def eval(s: String): Int = s.hashCode() + } + +} http://git-wip-us.apache.org/repos/asf/flink/blob/15e7f0a8/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala index 4eaba90..b281dfc 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala @@ -91,6 +91,10 @@ object TableTestUtil { |""".stripMargin.stripLineEnd } + def values(node: String, term: String*): String = { + s"$node(${term.mkString(", ")})" + } + def term(term: AnyRef, value: AnyRef*): String = { s"$term=[${value.mkString(", ")}]" }
