Repository: flink Updated Branches: refs/heads/master e13a7f80e -> 1cc1bb41e
[FLINK-6834] [table] Support scalar functions on Over Window This closes #4070. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/1cc1bb41 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/1cc1bb41 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/1cc1bb41 Branch: refs/heads/master Commit: 1cc1bb41e94b585200d9f7179bdeaa0bec0dcc5d Parents: 8ae4f2b Author: Jark Wu <[email protected]> Authored: Sat Jun 3 23:31:00 2017 +0800 Committer: zentol <[email protected]> Committed: Wed Jun 7 23:06:08 2017 +0200 ---------------------------------------------------------------------- .../flink/table/api/scala/expressionDsl.scala | 2 +- .../flink/table/plan/ProjectionTranslator.scala | 81 +++++++++++++++----- .../scala/stream/table/OverWindowITCase.scala | 33 ++++---- .../api/scala/stream/table/OverWindowTest.scala | 44 +++++++++++ .../OverWindowStringExpressionTest.scala | 35 +++++++++ 5 files changed, 160 insertions(+), 35 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/1cc1bb41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala index b87bb6d..7b424b2 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala @@ -456,7 +456,7 @@ trait ImplicitExpressionOperations { * .window(Over partitionBy 'c orderBy 'rowtime preceding 2.rows following CURRENT_ROW as 'w) * .select('c, 'a, 'a.count over 'w, 'a.sum over 'w) */ - def over(alias: Expression) = { + def over(alias: Expression): Expression = { expr match { case _: Aggregation => UnresolvedOverCall( expr.asInstanceOf[Aggregation], http://git-wip-us.apache.org/repos/asf/flink/blob/1cc1bb41/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala index 69b437a..b3799d1 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala @@ -226,30 +226,69 @@ object ProjectionTranslator { overWindows: Array[OverWindow], tEnv: TableEnvironment): Seq[Expression] = { - def resolveOverWindow(unresolvedCall: UnresolvedOverCall): Expression = { - - val overWindow = overWindows.find(_.alias.equals(unresolvedCall.alias)) - if (overWindow.isDefined) { - OverCall( - unresolvedCall.agg, - overWindow.get.partitionBy, - overWindow.get.orderBy, - overWindow.get.preceding, - overWindow.get.following) - } else { - unresolvedCall - } - } + exprs.map(e => replaceOverCall(e, overWindows, tEnv)) + } - val projectList = new ListBuffer[Expression] - exprs.foreach { - case Alias(u: UnresolvedOverCall, name, _) => - projectList += Alias(resolveOverWindow(u), name) + /** + * Find and replace UnresolvedOverCall with OverCall + * + * @param expr the expression to check + * @return an expression with correct resolved OverCall + */ + private def replaceOverCall( + expr: Expression, + overWindows: Array[OverWindow], + tableEnv: TableEnvironment): Expression = { + + expr match { case u: UnresolvedOverCall => - projectList += resolveOverWindow(u) - case e: Expression => projectList += e + val overWindow = overWindows.find(_.alias.equals(u.alias)) + if (overWindow.isDefined) { + OverCall( + u.agg, + overWindow.get.partitionBy, + overWindow.get.orderBy, + overWindow.get.preceding, + overWindow.get.following) + } else { + u + } + + case u: UnaryExpression => + val c = replaceOverCall(u.child, overWindows, tableEnv) + u.makeCopy(Array(c)) + + case b: BinaryExpression => + val l = replaceOverCall(b.left, overWindows, tableEnv) + val r = replaceOverCall(b.right, overWindows, tableEnv) + b.makeCopy(Array(l, r)) + + // Functions calls + case c @ Call(name, args: Seq[Expression]) => + val newArgs = + args.map( + (exp: Expression) => + replaceOverCall(exp, overWindows, tableEnv)) + c.makeCopy(Array(name, newArgs)) + + // Scala functions + case sfc @ ScalarFunctionCall(clazz, args: Seq[Expression]) => + val newArgs: Seq[Expression] = + args.map( + (exp: Expression) => + replaceOverCall(exp, overWindows, tableEnv)) + sfc.makeCopy(Array(clazz, newArgs)) + + // Array constructor + case c @ ArrayConstructor(args) => + val newArgs = + c.elements + .map((exp: Expression) => replaceOverCall(exp, overWindows, tableEnv)) + c.makeCopy(Array(newArgs)) + + // Other expressions + case e: Expression => e } - projectList } http://git-wip-us.apache.org/repos/asf/flink/blob/1cc1bb41/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/OverWindowITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/OverWindowITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/OverWindowITCase.scala index dc7d5dc..133328e 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/OverWindowITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/OverWindowITCase.scala @@ -26,6 +26,7 @@ import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment import org.apache.flink.streaming.api.watermark.Watermark import org.apache.flink.table.api.TableEnvironment import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions.WeightedAvg +import org.apache.flink.table.api.java.utils.UserDefinedScalarFunctions.JavaFunc0 import org.apache.flink.table.api.scala._ import org.apache.flink.table.api.scala.stream.table.OverWindowITCase.RowTimeSourceFunction import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, StreamingWithStateTestBase} @@ -110,6 +111,7 @@ class OverWindowITCase extends StreamingWithStateTestBase { .toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime) val countFun = new CountAggFunction val weightAvgFun = new WeightedAvg + val plusOne = new JavaFunc0 val windowedTable = table .window(Over partitionBy 'a orderBy 'rowtime preceding UNBOUNDED_RANGE following @@ -117,10 +119,15 @@ class OverWindowITCase extends StreamingWithStateTestBase { .select( 'a, 'b, 'c, 'b.sum over 'w, + "SUM:".toExpr + ('b.sum over 'w), countFun('b) over 'w, + (countFun('b) over 'w) + 1, + plusOne(countFun('b) over 'w), + array('b.avg over 'w, 'b.max over 'w), 'b.avg over 'w, 'b.max over 'w, 'b.min over 'w, + ('b.min over 'w).abs(), weightAvgFun('b, 'a) over 'w) val result = windowedTable.toAppendStream[Row] @@ -128,19 +135,19 @@ class OverWindowITCase extends StreamingWithStateTestBase { env.execute() val expected = mutable.MutableList( - "1,1,Hello,6,3,2,3,1,2", - "1,2,Hello,6,3,2,3,1,2", - "1,3,Hello world,6,3,2,3,1,2", - "1,1,Hi,7,4,1,3,1,1", - "2,1,Hello,1,1,1,1,1,1", - "2,2,Hello world,6,3,2,3,1,2", - "2,3,Hello world,6,3,2,3,1,2", - "1,4,Hello world,11,5,2,4,1,2", - "1,5,Hello world,29,8,3,7,1,3", - "1,6,Hello world,29,8,3,7,1,3", - "1,7,Hello world,29,8,3,7,1,3", - "2,4,Hello world,15,5,3,5,1,3", - "2,5,Hello world,15,5,3,5,1,3" + "1,1,Hello,6,SUM:6,3,4,4,[2, 3],2,3,1,1,2", + "1,2,Hello,6,SUM:6,3,4,4,[2, 3],2,3,1,1,2", + "1,3,Hello world,6,SUM:6,3,4,4,[2, 3],2,3,1,1,2", + "1,1,Hi,7,SUM:7,4,5,5,[1, 3],1,3,1,1,1", + "2,1,Hello,1,SUM:1,1,2,2,[1, 1],1,1,1,1,1", + "2,2,Hello world,6,SUM:6,3,4,4,[2, 3],2,3,1,1,2", + "2,3,Hello world,6,SUM:6,3,4,4,[2, 3],2,3,1,1,2", + "1,4,Hello world,11,SUM:11,5,6,6,[2, 4],2,4,1,1,2", + "1,5,Hello world,29,SUM:29,8,9,9,[3, 7],3,7,1,1,3", + "1,6,Hello world,29,SUM:29,8,9,9,[3, 7],3,7,1,1,3", + "1,7,Hello world,29,SUM:29,8,9,9,[3, 7],3,7,1,1,3", + "2,4,Hello world,15,SUM:15,5,6,6,[3, 5],3,5,1,1,3", + "2,5,Hello world,15,SUM:15,5,6,6,[3, 5],3,5,1,1,3" ) assertEquals(expected.sorted, StreamITCase.testResults.sorted) http://git-wip-us.apache.org/repos/asf/flink/blob/1cc1bb41/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/OverWindowTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/OverWindowTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/OverWindowTest.scala index 96e5eb5..49a210c 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/OverWindowTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/OverWindowTest.scala @@ -21,6 +21,7 @@ import org.apache.flink.api.scala._ import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions.WeightedAvgWithRetract import org.apache.flink.table.api.{Table, ValidationException} import org.apache.flink.table.api.scala._ +import org.apache.flink.table.expressions.utils.Func1 import org.apache.flink.table.utils.TableTestUtil._ import org.apache.flink.table.utils.{StreamTableTestUtil, TableTestBase} import org.junit.Test @@ -106,6 +107,49 @@ class OverWindowTest extends TableTestBase { streamUtil.tEnv.optimize(result.getRelNode, updatesAsRetraction = true) } + + @Test + def testScalarFunctionsOnOverWindow() = { + val weightedAvg = new WeightedAvgWithRetract + val plusOne = Func1 + + val result = table + .window(Over partitionBy 'b orderBy 'proctime preceding UNBOUNDED_ROW as 'w) + .select( + plusOne('a.sum over 'w as 'wsum) as 'd, + ('a.count over 'w).exp(), + (weightedAvg('c, 'a) over 'w) + 1, + "AVG:".toExpr + (weightedAvg('c, 'a) over 'w), + array(weightedAvg('c, 'a) over 'w, 'a.count over 'w)) + + val expected = + unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamOverAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "a", "b", "c", "proctime") + ), + term("partitionBy", "b"), + term("orderBy", "proctime"), + term("rows", "BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"), + term("select", "a", "b", "c", "proctime", + "SUM(a) AS w0$o0", + "COUNT(a) AS w0$o1", + "WeightedAvgWithRetract(c, a) AS w0$o2") + ), + term("select", + s"${plusOne.functionIdentifier}(w0$$o0) AS d", + "EXP(CAST(w0$o1)) AS _c1", + "+(w0$o2, 1) AS _c2", + "||('AVG:', CAST(w0$o2)) AS _c3", + "ARRAY(w0$o2, w0$o1) AS _c4") + ) + streamUtil.verifyTable(result, expected) + } + @Test def testProcTimeBoundedPartitionedRowsOver() = { val weightedAvg = new WeightedAvgWithRetract http://git-wip-us.apache.org/repos/asf/flink/blob/1cc1bb41/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/stringexpr/OverWindowStringExpressionTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/stringexpr/OverWindowStringExpressionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/stringexpr/OverWindowStringExpressionTest.scala index 04016f1..4c95916 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/stringexpr/OverWindowStringExpressionTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/stringexpr/OverWindowStringExpressionTest.scala @@ -19,8 +19,10 @@ package org.apache.flink.table.api.scala.stream.table.stringexpr import org.apache.flink.api.scala._ +import org.apache.flink.table.api.java.utils.UserDefinedAggFunctions.WeightedAvgWithRetract import org.apache.flink.table.api.java.{Over => JOver} import org.apache.flink.table.api.scala.{Over => SOver, _} +import org.apache.flink.table.expressions.utils.Func1 import org.apache.flink.table.utils.TableTestBase import org.junit.Test @@ -147,5 +149,38 @@ class OverWindowStringExpressionTest extends TableTestBase { verifyTableEquals(resScala, resJava) } + @Test + def testScalarFunctionsOnOverWindow(): Unit = { + val util = streamTestUtil() + val t = util.addTable[(Long, Int, String, Int, Long)]('a, 'b, 'c, 'd, 'e, 'rowtime.rowtime) + + val weightedAvg = new WeightedAvgWithRetract + val plusOne = Func1 + util.addFunction("plusOne", plusOne) + util.addFunction("weightedAvg", weightedAvg) + + val resScala = t + .window(SOver partitionBy 'a orderBy 'rowtime preceding UNBOUNDED_ROW as 'w) + .select( + array('a.sum over 'w, 'a.count over 'w), + plusOne('b.sum over 'w as 'wsum) as 'd, + ('a.count over 'w).exp(), + (weightedAvg('a, 'b) over 'w) + 1, + "AVG:".toExpr + (weightedAvg('a, 'b) over 'w)) + + val resJava = t + .window(JOver.partitionBy("a").orderBy("rowtime").preceding("unbounded_row").as("w")) + .select( + s""" + |ARRAY(SUM(a) OVER w, COUNT(a) OVER w), + |plusOne(SUM(b) OVER w AS wsum) AS d, + |EXP(COUNT(a) OVER w), + |(weightedAvg(a, b) OVER w) + 1, + |'AVG:' + (weightedAvg(a, b) OVER w) + """.stripMargin) + + verifyTableEquals(resScala, resJava) + } + }
