[FLINK-7658] [table] Add Collect aggregate function to Table API.

This closes #5472.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/a4442f89
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/a4442f89
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/a4442f89

Branch: refs/heads/master
Commit: a4442f893b3b6f927e334f8090f5a57e5b953fe5
Parents: 75b8398
Author: Shuyi Chen <sh...@uber.com>
Authored: Tue Feb 13 01:18:12 2018 -0800
Committer: Fabian Hueske <fhue...@apache.org>
Committed: Wed Feb 14 20:13:11 2018 +0100

----------------------------------------------------------------------
 docs/dev/table/tableApi.md                      | 23 +++++++++++++++++++-
 .../flink/table/api/scala/expressionDsl.scala   |  5 +++++
 .../flink/table/expressions/aggregations.scala  | 22 +++++++++++++++++++
 .../flink/table/validate/FunctionCatalog.scala  |  1 +
 .../AggregateStringExpressionTest.scala         |  4 ++--
 .../runtime/batch/table/AggregateITCase.scala   | 16 ++++++++++++++
 .../runtime/stream/table/AggregateITCase.scala  | 21 ++++++++++++++++++
 7 files changed, 89 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/a4442f89/docs/dev/table/tableApi.md
----------------------------------------------------------------------
diff --git a/docs/dev/table/tableApi.md b/docs/dev/table/tableApi.md
index 0408c3c..4e7c544 100644
--- a/docs/dev/table/tableApi.md
+++ b/docs/dev/table/tableApi.md
@@ -2885,6 +2885,17 @@ FIELD.varSamp
       </td>
     </tr>
 
+    <tr>
+      <td>
+        {% highlight java %}
+FIELD.collect
+        {% endhighlight %}
+      </td>
+      <td>
+        <p>Returns the multiset aggregate of the input value.</p>
+      </td>
+    </tr>
+
     </tbody>
 </table>
 
@@ -4294,6 +4305,17 @@ FIELD.varSamp
         <p>Returns the sample variance (square of the sample standard 
deviation) of the numeric field across all input values.</p>
       </td>
     </tr>
+
+    <tr>
+      <td>
+        {% highlight scala %}
+FIELD.collect
+        {% endhighlight %}
+      </td>
+      <td>
+        <p>Returns the multiset aggregate of the input value.</p>
+      </td>
+    </tr>
   </tbody>
 </table>
 
@@ -4551,7 +4573,6 @@ The following operations are not supported yet:
 
 - Binary string operators and functions
 - System functions
-- Collection functions
 - Aggregate functions like REGR_xxx
 - Distinct aggregate functions like COUNT DISTINCT
 

http://git-wip-us.apache.org/repos/asf/flink/blob/a4442f89/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala
index c520433..f73442b 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/scala/expressionDsl.scala
@@ -214,6 +214,11 @@ trait ImplicitExpressionOperations {
   def varSamp = VarSamp(expr)
 
   /**
+    *  Returns multiset aggregate of a given expression.
+    */
+  def collect = Collect(expr)
+
+  /**
     * Converts a value to a given type.
     *
     * e.g. "42".cast(Types.INT) leads to 42.

http://git-wip-us.apache.org/repos/asf/flink/blob/a4442f89/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
index 47d1519..b39bd98 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/aggregations.scala
@@ -28,6 +28,7 @@ import org.apache.flink.table.functions.AggregateFunction
 import org.apache.flink.table.functions.utils.AggSqlFunction
 import org.apache.flink.table.typeutils.TypeCheckUtils
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo
+import org.apache.flink.api.java.typeutils.MultisetTypeInfo
 import org.apache.flink.table.calcite.FlinkTypeFactory
 import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._
 import org.apache.flink.table.validate.{ValidationFailure, ValidationResult, 
ValidationSuccess}
@@ -158,6 +159,27 @@ case class Avg(child: Expression) extends Aggregation {
   }
 }
 
+/**
+  * Returns a multiset aggregates.
+  */
+case class Collect(child: Expression) extends Aggregation  {
+
+  override private[flink] def children: Seq[Expression] = Seq(child)
+
+  override private[flink] def resultType: TypeInformation[_] =
+    MultisetTypeInfo.getInfoFor(child.resultType)
+
+  override def toString: String = s"collect($child)"
+
+  override private[flink] def toAggCall(name: String)(implicit relBuilder: 
RelBuilder): AggCall = {
+    relBuilder.aggregateCall(SqlStdOperatorTable.COLLECT, false, false, null, 
name, child.toRexNode)
+  }
+
+  override private[flink] def getSqlAggFunction()(implicit relBuilder: 
RelBuilder) = {
+    SqlStdOperatorTable.COLLECT
+  }
+}
+
 case class StddevPop(child: Expression) extends Aggregation {
   override private[flink] def children: Seq[Expression] = Seq(child)
   override def toString = s"stddev_pop($child)"

http://git-wip-us.apache.org/repos/asf/flink/blob/a4442f89/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
index 773b2f0..36c99e9 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala
@@ -176,6 +176,7 @@ object FunctionCatalog {
     "stddevSamp" -> classOf[StddevSamp],
     "varPop" -> classOf[VarPop],
     "varSamp" -> classOf[VarSamp],
+    "collect" -> classOf[Collect],
 
     // string functions
     "charLength" -> classOf[CharLength],

http://git-wip-us.apache.org/repos/asf/flink/blob/a4442f89/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/stringexpr/AggregateStringExpressionTest.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/stringexpr/AggregateStringExpressionTest.scala
 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/stringexpr/AggregateStringExpressionTest.scala
index e148b47..4bbb101 100644
--- 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/stringexpr/AggregateStringExpressionTest.scala
+++ 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/stringexpr/AggregateStringExpressionTest.scala
@@ -43,8 +43,8 @@ class AggregateStringExpressionTest extends TableTestBase {
     val util = batchTestUtil()
     val t = util.addTable[(Byte, Short, Int, Long, Float, Double, 
String)]("Table7")
 
-    val t1 = t.select('_1.avg, '_2.avg, '_3.avg, '_4.avg, '_5.avg, '_6.avg, 
'_7.count)
-    val t2 = t.select("_1.avg, _2.avg, _3.avg, _4.avg, _5.avg, _6.avg, 
_7.count")
+    val t1 = t.select('_1.avg, '_2.avg, '_3.avg, '_4.avg, '_5.avg, '_6.avg, 
'_7.count, '_7.collect)
+    val t2 = t.select("_1.avg, _2.avg, _3.avg, _4.avg, _5.avg, _6.avg, 
_7.count, _7.collect")
 
     verifyTableEquals(t1, t2)
   }

http://git-wip-us.apache.org/repos/asf/flink/blob/a4442f89/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/AggregateITCase.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/AggregateITCase.scala
 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/AggregateITCase.scala
index 892e4f3..102ff0d 100644
--- 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/AggregateITCase.scala
+++ 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/AggregateITCase.scala
@@ -439,6 +439,22 @@ class AggregationsITCase(
     val results = t.toDataSet[Row].collect()
     TestBaseUtils.compareResultAsText(results.asJava, expected)
   }
+
+  @Test
+  def testCollect(): Unit = {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    val tEnv = TableEnvironment.getTableEnvironment(env, config)
+
+    val t = CollectionDataSets.get3TupleDataSet(env).toTable(tEnv, 'a, 'b, 'c)
+      .groupBy('b)
+      .select('b, 'a.collect)
+
+    val expected =
+      "1,{1=1}\n2,{2=1, 3=1}\n3,{4=1, 5=1, 6=1}\n4,{8=1, 9=1, 10=1, 7=1}\n" +
+        "5,{11=1, 12=1, 13=1, 14=1, 15=1}\n6,{16=1, 17=1, 18=1, 19=1, 20=1, 
21=1}"
+    val results = t.toDataSet[Row].collect()
+    TestBaseUtils.compareResultAsText(results.asJava, expected)
+  }
 }
 
 case class WC(word: String, frequency: Long)

http://git-wip-us.apache.org/repos/asf/flink/blob/a4442f89/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala
 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala
index 67558d9..db6820a 100644
--- 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala
+++ 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala
@@ -157,6 +157,27 @@ class AggregateITCase extends StreamingWithStateTestBase {
   }
 
   @Test
+  def testCollect(): Unit = {
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    env.setStateBackend(getStateBackend)
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+    StreamITCase.clear
+
+    val t = StreamTestData.get3TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c)
+      .groupBy('b)
+      .select('b, 'a.collect)
+
+    val results = t.toRetractStream[Row](queryConfig)
+    results.addSink(new RetractingSink)
+    env.execute()
+
+    val expected = mutable.MutableList(
+      "1,{1=1}", "2,{2=1, 3=1}", "3,{4=1, 5=1, 6=1}", "4,{7=1, 8=1, 9=1, 
10=1}",
+      "5,{11=1, 12=1, 13=1, 14=1, 15=1}", "6,{16=1, 17=1, 18=1, 19=1, 20=1, 
21=1}")
+    assertEquals(expected.sorted, StreamITCase.retractedResults.sorted)
+  }
+
+  @Test
   def testGroupAggregateWithStateBackend(): Unit = {
     val env = StreamExecutionEnvironment.getExecutionEnvironment
     env.setStateBackend(getStateBackend)

Reply via email to