Repository: flink Updated Branches: refs/heads/master c80e76bd7 -> a273f645b
[FLINK-8838] [table] Add support for UNNEST on MultiSet fields This closes #5619. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/a273f645 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/a273f645 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/a273f645 Branch: refs/heads/master Commit: a273f645b4120b143ac122e4db91616cd32bec5b Parents: c80e76b Author: lincoln-lil <[email protected]> Authored: Fri Mar 2 20:05:44 2018 +0800 Committer: Timo Walther <[email protected]> Committed: Tue May 22 16:59:09 2018 +0200 ---------------------------------------------------------------------- .../plan/rules/logical/LogicalUnnestRule.scala | 27 ++-- .../table/plan/util/ExplodeFunctionUtil.scala | 132 ++++++++++++++----- .../runtime/batch/sql/AggregateITCase.scala | 31 ++++- .../table/runtime/stream/sql/SqlITCase.scala | 74 ++++++++++- 4 files changed, 218 insertions(+), 46 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/a273f645/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalUnnestRule.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalUnnestRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalUnnestRule.scala index 23dfc03..8ef9fd3 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalUnnestRule.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/logical/LogicalUnnestRule.scala @@ -32,7 +32,7 @@ import org.apache.calcite.sql.`type`.AbstractSqlType import org.apache.flink.table.api.TableException import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils -import org.apache.flink.table.plan.schema.ArrayRelDataType +import org.apache.flink.table.plan.schema.{ArrayRelDataType, MultisetRelDataType} import org.apache.flink.table.plan.util.ExplodeFunctionUtil class LogicalUnnestRule( @@ -76,22 +76,27 @@ class LogicalUnnestRule( case uc: Uncollect => // convert Uncollect into TableFunctionScan val cluster = correlate.getCluster + val dataType = uc.getInput.getRowType.getFieldList.get(0).getValue + val (componentType, explodeTableFunc) = dataType match { + case arrayType: ArrayRelDataType => + (arrayType.getComponentType, + ExplodeFunctionUtil.explodeTableFuncFromType(arrayType.typeInfo)) + case mt: MultisetRelDataType => + (mt.getComponentType, ExplodeFunctionUtil.explodeTableFuncFromType(mt.typeInfo)) + case _ => throw TableException(s"Unsupported UNNEST on type: ${dataType.toString}") + } - val arrayType = - uc.getInput.getRowType.getFieldList.get(0).getValue.asInstanceOf[ArrayRelDataType] - val componentType = arrayType.getComponentType - - // create table function - val explodeTableFunc = UserDefinedFunctionUtils.createTableSqlFunction( + // create sql function + val explodeSqlFunc = UserDefinedFunctionUtils.createTableSqlFunction( "explode", "explode", - ExplodeFunctionUtil.explodeTableFuncFromType(arrayType.typeInfo), - FlinkTypeFactory.toTypeInfo(arrayType.getComponentType), + explodeTableFunc, + FlinkTypeFactory.toTypeInfo(componentType), cluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]) // create table function call val rexCall = cluster.getRexBuilder.makeCall( - explodeTableFunc, + explodeSqlFunc, uc.getInput.asInstanceOf[RelSubset] .getOriginal.asInstanceOf[LogicalProject].getChildExps ) @@ -104,7 +109,7 @@ class LogicalUnnestRule( ImmutableList.of(new RelDataTypeFieldImpl("f0", 0, componentType))) case _: RelRecordType => componentType case _ => throw TableException( - s"Unsupported array component type in UNNEST: ${componentType.toString}") + s"Unsupported component type in UNNEST: ${componentType.toString}") } // create table function scan http://git-wip-us.apache.org/repos/asf/flink/blob/a273f645/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/ExplodeFunctionUtil.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/ExplodeFunctionUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/ExplodeFunctionUtil.scala index 1bcc6d9..cfcaa84 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/ExplodeFunctionUtil.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/ExplodeFunctionUtil.scala @@ -18,74 +18,146 @@ package org.apache.flink.table.plan.util +import java.util + import org.apache.flink.api.common.typeinfo.{BasicArrayTypeInfo, BasicTypeInfo, PrimitiveArrayTypeInfo, TypeInformation} -import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo +import org.apache.flink.api.java.typeutils.{MultisetTypeInfo, ObjectArrayTypeInfo} +import org.apache.flink.table.api.TableException import org.apache.flink.table.functions.TableFunction -class ObjectExplodeTableFunc extends TableFunction[Object] { +abstract class ExplodeTableFunction[T] extends TableFunction[T] { + + def collectArray(array: Array[T]): Unit = { + if (null != array) { + var i = 0 + while (i < array.length) { + collect(array(i)) + i += 1 + } + } + } + + def collect(map: util.Map[T, Integer]): Unit = { + if (null != map) { + val it = map.entrySet().iterator() + while (it.hasNext) { + val item = it.next() + val key: T = item.getKey + val cnt: Int = item.getValue + var i = 0 + while (i < cnt) { + collect(key) + i += 1 + } + } + } + } +} + +class ObjectExplodeTableFunc extends ExplodeTableFunction[Object] { def eval(arr: Array[Object]): Unit = { - arr.foreach(collect) + collectArray(arr) + } + + def eval(map: util.Map[Object, Integer]): Unit = { + collect(map) } } -class FloatExplodeTableFunc extends TableFunction[Float] { +class FloatExplodeTableFunc extends ExplodeTableFunction[Float] { def eval(arr: Array[Float]): Unit = { - arr.foreach(collect) + collectArray(arr) + } + + def eval(map: util.Map[Float, Integer]): Unit = { + collect(map) } } -class ShortExplodeTableFunc extends TableFunction[Short] { +class ShortExplodeTableFunc extends ExplodeTableFunction[Short] { def eval(arr: Array[Short]): Unit = { - arr.foreach(collect) + collectArray(arr) + } + + def eval(map: util.Map[Short, Integer]): Unit = { + collect(map) } } -class IntExplodeTableFunc extends TableFunction[Int] { +class IntExplodeTableFunc extends ExplodeTableFunction[Int] { def eval(arr: Array[Int]): Unit = { - arr.foreach(collect) + collectArray(arr) + } + + def eval(map: util.Map[Int, Integer]): Unit = { + collect(map) } } -class LongExplodeTableFunc extends TableFunction[Long] { +class LongExplodeTableFunc extends ExplodeTableFunction[Long] { def eval(arr: Array[Long]): Unit = { - arr.foreach(collect) + collectArray(arr) + } + + def eval(map: util.Map[Long, Integer]): Unit = { + collect(map) } } -class DoubleExplodeTableFunc extends TableFunction[Double] { +class DoubleExplodeTableFunc extends ExplodeTableFunction[Double] { def eval(arr: Array[Double]): Unit = { - arr.foreach(collect) + collectArray(arr) + } + + def eval(map: util.Map[Double, Integer]): Unit = { + collect(map) } } -class ByteExplodeTableFunc extends TableFunction[Byte] { +class ByteExplodeTableFunc extends ExplodeTableFunction[Byte] { def eval(arr: Array[Byte]): Unit = { - arr.foreach(collect) + collectArray(arr) + } + + def eval(map: util.Map[Byte, Integer]): Unit = { + collect(map) } } -class BooleanExplodeTableFunc extends TableFunction[Boolean] { +class BooleanExplodeTableFunc extends ExplodeTableFunction[Boolean] { def eval(arr: Array[Boolean]): Unit = { - arr.foreach(collect) + collectArray(arr) + } + + def eval(map: util.Map[Boolean, Integer]): Unit = { + collect(map) } } object ExplodeFunctionUtil { - def explodeTableFuncFromType(ti: TypeInformation[_]):TableFunction[_] = { + def explodeTableFuncFromType(ti: TypeInformation[_]): TableFunction[_] = { ti match { - case pat: PrimitiveArrayTypeInfo[_] => { - pat.getComponentType match { - case BasicTypeInfo.INT_TYPE_INFO => new IntExplodeTableFunc - case BasicTypeInfo.LONG_TYPE_INFO => new LongExplodeTableFunc - case BasicTypeInfo.SHORT_TYPE_INFO => new ShortExplodeTableFunc - case BasicTypeInfo.FLOAT_TYPE_INFO => new FloatExplodeTableFunc - case BasicTypeInfo.DOUBLE_TYPE_INFO => new DoubleExplodeTableFunc - case BasicTypeInfo.BYTE_TYPE_INFO => new ByteExplodeTableFunc - case BasicTypeInfo.BOOLEAN_TYPE_INFO => new BooleanExplodeTableFunc - } - } + case pat: PrimitiveArrayTypeInfo[_] => createTableFuncByType(pat.getComponentType) + case _: ObjectArrayTypeInfo[_, _] => new ObjectExplodeTableFunc + case _: BasicArrayTypeInfo[_, _] => new ObjectExplodeTableFunc - case _ => throw new UnsupportedOperationException(ti.toString + "IS NOT supported") + + case mt: MultisetTypeInfo[_] => createTableFuncByType(mt.getElementTypeInfo) + + case _ => throw new TableException("Unnesting of '" + ti.toString + "' is not supported.") + } + } + + def createTableFuncByType(typeInfo: TypeInformation[_]): TableFunction[_] = { + typeInfo match { + case BasicTypeInfo.INT_TYPE_INFO => new IntExplodeTableFunc + case BasicTypeInfo.LONG_TYPE_INFO => new LongExplodeTableFunc + case BasicTypeInfo.SHORT_TYPE_INFO => new ShortExplodeTableFunc + case BasicTypeInfo.FLOAT_TYPE_INFO => new FloatExplodeTableFunc + case BasicTypeInfo.DOUBLE_TYPE_INFO => new DoubleExplodeTableFunc + case BasicTypeInfo.BYTE_TYPE_INFO => new ByteExplodeTableFunc + case BasicTypeInfo.BOOLEAN_TYPE_INFO => new BooleanExplodeTableFunc + case _ => new ObjectExplodeTableFunc } } } http://git-wip-us.apache.org/repos/asf/flink/blob/a273f645/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/AggregateITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/AggregateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/AggregateITCase.scala index ac0b705..09ccfc4 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/AggregateITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/AggregateITCase.scala @@ -22,9 +22,9 @@ import org.apache.calcite.runtime.SqlFunctions.{internalToTimestamp => toTimesta import org.apache.flink.api.scala._ import org.apache.flink.api.scala.util.CollectionDataSets import org.apache.flink.table.api.TableEnvironment -import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.WeightedAvgWithMergeAndReset import org.apache.flink.table.api.scala._ import org.apache.flink.table.functions.aggfunctions.CountAggFunction +import org.apache.flink.table.runtime.utils.JavaUserDefinedAggFunctions.WeightedAvgWithMergeAndReset import org.apache.flink.table.runtime.utils.TableProgramsCollectionTestBase import org.apache.flink.table.runtime.utils.TableProgramsTestBase.TableConfigMode import org.apache.flink.table.utils.NonMergableCount @@ -35,7 +35,6 @@ import org.junit.runner.RunWith import org.junit.runners.Parameterized import scala.collection.JavaConverters._ -import scala.collection.mutable @RunWith(classOf[Parameterized]) class AggregateITCase( @@ -368,6 +367,34 @@ class AggregateITCase( } @Test + def testTumbleWindowAggregateWithCollectUnnest(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val ds = CollectionDataSets.get3TupleDataSet(env) + // create timestamps + .map(x => (x._1, x._2, x._3, toTimestamp(x._1 * 1000))) + tEnv.registerDataSet("t1", ds, 'a, 'b, 'c, 'ts) + + val t2 = tEnv.sqlQuery("SELECT b, COLLECT(b) as `set`" + + "FROM t1 " + + "GROUP BY b, TUMBLE(ts, INTERVAL '3' SECOND)") + tEnv.registerTable("t2", t2) + + val result = tEnv.sqlQuery("SELECT b, s FROM t2, UNNEST(t2.`set`) AS A(s) where b < 3") + .toDataSet[Row] + .collect() + + val expected = Seq( + "1,1", + "2,2", + "2,2" + ).mkString("\n") + + TestBaseUtils.compareResultAsText(result.asJava, expected) + } + + @Test def testTumbleWindowWithProperties(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env, config) http://git-wip-us.apache.org/repos/asf/flink/blob/a273f645/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala index 9155ff9..e132349 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala @@ -480,7 +480,6 @@ class SqlITCase extends StreamingWithStateTestBase { @Test def testUnnestPrimitiveArrayFromTable(): Unit = { - val env = StreamExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env) StreamITCase.clear @@ -512,7 +511,6 @@ class SqlITCase extends StreamingWithStateTestBase { @Test def testUnnestArrayOfArrayFromTable(): Unit = { - val env = StreamExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env) StreamITCase.clear @@ -542,7 +540,6 @@ class SqlITCase extends StreamingWithStateTestBase { @Test def testUnnestObjectArrayFromTableWithFilter(): Unit = { - val env = StreamExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env) StreamITCase.clear @@ -568,6 +565,77 @@ class SqlITCase extends StreamingWithStateTestBase { } @Test + def testUnnestMultiSetFromCollectResult(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + env.setStateBackend(getStateBackend) + StreamITCase.clear + + val data = List( + (1, 1, (12, "45.6")), + (2, 2, (12, "45.612")), + (3, 2, (13, "41.6")), + (4, 3, (14, "45.2136")), + (5, 3, (18, "42.6"))) + tEnv.registerTable("t1", env.fromCollection(data).toTable(tEnv).as('a, 'b, 'c)) + + val t2 = tEnv.sqlQuery("SELECT b, COLLECT(c) as `set` FROM t1 GROUP BY b") + tEnv.registerTable("t2", t2) + + val result = tEnv + .sqlQuery("SELECT b, id, point FROM t2, UNNEST(t2.`set`) AS A(id, point) WHERE b < 3") + .toRetractStream[Row] + result.addSink(new StreamITCase.RetractingSink).setParallelism(1) + env.execute() + + val expected = List( + "1,12,45.6", + "2,12,45.612", + "2,13,41.6") + assertEquals(expected.sorted, StreamITCase.retractedResults.sorted) + } + + @Test + def testLeftUnnestMultiSetFromCollectResult(): Unit = { + val data = List( + (1, "1", "Hello"), + (1, "2", "Hello2"), + (2, "2", "Hello"), + (3, null.asInstanceOf[String], "Hello"), + (4, "4", "Hello"), + (5, "5", "Hello"), + (5, null.asInstanceOf[String], "Hello"), + (6, "6", "Hello"), + (7, "7", "Hello World"), + (7, "8", "Hello World")) + + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val t1 = env.fromCollection(data).toTable(tEnv).as('a, 'b, 'c) + tEnv.registerTable("t1", t1) + + val t2 = tEnv.sqlQuery("SELECT a, COLLECT(b) as `set` FROM t1 GROUP BY a") + tEnv.registerTable("t2", t2) + + val result = tEnv + .sqlQuery("SELECT a, s FROM t2 LEFT JOIN UNNEST(t2.`set`) AS A(s) ON TRUE WHERE a < 5") + .toRetractStream[Row] + result.addSink(new StreamITCase.RetractingSink).setParallelism(1) + env.execute() + + val expected = List( + "1,1", + "1,2", + "2,2", + "3,null", + "4,4" + ) + assertEquals(expected.sorted, StreamITCase.retractedResults.sorted) + } + + @Test def testHopStartEndWithHaving(): Unit = { val env = StreamExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env)
