This is an automated email from the ASF dual-hosted git repository.
jincheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 98184bd [FLINK-12401][table] Support incremental emit under
AccRetract mode for non-window streaming FlatAggregate on Table API. This
closes #8550
98184bd is described below
commit 98184bd078d7b957f4cf99d26a5aacef1583fe3a
Author: hequn8128 <[email protected]>
AuthorDate: Wed Jun 12 10:32:34 2019 +0800
[FLINK-12401][table] Support incremental emit under AccRetract mode for
non-window streaming FlatAggregate on Table API.
This closes #8550
---
docs/dev/table/tableApi.md | 152 +++++++++++++--------
.../table/functions/TableAggregateFunction.java | 38 +++++-
.../table/codegen/AggregationCodeGenerator.scala | 108 ++++++++++++---
.../org/apache/flink/table/codegen/generated.scala | 14 +-
.../table/runtime/aggregate/AggregateUtil.scala | 13 +-
.../aggregate/GroupTableAggProcessFunction.scala | 14 +-
.../harness/TableAggregateHarnessTest.scala | 55 +++++++-
.../stream/table/TableAggregateITCase.scala | 73 ++++++++--
.../table/utils/UserDefinedTableAggFunctions.scala | 52 ++++++-
9 files changed, 414 insertions(+), 105 deletions(-)
diff --git a/docs/dev/table/tableApi.md b/docs/dev/table/tableApi.md
index 2143bf7..3995b97 100644
--- a/docs/dev/table/tableApi.md
+++ b/docs/dev/table/tableApi.md
@@ -2643,46 +2643,63 @@ Table table = input
</td>
<td>
<p>Similar to a <b>GroupBy Aggregation</b>. Groups the rows on the
grouping keys with the following running table aggregation operator to
aggregate rows group-wise. The difference from an AggregateFunction is that
TableAggregateFunction may return 0 or more records for a group. You have to
close the "flatAggregate" with a select statement. And the select statement
does not support aggregate functions.</p>
+ <p>Instead of using <code>emitValue</code> to output results, you can
also use the <code>emitUpdateWithRetract</code> method. Different from
<code>emitValue</code>, <code>emitUpdateWithRetract</code> is used to emit
values that have been updated. This method outputs data incrementally in
retract mode, i.e., once there is an update, we have to retract old records
before sending new updated ones. The <code>emitUpdateWithRetract</code> method
will be used in preference to the <code> [...]
{% highlight java %}
- public class MyMinMaxAcc {
- public int min = 0;
- public int max = 0;
+/**
+ * Accumulator for Top2.
+ */
+public class Top2Accum {
+ public Integer first;
+ public Integer second;
+}
+
+/**
+ * The top2 user-defined table aggregate function.
+ */
+public class Top2 extends TableAggregateFunction<Tuple2<Integer, Integer>,
Top2Accum> {
+
+ @Override
+ public Top2Accum createAccumulator() {
+ Top2Accum acc = new Top2Accum();
+ acc.first = Integer.MIN_VALUE;
+ acc.second = Integer.MIN_VALUE;
+ return acc;
}
- public class MyMinMax extends TableAggregateFunction<Row, MyMinMaxAcc> {
- public void accumulate(MyMinMaxAcc acc, int value) {
- if (value < acc.min) {
- acc.min = value;
- }
- if (value > acc.max) {
- acc.max = value;
- }
+ public void accumulate(Top2Accum acc, Integer v) {
+ if (v > acc.first) {
+ acc.second = acc.first;
+ acc.first = v;
+ } else if (v > acc.second) {
+ acc.second = v;
}
+ }
- @Override
- public MyMinMaxAcc createAccumulator() {
- return new MyMinMaxAcc();
+ public void merge(Top2Accum acc, java.lang.Iterable<Top2Accum> iterable) {
+ for (Top2Accum otherAcc : iterable) {
+ accumulate(acc, otherAcc.first);
+ accumulate(acc, otherAcc.second);
}
+ }
- public void emitValue(MyMinMaxAcc acc, Collector<Row> out) {
- out.collect(Row.of(acc.min, acc.min));
- out.collect(Row.of(acc.max, acc.max));
+ public void emitValue(Top2Accum acc, Collector<Tuple2<Integer, Integer>>
out) {
+ // emit the value and rank
+ if (acc.first != Integer.MIN_VALUE) {
+ out.collect(Tuple2.of(acc.first, 1));
}
-
- @Override
- public TypeInformation<Row> getResultType() {
- return new RowTypeInfo(Types.INT, Types.INT);
+ if (acc.second != Integer.MIN_VALUE) {
+ out.collect(Tuple2.of(acc.second, 2));
}
}
+}
-TableAggregateFunction tableAggFunc = new MyMinMax();
-tableEnv.registerFunction("tableAggFunc", tableAggFunc);
+tEnv.registerFunction("top2", new Top2());
Table orders = tableEnv.scan("Orders");
Table result = orders
- .groupBy("a")
- .flatAggregate("tableAggFunc(b) as (x, y)")
- .select("a, x, y");
+ .groupBy("key")
+ .flatAggregate("top2(a) as (v, rank)")
+ .select("key, v, rank");
{% endhighlight %}
<p><b>Note:</b> For streaming queries, the required state to compute
the query result might grow infinitely depending on the type of aggregation and
the number of distinct grouping keys. Please provide a query configuration with
a valid retention interval to prevent excessive state size. See <a
href="streaming/query_configuration.html">Query Configuration</a> for
details.</p>
</td>
@@ -2697,14 +2714,13 @@ Table result = orders
<td>
<p>Groups and aggregates a table on a <a href="#group-windows">group
window</a> and possibly one or more grouping keys. You have to close the
"flatAggregate" with a select statement. And the select statement does not
support aggregate functions.</p>
{% highlight java %}
-TableAggregateFunction tableAggFunc = new MyMinMax();
-tableEnv.registerFunction("tableAggFunc", tableAggFunc);
+tableEnv.registerFunction("top2", new Top2());
Table orders = tableEnv.scan("Orders");
Table result = orders
.window(Tumble.over("5.minutes").on("rowtime").as("w")) // define window
.groupBy("a, w") // group by key and window
- .flatAggregate("tableAggFunc(b) as (x, y)")
- .select("a, w.start, w.end, w.rowtime, x, y"); // access window properties
and aggregate results
+ .flatAggregate("top2(b) as (v, rank)")
+ .select("a, w.start, w.end, w.rowtime, v, rank"); // access window
properties and aggregate results
{% endhighlight %}
</td>
</tr>
@@ -2832,43 +2848,67 @@ val table = input
</td>
<td>
<p>Similar to a <b>GroupBy Aggregation</b>. Groups the rows on the
grouping keys with the following running table aggregation operator to
aggregate rows group-wise. The difference from an AggregateFunction is that
TableAggregateFunction may return 0 or more records for a group. You have to
close the "flatAggregate" with a select statement. And the select statement
does not support aggregate functions.</p>
+ <p>Instead of using <code>emitValue</code> to output results, you can
also use the <code>emitUpdateWithRetract</code> method. Different from
<code>emitValue</code>, <code>emitUpdateWithRetract</code> is used to emit
values that have been updated. This method outputs data incrementally in
retract mode, i.e., once there is an update, we have to retract old records
before sending new updated ones. The <code>emitUpdateWithRetract</code> method
will be used in preference to the <code> [...]
{% highlight scala %}
-case class MyMinMaxAcc(var min: Int, var max: Int)
+import java.lang.{Integer => JInteger}
+import org.apache.flink.table.api.Types
+import org.apache.flink.table.functions.TableAggregateFunction
+
+/**
+ * Accumulator for top2.
+ */
+class Top2Accum {
+ var first: JInteger = _
+ var second: JInteger = _
+}
-class MyMinMax extends TableAggregateFunction[Row, MyMinMaxAcc] {
+/**
+ * The top2 user-defined table aggregate function.
+ */
+class Top2 extends TableAggregateFunction[JTuple2[JInteger, JInteger],
Top2Accum] {
- def accumulate(acc: MyMinMaxAcc, value: Int): Unit = {
- if (value < acc.min) {
- acc.min = value
- }
- if (value > acc.max) {
- acc.max = value
- }
+ override def createAccumulator(): Top2Accum = {
+ val acc = new Top2Accum
+ acc.first = Int.MinValue
+ acc.second = Int.MinValue
+ acc
}
- def resetAccumulator(acc: MyMinMaxAcc): Unit = {
- acc.min = 0
- acc.max = 0
+ def accumulate(acc: Top2Accum, v: Int) {
+ if (v > acc.first) {
+ acc.second = acc.first
+ acc.first = v
+ } else if (v > acc.second) {
+ acc.second = v
+ }
}
- override def createAccumulator(): MyMinMaxAcc = MyMinMaxAcc(0, 0)
-
- def emitValue(acc: MyMinMaxAcc, out: Collector[Row]): Unit = {
- out.collect(Row.of(Integer.valueOf(acc.min), Integer.valueOf(acc.min)))
- out.collect(Row.of(Integer.valueOf(acc.max), Integer.valueOf(acc.max)))
+ def merge(acc: Top2Accum, its: JIterable[Top2Accum]): Unit = {
+ val iter = its.iterator()
+ while (iter.hasNext) {
+ val top2 = iter.next()
+ accumulate(acc, top2.first)
+ accumulate(acc, top2.second)
+ }
}
- override def getResultType: TypeInformation[Row] = {
- new RowTypeInfo(Types.INT, Types.INT)
+ def emitValue(acc: Top2Accum, out: Collector[JTuple2[JInteger, JInteger]]):
Unit = {
+ // emit the value and rank
+ if (acc.first != Int.MinValue) {
+ out.collect(JTuple2.of(acc.first, 1))
+ }
+ if (acc.second != Int.MinValue) {
+ out.collect(JTuple2.of(acc.second, 2))
+ }
}
}
-val tableAggFunc = new MyMinMax
+val top2 = new Top2
val orders: Table = tableEnv.scan("Orders")
val result = orders
- .groupBy('a)
- .flatAggregate(tableAggFunc('b) as ('x, 'y))
- .select('a, 'x, 'y)
+ .groupBy('key)
+ .flatAggregate(top2('a) as ('v, 'rank))
+ .select('key, 'v, 'rank)
{% endhighlight %}
<p><b>Note:</b> For streaming queries, the required state to compute
the query result might grow infinitely depending on the type of aggregation and
the number of distinct grouping keys. Please provide a query configuration with
a valid retention interval to prevent excessive state size. See <a
href="streaming/query_configuration.html">Query Configuration</a> for
details.</p>
</td>
@@ -2882,13 +2922,13 @@ val result = orders
<td>
<p>Groups and aggregates a table on a <a href="#group-windows">group
window</a> and possibly one or more grouping keys. You have to close the
"flatAggregate" with a select statement. And the select statement does not
support aggregate functions.</p>
{% highlight scala %}
-val tableAggFunc = new MyMinMax
+val top2 = new Top2
val orders: Table = tableEnv.scan("Orders")
val result = orders
.window(Tumble over 5.minutes on 'rowtime as 'w) // define window
.groupBy('a, 'w) // group by key and window
- .flatAggregate(tableAggFunc('b) as ('x, 'y))
- .select('a, w.start, 'w.end, 'w.rowtime, 'x, 'y) // access window
properties and aggregate results
+ .flatAggregate(top2('b) as ('v, 'rank))
+ .select('a, w.start, 'w.end, 'w.rowtime, 'v, 'rank) // access window
properties and aggregate results
{% endhighlight %}
</td>
diff --git
a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/TableAggregateFunction.java
b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/TableAggregateFunction.java
index 4224983..c8d3aef 100644
---
a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/TableAggregateFunction.java
+++
b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/TableAggregateFunction.java
@@ -19,6 +19,7 @@
package org.apache.flink.table.functions;
import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.util.Collector;
/**
* Base class for user-defined table aggregates.
@@ -28,7 +29,7 @@ import org.apache.flink.annotation.PublicEvolving;
* <ul>
* <li>createAccumulator</li>
* <li>accumulate</li>
- * <li>emitValue</li>
+ * <li>emitValue or emitUpdateWithRetract</li>
* </ul>
*
* <p>There is another method that can be optional to have:
@@ -79,6 +80,28 @@ import org.apache.flink.annotation.PublicEvolving;
* }
* </pre>
*
+ * <pre>
+ * {@code
+ * Called every time when an aggregation result should be materialized. The
returned value could
+ * be either an early and incomplete result (periodically emitted as data
arrive) or the final
+ * result of the aggregation.
+ *
+ * Different from emitValue, emitUpdateWithRetract is used to emit values that
have been updated.
+ * This method outputs data incrementally in retract mode, i.e., once there is
an update, we have
+ * to retract old records before sending new updated ones. The
emitUpdateWithRetract method will be
+ * used in preference to the emitValue method if both methods are defined in
the table aggregate
+ * function, because the method is treated to be more efficient than emitValue
as it can output
+ * values incrementally.
+ *
+ * param: accumulator the accumulator which contains the current
aggregated results
+ * param: out the retractable collector used to output data.
Use collect method
+ * to output(add) records and use retract method
to retract(delete)
+ * records.
+ *
+ * public void emitUpdateWithRetract(ACC accumulator, RetractableCollector<T>
out)
+ * }
+ * </pre>
+ *
* @param <T> the type of the table aggregation result
* @param <ACC> the type of the table aggregation accumulator. The accumulator
is used to keep the
* aggregated values which are needed to compute an aggregation
result.
@@ -88,4 +111,17 @@ import org.apache.flink.annotation.PublicEvolving;
@PublicEvolving
public abstract class TableAggregateFunction<T, ACC> extends
UserDefinedAggregateFunction<T, ACC> {
+ /**
+ * Collects a record and forwards it. The collector can output retract
messages with the retract
+ * method. Note: only use it in {@code emitUpdateWithRetract}.
+ */
+ public interface RetractableCollector<T> extends Collector<T> {
+
+ /**
+ * Retract a record.
+ *
+ * @param record The record to retract.
+ */
+ void retract(T record);
+ }
}
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala
index 2f39991..2de3f0f 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala
@@ -28,16 +28,17 @@ import
org.apache.flink.api.common.state.{ListStateDescriptor, MapStateDescripto
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils.RowTypeInfo
import
org.apache.flink.api.java.typeutils.TypeExtractionUtils.{extractTypeArgument,
getRawClass}
-import org.apache.flink.table.api.TableConfig
+import org.apache.flink.table.api.{TableConfig, ValidationException}
import org.apache.flink.table.api.dataview._
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.codegen.CodeGenUtils.{newName,
reflectiveFieldWriteAccess}
import org.apache.flink.table.codegen.Indenter.toISC
import org.apache.flink.table.dataview.{StateListView, StateMapView}
-import org.apache.flink.table.functions.{TableAggregateFunction,
UserDefinedAggregateFunction}
+import org.apache.flink.table.functions.{TableAggregateFunction,
UserDefinedAggregateFunction, UserDefinedFunction}
import org.apache.flink.table.functions.aggfunctions.DistinctAccumulator
import org.apache.flink.table.functions.utils.{AggSqlFunction,
UserDefinedFunctionUtils}
import
org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.{getUserDefinedMethod,
signatureToString}
+import org.apache.flink.table.runtime.CRowWrappingCollector
import org.apache.flink.table.runtime.aggregate.AggregateUtil.CalcitePair
import org.apache.flink.table.runtime.aggregate.{AggregateUtil,
GeneratedAggregations, GeneratedTableAggregations, SingleElementIterable}
import org.apache.flink.table.utils.EncodingUtils
@@ -840,18 +841,33 @@ class AggregationCodeGenerator(
*/
def generateTableAggregations(
tableAggOutputRowType: RowTypeInfo,
- tableAggOutputType: TypeInformation[_]): GeneratedAggregationsFunction = {
+ tableAggOutputType: TypeInformation[_],
+ supportEmitIncrementally: Boolean): GeneratedAggregationsFunction = {
// constants
val CONVERT_COLLECTOR_CLASS_TERM = "ConvertCollector"
-
val CONVERT_COLLECTOR_VARIABLE_TERM = "convertCollector"
val COLLECTOR_VARIABLE_TERM = "cRowWrappingcollector"
val CONVERTER_ROW_RESULT_TERM = "rowTerm"
+ // emit methods
+ val emitValue = "emitValue"
+ val emitUpdateWithRetract = "emitUpdateWithRetract"
+
+ // collectors
val COLLECTOR: String = classOf[Collector[_]].getCanonicalName
+ val CROW_WRAPPING_COLLECTOR: String =
classOf[CRowWrappingCollector].getCanonicalName
+ val RETRACTABLE_COLLECTOR: String =
+ classOf[TableAggregateFunction.RetractableCollector[_]].getCanonicalName
+
val ROW: String = classOf[Row].getCanonicalName
+ // Set emitValue as the default emit method here and set it to
emitUpdateWithRetract on
+ // condition that: 1. emitUpdateWithRetract has been defined in the table
aggregate
+ // function and 2. the operator supports emit incrementally, for example,
window flatAggregate
+ // doesn't support emit incrementally now)
+ var finalEmitMethodName: String = emitValue
+
def genEmit: String = {
val sig: String =
@@ -865,13 +881,14 @@ class AggregationCodeGenerator(
val emitAcc =
j"""
| ${genAccDataViewFieldSetter(s"acc$i", i)}
- | ${aggs(i)}.emitValue(acc$i
+ | ${aggs(i)}.$finalEmitMethodName(acc$i
| ${if (!parametersCode(i).isEmpty) "," else ""}
| $CONVERT_COLLECTOR_VARIABLE_TERM);
""".stripMargin
j"""
| ${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
- | $CONVERT_COLLECTOR_VARIABLE_TERM.$COLLECTOR_VARIABLE_TERM =
collector;
+ | $CONVERT_COLLECTOR_VARIABLE_TERM.$COLLECTOR_VARIABLE_TERM =
+ | ($CROW_WRAPPING_COLLECTOR) collector;
| $emitAcc
""".stripMargin
}
@@ -900,22 +917,63 @@ class AggregationCodeGenerator(
functionGenerator.reuseInputUnboxingCode() + resultExprs.code
}
+ def checkAndGetEmitValueMethod(function: UserDefinedFunction, index: Int):
Unit = {
+ finalEmitMethodName = emitValue
+ getUserDefinedMethod(
+ function, emitValue, Array(accTypeClasses(index),
classOf[Collector[_]]))
+ .getOrElse(throw new CodeGenException(
+ s"No matching $emitValue method found for " +
+ s"tableAggregate ${function.getClass.getCanonicalName}'."))
+ }
+
/**
* Call super init and check emit methods.
*/
def innerInit(): Unit = {
init()
- // check and validate the emit methods
+ // check and validate the emit methods. Find incremental emit method
first if the operator
+ // supports emit incrementally.
aggregates.zipWithIndex.map {
case (a, i) =>
- val methodName = "emitValue"
- getUserDefinedMethod(
- a, methodName, Array(accTypeClasses(i), classOf[Collector[_]]))
- .getOrElse(
- throw new CodeGenException(
- s"No matching $methodName method found for " +
- s"tableAggregate ${a.getClass.getCanonicalName}'.")
- )
+ if (supportEmitIncrementally) {
+ try {
+ finalEmitMethodName = emitUpdateWithRetract
+ getUserDefinedMethod(
+ a,
+ emitUpdateWithRetract,
+ Array(accTypeClasses(i),
classOf[TableAggregateFunction.RetractableCollector[_]]))
+ .getOrElse(checkAndGetEmitValueMethod(a, i))
+ } catch {
+ case _: ValidationException =>
+ // Use try catch here as exception will be thrown if there is
no
+ // emitUpdateWithRetract method
+ checkAndGetEmitValueMethod(a, i)
+ }
+ } else {
+ checkAndGetEmitValueMethod(a, i)
+ }
+ }
+ }
+
+ /**
+ * Generates the retract method if it is a
[[TableAggregateFunction.RetractableCollector]].
+ */
+ def getRetractMethodForConvertCollector(emitMethodName: String): String = {
+ if (emitMethodName == emitValue) {
+ // Users can't retract messages with emitValue method.
+ j"""
+ |
+ """.stripMargin
+ } else {
+ // only generates retract method for RetractableCollector
+ j"""
+ | @Override
+ | public void retract(Object record) throws Exception {
+ | $COLLECTOR_VARIABLE_TERM.setChange(false);
+ | $COLLECTOR_VARIABLE_TERM.collect(convertToRow(record));
+ | $COLLECTOR_VARIABLE_TERM.setChange(true);
+ | }
+ """.stripMargin
}
}
@@ -931,6 +989,10 @@ class AggregationCodeGenerator(
val generatedAggregationsClass =
classOf[GeneratedTableAggregations].getCanonicalName
val aggOutputTypeName = tableAggOutputType.getTypeClass.getCanonicalName
+
+ val baseCollectorString =
+ if (finalEmitMethodName == emitValue) COLLECTOR else
RETRACTABLE_COLLECTOR
+
val funcCode =
j"""
|public final class $funcName extends $generatedAggregationsClass {
@@ -959,9 +1021,9 @@ class AggregationCodeGenerator(
| ${reuseCloseCode()}
| }
|
- | private class $CONVERT_COLLECTOR_CLASS_TERM implements $COLLECTOR {
+ | private class $CONVERT_COLLECTOR_CLASS_TERM implements
$baseCollectorString {
|
- | public $COLLECTOR<$ROW> $COLLECTOR_VARIABLE_TERM;
+ | public $CROW_WRAPPING_COLLECTOR $COLLECTOR_VARIABLE_TERM;
| private final $ROW $CONVERTER_ROW_RESULT_TERM =
| new $ROW(${tableAggOutputType.getArity});
|
@@ -976,6 +1038,8 @@ class AggregationCodeGenerator(
| $COLLECTOR_VARIABLE_TERM.collect(convertToRow(record));
| }
|
+ | ${getRetractMethodForConvertCollector(finalEmitMethodName)}
+ |
| @Override
| public void close() {
| $COLLECTOR_VARIABLE_TERM.close();
@@ -984,7 +1048,7 @@ class AggregationCodeGenerator(
|}
""".stripMargin
- GeneratedAggregationsFunction(funcName, funcCode)
+ new GeneratedTableAggregationsFunction(funcName, funcCode,
finalEmitMethodName != emitValue)
}
/**
@@ -994,12 +1058,15 @@ class AggregationCodeGenerator(
* @param outputType Output type of the (table)aggregate node.
* @param groupSize The size of the groupings.
* @param namedAggregates The correspond named aggregates in the aggregate
operator.
+ * @param supportEmitIncrementally Whether support emit values
incrementally. Window operators
+ * don't support it yet.
* @return A GeneratedAggregationsFunction
*/
def genAggregationsOrTableAggregations(
outputType: RelDataType,
groupSize: Int,
- namedAggregates: Seq[CalcitePair[AggregateCall, String]])
+ namedAggregates: Seq[CalcitePair[AggregateCall, String]],
+ supportEmitIncrementally: Boolean)
: GeneratedAggregationsFunction = {
if (isTableAggregate) {
@@ -1017,7 +1084,8 @@ class AggregationCodeGenerator(
generateTableAggregations(
tableAggOutputRowType,
-
namedAggregates.head.left.getAggregation.asInstanceOf[AggSqlFunction].returnType)
+
namedAggregates.head.left.getAggregation.asInstanceOf[AggSqlFunction].returnType,
+ supportEmitIncrementally)
} else {
generateAggregations
}
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/codegen/generated.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/codegen/generated.scala
index 3301e81..a91377f 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/codegen/generated.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/codegen/generated.scala
@@ -62,7 +62,7 @@ case class GeneratedFunction[F <: Function, T <: Any](
code: String)
/**
- * Describes a generated aggregate or table aggregate helper function
+ * Describes a generated aggregate helper function
*
* @param name class name of the generated Function.
* @param code code of the generated Function.
@@ -72,6 +72,18 @@ case class GeneratedAggregationsFunction(
code: String)
/**
+ * Describes a generated table aggregate helper function
+ *
+ * @param name class name of the generated Function.
+ * @param code code of the generated Function.
+ * @param emitValuesIncrementally whether emit incremental values.
+ */
+class GeneratedTableAggregationsFunction(
+ name: String,
+ code: String,
+ val emitValuesIncrementally: Boolean) extends
GeneratedAggregationsFunction(name, code)
+
+/**
* Describes a generated [[InputFormat]].
*
* @param name class name of the generated input function.
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
index f44b4bd..83ef569 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
@@ -37,7 +37,7 @@ import org.apache.flink.table.api.dataview.DataViewSpec
import org.apache.flink.table.api.{StreamQueryConfig, TableConfig,
TableException}
import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty
import org.apache.flink.table.calcite.FlinkTypeFactory
-import org.apache.flink.table.codegen.AggregationCodeGenerator
+import org.apache.flink.table.codegen.{AggregationCodeGenerator,
GeneratedTableAggregationsFunction}
import
org.apache.flink.table.expressions.PlannerExpressionUtils.isTimeIntervalLiteral
import org.apache.flink.table.expressions._
import org.apache.flink.table.functions.aggfunctions._
@@ -225,19 +225,21 @@ object AggregateUtil {
accConfig = Some(aggregateMetadata.getAggregatesAccumulatorSpecs)
)
- val genAggregations = generator
- .genAggregationsOrTableAggregations(outputType, groupings.length,
namedAggregates)
val aggregationStateType: RowTypeInfo = new RowTypeInfo(aggregateMetadata
.getAggregatesAccumulatorTypes: _*)
if (isTableAggregate) {
+ val genAggregations = generator
+ .genAggregationsOrTableAggregations(outputType, groupings.length,
namedAggregates, true)
new GroupTableAggProcessFunction[K](
- genAggregations,
+ genAggregations.asInstanceOf[GeneratedTableAggregationsFunction],
aggregationStateType,
generateRetraction,
groupings.length,
queryConfig)
} else {
+ val genAggregations = generator
+ .genAggregationsOrTableAggregations(outputType, groupings.length,
namedAggregates, false)
new GroupAggProcessFunction[K](
genAggregations,
aggregationStateType,
@@ -1212,7 +1214,8 @@ object AggregateUtil {
val genAggregations = generator.genAggregationsOrTableAggregations(
outputType,
groupingKeys.length,
- namedAggregates)
+ namedAggregates,
+ false)
val aggFunction = new AggregateAggFunction(genAggregations,
isTableAggregate)
(aggFunction, accumulatorRowType)
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupTableAggProcessFunction.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupTableAggProcessFunction.scala
index f949bf7..07e647b 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupTableAggProcessFunction.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupTableAggProcessFunction.scala
@@ -24,7 +24,7 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo
import org.apache.flink.configuration.Configuration
import org.apache.flink.streaming.api.functions.KeyedProcessFunction
import org.apache.flink.table.api.{StreamQueryConfig, Types}
-import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction}
+import org.apache.flink.table.codegen.{Compiler,
GeneratedTableAggregationsFunction}
import org.apache.flink.table.runtime.TableAggregateCollector
import org.apache.flink.table.runtime.types.CRow
import org.apache.flink.table.util.Logging
@@ -38,7 +38,7 @@ import org.apache.flink.util.Collector
* @param aggregationStateType The row type info of aggregation
*/
class GroupTableAggProcessFunction[K](
- private val genTableAggregations: GeneratedAggregationsFunction,
+ private val genTableAggregations: GeneratedTableAggregationsFunction,
private val aggregationStateType: RowTypeInfo,
private val generateRetraction: Boolean,
private val groupKeySize: Int,
@@ -115,7 +115,9 @@ class GroupTableAggProcessFunction[K](
concatCollector.out = out
if (!firstRow) {
- if (generateRetraction) {
+ // retractions will be generated by TableAggregateFunction manually if
incremental emit value
+ // methods have been defined.
+ if (generateRetraction && !genTableAggregations.emitValuesIncrementally)
{
concatCollector.setChange(false)
function.emit(accumulators, concatCollector)
concatCollector.setChange(true)
@@ -140,13 +142,13 @@ class GroupTableAggProcessFunction[K](
if (inputCnt != 0) {
// we aggregated at least one record for this key
+ // emit the new result
+ function.emit(accumulators, concatCollector)
+
// update the state
state.update(accumulators)
cntState.update(inputCnt)
- // emit the new result
- function.emit(accumulators, concatCollector)
-
} else {
// and clear all state
state.clear()
diff --git
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/harness/TableAggregateHarnessTest.scala
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/harness/TableAggregateHarnessTest.scala
index 7d92f2b..56d79d5 100644
---
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/harness/TableAggregateHarnessTest.scala
+++
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/harness/TableAggregateHarnessTest.scala
@@ -27,7 +27,7 @@ import
org.apache.flink.streaming.runtime.streamrecord.StreamRecord
import org.apache.flink.table.api.scala._
import org.apache.flink.table.runtime.harness.HarnessTestBase._
import org.apache.flink.table.runtime.types.CRow
-import org.apache.flink.table.utils.{Top3WithMapView}
+import org.apache.flink.table.utils.{Top3WithEmitRetractValue, Top3WithMapView}
import org.apache.flink.types.Row
import org.junit.Test
@@ -104,6 +104,59 @@ class TableAggregateHarnessTest extends HarnessTestBase {
}
@Test
+ def testTableAggregateEmitRetractValueIncrementally(): Unit = {
+
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ val tEnv = StreamTableEnvironment.create(env)
+
+ val top3 = new Top3WithEmitRetractValue
+ val source = env.fromCollection(data).toTable(tEnv, 'a, 'b)
+ val resultTable = source
+ .groupBy('a)
+ .flatAggregate(top3('b) as ('b1, 'b2))
+ .select('a, 'b1, 'b2)
+
+ val testHarness = createHarnessTester[Int, CRow, CRow](
+ resultTable.toRetractStream[Row](queryConfig), "groupBy: (a)")
+
+ testHarness.open()
+
+ val expectedOutput = new ConcurrentLinkedQueue[Object]()
+
+ // register cleanup timer with 3001
+ testHarness.setProcessingTime(1)
+
+ // input with two columns: key and value
+ testHarness.processElement(new StreamRecord(CRow(1: JInt, 1: JInt), 1))
+ // output with three columns: key, value, value. The value is in the top3
of the key
+ expectedOutput.add(new StreamRecord(CRow(1: JInt, 1: JInt, 1: JInt), 1))
+
+ testHarness.processElement(new StreamRecord(CRow(1: JInt, 2: JInt), 1))
+ expectedOutput.add(new StreamRecord(CRow(1: JInt, 2: JInt, 2: JInt), 1))
+
+ testHarness.processElement(new StreamRecord(CRow(1: JInt, 3: JInt), 1))
+ expectedOutput.add(new StreamRecord(CRow(1: JInt, 3: JInt, 3: JInt), 1))
+
+ testHarness.processElement(new StreamRecord(CRow(1: JInt, 2: JInt), 1))
+ expectedOutput.add(new StreamRecord(CRow(false, 1: JInt, 1: JInt, 1:
JInt), 1))
+ expectedOutput.add(new StreamRecord(CRow(1: JInt, 2: JInt, 2: JInt), 1))
+
+ // ingest data with key value of 2
+ testHarness.processElement(new StreamRecord(CRow(2: JInt, 2: JInt), 1))
+ expectedOutput.add(new StreamRecord(CRow(2: JInt, 2: JInt, 2: JInt), 1))
+
+ // trigger cleanup timer
+ testHarness.setProcessingTime(3002)
+ testHarness.processElement(new StreamRecord(CRow(1: JInt, 2: JInt), 1))
+ expectedOutput.add(new StreamRecord(CRow(1: JInt, 2: JInt, 2: JInt), 1))
+
+ val result = testHarness.getOutput
+
+ verify(expectedOutput, result)
+ testHarness.close()
+ }
+
+ @Test
def testTableAggregateWithRetractInput(): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
diff --git
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/stream/table/TableAggregateITCase.scala
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/stream/table/TableAggregateITCase.scala
index d21c0a9..a45097a 100644
---
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/stream/table/TableAggregateITCase.scala
+++
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/runtime/stream/table/TableAggregateITCase.scala
@@ -19,13 +19,12 @@
package org.apache.flink.table.runtime.stream.table
import org.apache.flink.api.common.time.Time
-import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
import org.apache.flink.table.api.scala._
import org.apache.flink.api.scala._
-import org.apache.flink.table.api.{StreamQueryConfig, Types,
ValidationException}
+import org.apache.flink.table.api.{StreamQueryConfig, ValidationException}
import org.apache.flink.table.runtime.utils.{StreamITCase, StreamTestData,
StreamingWithStateTestBase}
-import org.apache.flink.table.utils.{Top3, Top3WithMapView}
+import org.apache.flink.table.utils.{EmptyTableAggFuncWithoutEmit, Top3,
Top3WithEmitRetractValue, Top3WithMapView}
import org.apache.flink.types.Row
import org.junit.Assert.assertEquals
import org.junit.Test
@@ -76,6 +75,44 @@ class TableAggregateITCase extends
StreamingWithStateTestBase {
}
@Test
+ def testEmitRetractValueIncrementally(): Unit = {
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ env.setStateBackend(getStateBackend)
+ val tEnv = StreamTableEnvironment.create(env)
+ StreamITCase.clear
+
+ val top3 = new Top3WithEmitRetractValue
+ val source = StreamTestData.get3TupleDataStream(env).toTable(tEnv, 'a, 'b,
'c)
+ val resultTable = source.groupBy('b)
+ .flatAggregate(top3('a))
+ .select('b, 'f0, 'f1)
+ .as('category, 'v1, 'v2)
+
+ val results = resultTable.toRetractStream[Row](queryConfig)
+ results.addSink(new StreamITCase.RetractingSink).setParallelism(1)
+ env.execute()
+
+ val expected = List(
+ "1,1,1",
+ "2,2,2",
+ "2,3,3",
+ "3,4,4",
+ "3,5,5",
+ "3,6,6",
+ "4,10,10",
+ "4,9,9",
+ "4,8,8",
+ "5,15,15",
+ "5,14,14",
+ "5,13,13",
+ "6,21,21",
+ "6,20,20",
+ "6,19,19"
+ ).sorted
+ assertEquals(expected, StreamITCase.retractedResults.sorted)
+ }
+
+ @Test
def testNonkeyedFlatAggregate(): Unit = {
val env = StreamExecutionEnvironment.getExecutionEnvironment
env.setStateBackend(getStateBackend)
@@ -140,12 +177,6 @@ class TableAggregateITCase extends
StreamingWithStateTestBase {
val tEnv = StreamTableEnvironment.create(env)
StreamITCase.clear
- tEnv.registerTableSink(
- "retractSink",
- new TestRetractSink().configure(
- Array[String]("v1", "v2"),
- Array[TypeInformation[_]](Types.INT, Types.INT)))
-
val top3 = new Top3
val source = StreamTestData.get3TupleDataStream(env).toTable(tEnv, 'a, 'b,
'c)
source
@@ -153,7 +184,29 @@ class TableAggregateITCase extends
StreamingWithStateTestBase {
.select('b, 'a.sum as 'a)
.flatAggregate(top3('a) as ('v1, 'v2))
.select('v1, 'v2)
- .insertInto("retractSink")
+ .toRetractStream[Row]
+
+ env.execute()
+ }
+
+ @Test
+ def testTableAggFunctionWithoutEmitValueMethod(): Unit = {
+ expectedException.expect(classOf[ValidationException])
+ expectedException.expectMessage("Function class
'org.apache.flink.table.utils." +
+ "EmptyTableAggFuncWithoutEmit' does not implement at least one method
named 'emitValue' " +
+ "which is public, not abstract and (in case of table functions) not
static.")
+
+ val env = StreamExecutionEnvironment.getExecutionEnvironment
+ env.setStateBackend(getStateBackend)
+ val tEnv = StreamTableEnvironment.create(env)
+ StreamITCase.clear
+
+ val func = new EmptyTableAggFuncWithoutEmit
+ val source = StreamTestData.get3TupleDataStream(env).toTable(tEnv, 'a, 'b,
'c)
+ source
+ .flatAggregate(func('a) as ('v1, 'v2))
+ .select('v1, 'v2)
+ .toRetractStream[Row]
env.execute()
}
diff --git
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/utils/UserDefinedTableAggFunctions.scala
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/utils/UserDefinedTableAggFunctions.scala
index 4000364..1a50a4c 100644
---
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/utils/UserDefinedTableAggFunctions.scala
+++
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/utils/UserDefinedTableAggFunctions.scala
@@ -20,7 +20,6 @@ package org.apache.flink.table.utils
import org.apache.flink.table.functions.TableAggregateFunction
import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
-
import java.lang.{Integer => JInt}
import java.lang.{Iterable => JIterable}
import java.sql.Timestamp
@@ -28,8 +27,11 @@ import java.util
import org.apache.flink.table.api.Types
import org.apache.flink.table.api.dataview.MapView
+import
org.apache.flink.table.functions.TableAggregateFunction.RetractableCollector
import org.apache.flink.util.Collector
+import scala.collection.mutable.ListBuffer
+
class Top3Accum {
var data: util.HashMap[JInt, JInt] = _
var size: JInt = _
@@ -128,6 +130,43 @@ class Top3WithMapViewAccum {
var smallest: JInt = _
}
+class Top3WithEmitRetractValue extends Top3 {
+
+ val add: ListBuffer[Int] = new ListBuffer[Int]
+ val retract: ListBuffer[Int] = new ListBuffer[Int]
+
+ override def accumulate(acc: Top3Accum, v: Int) {
+ if (acc.size == 0) {
+ acc.size = 1
+ acc.smallest = v
+ acc.data.put(v, 1)
+ add.append(v)
+ } else if (acc.size < 3) {
+ add(acc, v)
+ if (v < acc.smallest) {
+ acc.smallest = v
+ }
+ add.append(v)
+ } else if (v > acc.smallest) {
+ delete(acc, acc.smallest)
+ retract.append(acc.smallest)
+ add(acc, v)
+ add.append(v)
+ updateSmallest(acc)
+ }
+ }
+
+ def emitUpdateWithRetract(
+ acc: Top3Accum,
+ out: RetractableCollector[JTuple2[JInt, JInt]])
+ : Unit = {
+ retract.foreach(e => out.retract(JTuple2.of(e, e)))
+ add.foreach(e => out.collect(JTuple2.of(e, e)))
+ retract.clear()
+ add.clear()
+ }
+}
+
/**
* Note: This function suffers performance problem. Only use it in tests.
*/
@@ -207,10 +246,7 @@ class Top3WithMapView extends
TableAggregateFunction[JTuple2[JInt, JInt], Top3Wi
}
}
-/**
- * Test function for plan test.
- */
-class EmptyTableAggFunc extends TableAggregateFunction[JTuple2[JInt, JInt],
Top3Accum] {
+class EmptyTableAggFuncWithoutEmit extends
TableAggregateFunction[JTuple2[JInt, JInt], Top3Accum] {
override def createAccumulator(): Top3Accum = new Top3Accum
@@ -219,6 +255,12 @@ class EmptyTableAggFunc extends
TableAggregateFunction[JTuple2[JInt, JInt], Top3
def accumulate(acc: Top3Accum, category: Long, value: Int): Unit = {}
def accumulate(acc: Top3Accum, value: Int): Unit = {}
+}
+
+/**
+ * Test function for plan test.
+ */
+class EmptyTableAggFunc extends EmptyTableAggFuncWithoutEmit {
def emitValue(acc: Top3Accum, out: Collector[JTuple2[JInt, JInt]]): Unit = {}
}