Repository: flink Updated Branches: refs/heads/master 0bdf3a74c -> b820fd3ca
[FLINK-5571] [table] add open and close methods for UserDefinedFunction This closes #3176. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/b820fd3c Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/b820fd3c Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/b820fd3c Branch: refs/heads/master Commit: b820fd3ca038e411bc7f43e1c35637bf62981fe5 Parents: 0bdf3a7 Author: godfreyhe <godfre...@163.com> Authored: Fri Jan 20 14:42:12 2017 +0800 Committer: twalthr <twal...@apache.org> Committed: Mon Feb 20 15:38:34 2017 +0100 ---------------------------------------------------------------------- docs/dev/table_api.md | 92 +++++++++++++++++ .../flink/table/codegen/CodeGenerator.scala | 78 +++++++++++--- .../flink/table/functions/FunctionContext.scala | 66 ++++++++++++ .../table/functions/UserDefinedFunction.scala | 17 ++- .../table/runtime/CorrelateFlatMapRunner.scala | 7 ++ .../flink/table/runtime/FlatMapRunner.scala | 7 ++ .../table/api/scala/batch/sql/CalcITCase.scala | 6 +- .../api/scala/batch/table/CalcITCase.scala | 8 +- .../table/api/scala/stream/sql/SqlITCase.scala | 6 +- .../api/scala/stream/table/CalcITCase.scala | 8 +- .../UserDefinedScalarFunctionTest.scala | 31 +++++- .../expressions/utils/ExpressionTestBase.scala | 40 ++++++- .../utils/UserDefinedScalarFunctions.scala | 97 ++++++++++++++++- .../runtime/dataset/DataSetCalcITCase.scala | 103 +++++++++++++++++++ .../dataset/DataSetCorrelateITCase.scala | 52 +++++++++- .../datastream/DataStreamCalcITCase.scala | 81 +++++++++++++++ .../datastream/DataStreamCorrelateITCase.scala | 67 ++++++++++-- .../utils/UserDefinedFunctionTestUtils.scala | 53 ++++++++++ .../table/utils/UserDefinedTableFunctions.scala | 58 ++++++++++- 19 files changed, 828 insertions(+), 49 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/docs/dev/table_api.md ---------------------------------------------------------------------- diff --git a/docs/dev/table_api.md b/docs/dev/table_api.md index 99ae711..22fd636 100644 --- a/docs/dev/table_api.md +++ b/docs/dev/table_api.md @@ -4724,6 +4724,11 @@ ELEMENT(ARRAY) </div> </div> +{% top %} + +User-defined Functions +---------------- + ### User-defined Scalar Functions If a required scalar function is not contained in the built-in functions, it is possible to define custom, user-defined scalar functions for both the Table API and SQL. A user-defined scalar functions maps zero, one, or multiple scalar values to a new scalar value. @@ -4933,6 +4938,93 @@ class CustomTypeSplit extends TableFunction[Row] { </div> </div> +### Advanced Function Features + +Sometimes it might be necessary for a user-defined function to get global runtime information or do some setup/clean-up work before the actual work. User-defined functions provide `open()` and `close()` methods that can be overriden and provide similar functionality as the methods in `RichFunction` of DataSet or DataStream API. + +The `open()` method is called once before the evaluation method. The `close()` method after the last call to the evaluation method. + +The `open()` method provides a `FunctionContext` that contains information about the context in which user-defined functions are executed, such as the metric group, the distributed cache files, or the global job parameters. + +The following information can be obtained by calling the corresponding methods of `FunctionContext`: + +| Method | Description | +| :------------------------------------ | :----------------------------------------------------- | +| `getMetricGroup()` | Metric group for this parallel subtask. | +| `getCachedFile(name)` | Local temporary file copy of a distributed cache file. | +| `getJobParameter(name, defaultValue)` | Global job parameter value associated with given key. | + +The following example snippet shows how to use `FunctionContext` in a scalar function for accessing a global job parameter: + +<div class="codetabs" markdown="1"> +<div data-lang="java" markdown="1"> +{% highlight java %} +public class HashCode extends ScalarFunction { + + private int factor = 0; + + @Override + public void open(FunctionContext context) throws Exception { + // access "hashcode_factor" parameter + // "12" would be the default value if parameter does not exist + factor = Integer.valueOf(context.getJobParameter("hashcode_factor", "12")); + } + + public int eval(String s) { + return s.hashCode() * factor; + } +} + +ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); +BatchTableEnvironment tableEnv = TableEnvironment.getTableEnvironment(env); + +// set job parameter +Configuration conf = new Configuration(); +conf.setString("hashcode_factor", "31"); +env.getConfig().setGlobalJobParameters(conf); + +// register the function +tableEnv.registerFunction("hashCode", new HashCode()) + +// use the function in Java Table API +myTable.select("string, string.hashCode(), hashCode(string)"); + +// use the function in SQL +tableEnv.sql("SELECT string, HASHCODE(string) FROM MyTable"); +{% endhighlight %} +</div> + +<div data-lang="scala" markdown="1"> +{% highlight scala %} +object hashCode extends ScalarFunction { + + var hashcode_factor = 12; + + override def open(context: FunctionContext): Unit = { + // access "hashcode_factor" parameter + // "12" would be the default value if parameter does not exist + hashcode_factor = context.getJobParameter("hashcode_factor", "12").toInt + } + + def eval(s: String): Int = { + s.hashCode() * hashcode_factor + } +} + +val tableEnv = TableEnvironment.getTableEnvironment(env) + +// use the function in Scala Table API +myTable.select('string, hashCode('string)) + +// register and use the function in SQL +tableEnv.registerFunction("hashCode", hashCode) +tableEnv.sql("SELECT string, HASHCODE(string) FROM MyTable"); +{% endhighlight %} + +</div> +</div> + + ### Limitations The following operations are not supported yet: http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala index c679bd8..441b1c0 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala @@ -25,12 +25,13 @@ import org.apache.calcite.rex._ import org.apache.calcite.sql.SqlOperator import org.apache.calcite.sql.`type`.SqlTypeName._ import org.apache.calcite.sql.fun.SqlStdOperatorTable._ -import org.apache.flink.api.common.functions.{FlatJoinFunction, FlatMapFunction, Function, MapFunction} +import org.apache.flink.api.common.functions._ import org.apache.flink.api.common.io.GenericInputFormat import org.apache.flink.api.common.typeinfo.{AtomicType, SqlTimeTypeInfo, TypeInformation} import org.apache.flink.api.common.typeutils.CompositeType import org.apache.flink.api.java.typeutils.{GenericTypeInfo, PojoTypeInfo, RowTypeInfo, TupleTypeInfo} import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo +import org.apache.flink.configuration.Configuration import org.apache.flink.table.api.TableConfig import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.codegen.CodeGenUtils._ @@ -38,7 +39,7 @@ import org.apache.flink.table.codegen.GeneratedExpression.{NEVER_NULL, NO_CODE} import org.apache.flink.table.codegen.Indenter.toISC import org.apache.flink.table.codegen.calls.FunctionGenerator import org.apache.flink.table.codegen.calls.ScalarOperators._ -import org.apache.flink.table.functions.UserDefinedFunction +import org.apache.flink.table.functions.{FunctionContext, UserDefinedFunction} import org.apache.flink.table.runtime.TableFunctionCollector import org.apache.flink.table.typeutils.TypeCheckUtils._ import org.apache.flink.types.Row @@ -122,6 +123,14 @@ class CodeGenerator( // we use a LinkedHashSet to keep the insertion order private val reusableInitStatements = mutable.LinkedHashSet[String]() + // set of open statements for RichFunction that will be added only once + // we use a LinkedHashSet to keep the insertion order + private val reusableOpenStatements = mutable.LinkedHashSet[String]() + + // set of close statements for RichFunction that will be added only once + // we use a LinkedHashSet to keep the insertion order + private val reusableCloseStatements = mutable.LinkedHashSet[String]() + // set of statements that will be added only once per record // we use a LinkedHashSet to keep the insertion order private val reusablePerRecordStatements = mutable.LinkedHashSet[String]() @@ -150,6 +159,20 @@ class CodeGenerator( } /** + * @return code block of statements that need to be placed in the open() method of RichFunction + */ + def reuseOpenCode(): String = { + reusableOpenStatements.mkString("", "\n", "\n") + } + + /** + * @return code block of statements that need to be placed in the close() method of RichFunction + */ + def reuseCloseCode(): String = { + reusableCloseStatements.mkString("", "\n", "\n") + } + + /** * @return code block of statements that need to be placed in the SAM of the Function */ def reusePerRecordCode(): String = { @@ -240,27 +263,33 @@ class CodeGenerator( // manual casting here val samHeader = // FlatMapFunction - if (clazz == classOf[FlatMapFunction[_,_]]) { + if (clazz == classOf[FlatMapFunction[_, _]]) { + val baseClass = classOf[RichFlatMapFunction[_, _]] val inputTypeTerm = boxedTypeTermForTypeInfo(input1) - (s"void flatMap(Object _in1, org.apache.flink.util.Collector $collectorTerm)", + (baseClass, + s"void flatMap(Object _in1, org.apache.flink.util.Collector $collectorTerm)", List(s"$inputTypeTerm $input1Term = ($inputTypeTerm) _in1;")) } // MapFunction - else if (clazz == classOf[MapFunction[_,_]]) { + else if (clazz == classOf[MapFunction[_, _]]) { + val baseClass = classOf[RichMapFunction[_, _]] val inputTypeTerm = boxedTypeTermForTypeInfo(input1) - ("Object map(Object _in1)", + (baseClass, + "Object map(Object _in1)", List(s"$inputTypeTerm $input1Term = ($inputTypeTerm) _in1;")) } // FlatJoinFunction - else if (clazz == classOf[FlatJoinFunction[_,_,_]]) { + else if (clazz == classOf[FlatJoinFunction[_, _, _]]) { + val baseClass = classOf[RichFlatJoinFunction[_, _, _]] val inputTypeTerm1 = boxedTypeTermForTypeInfo(input1) val inputTypeTerm2 = boxedTypeTermForTypeInfo(input2.getOrElse( - throw new CodeGenException("Input 2 for FlatJoinFunction should not be null"))) - (s"void join(Object _in1, Object _in2, org.apache.flink.util.Collector $collectorTerm)", + throw new CodeGenException("Input 2 for FlatJoinFunction should not be null"))) + (baseClass, + s"void join(Object _in1, Object _in2, org.apache.flink.util.Collector $collectorTerm)", List(s"$inputTypeTerm1 $input1Term = ($inputTypeTerm1) _in1;", - s"$inputTypeTerm2 $input2Term = ($inputTypeTerm2) _in2;")) + s"$inputTypeTerm2 $input2Term = ($inputTypeTerm2) _in2;")) } else { // TODO more functions @@ -269,7 +298,7 @@ class CodeGenerator( val funcCode = j""" public class $funcName - implements ${clazz.getCanonicalName} { + extends ${samHeader._1.getCanonicalName} { ${reuseMemberCode()} @@ -280,12 +309,22 @@ class CodeGenerator( ${reuseConstructorCode(funcName)} @Override - public ${samHeader._1} throws Exception { - ${samHeader._2.mkString("\n")} + public void open(${classOf[Configuration].getCanonicalName} parameters) throws Exception { + ${reuseOpenCode()} + } + + @Override + public ${samHeader._2} throws Exception { + ${samHeader._3.mkString("\n")} ${reusePerRecordCode()} ${reuseInputUnboxingCode()} $bodyCode } + + @Override + public void close() throws Exception { + ${reuseCloseCode()} + } } """.stripMargin @@ -1480,6 +1519,19 @@ class CodeGenerator( |$fieldTerm = ($classQualifier) $constructorTerm.newInstance(); """.stripMargin reusableInitStatements.add(constructorAccessibility) + + val openFunction = + s""" + |$fieldTerm.open(new ${classOf[FunctionContext].getCanonicalName}(getRuntimeContext())); + """.stripMargin + reusableOpenStatements.add(openFunction) + + val closeFunction = + s""" + |$fieldTerm.close(); + """.stripMargin + reusableCloseStatements.add(closeFunction) + fieldTerm } http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/FunctionContext.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/FunctionContext.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/FunctionContext.scala new file mode 100644 index 0000000..beeb686 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/FunctionContext.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions + +import java.io.File + +import org.apache.flink.api.common.functions.RuntimeContext +import org.apache.flink.metrics.MetricGroup + +/** + * A FunctionContext allows to obtain global runtime information about the context in which the + * user-defined function is executed. The information include the metric group, + * the distributed cache files, and the global job parameters. + * + * @param context the runtime context in which the Flink Function is executed + */ +class FunctionContext(context: RuntimeContext) { + + /** + * Returns the metric group for this parallel subtask. + * + * @return metric group for this parallel subtask. + */ + def getMetricGroup: MetricGroup = context.getMetricGroup + + /** + * Gets the local temporary file copy of a distributed cache files. + * + * @param name distributed cache file name + * @return local temporary file copy of a distributed cache file. + */ + def getCachedFile(name: String): File = context.getDistributedCache.getFile(name) + + /** + * Gets the global job parameter value associated with the given key as a string. + * + * @param key key pointing to the associated value + * @param defaultValue default value which is returned in case global job parameter is null + * or there is no value associated with the given key + * @return (default) value associated with the given key + */ + def getJobParameter(key: String, defaultValue: String): String = { + val conf = context.getExecutionConfig.getGlobalJobParameters + if (conf != null && conf.toMap.containsKey(key)) { + conf.toMap.get(key) + } else { + defaultValue + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UserDefinedFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UserDefinedFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UserDefinedFunction.scala index b99ab8d..c313d80 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UserDefinedFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UserDefinedFunction.scala @@ -23,5 +23,20 @@ package org.apache.flink.table.functions * * User-defined functions must have a default constructor and must be instantiable during runtime. */ -trait UserDefinedFunction { +abstract class UserDefinedFunction { + /** + * Setup method for user-defined function. It can be used for initialization work. + * + * By default, this method does nothing. + */ + @throws(classOf[Exception]) + def open(context: FunctionContext): Unit = {} + + /** + * Tear-down method for user-defined function. It can be used for clean up work. + * + * By default, this method does nothing. + */ + @throws(classOf[Exception]) + def close(): Unit = {} } http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CorrelateFlatMapRunner.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CorrelateFlatMapRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CorrelateFlatMapRunner.scala index 4e803da..a0415e1 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CorrelateFlatMapRunner.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CorrelateFlatMapRunner.scala @@ -18,6 +18,7 @@ package org.apache.flink.table.runtime +import org.apache.flink.api.common.functions.util.FunctionUtils import org.apache.flink.api.common.functions.{FlatMapFunction, RichFlatMapFunction} import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.typeutils.ResultTypeQueryable @@ -52,6 +53,8 @@ class CorrelateFlatMapRunner[IN, OUT]( val constructor = flatMapClazz.getConstructor(classOf[TableFunctionCollector[_]]) LOG.debug("Instantiating FlatMapFunction.") function = constructor.newInstance(collector).asInstanceOf[FlatMapFunction[IN, OUT]] + FunctionUtils.setFunctionRuntimeContext(function, getRuntimeContext) + FunctionUtils.openFunction(function, parameters) } override def flatMap(in: IN, out: Collector[OUT]): Unit = { @@ -62,4 +65,8 @@ class CorrelateFlatMapRunner[IN, OUT]( } override def getProducedType: TypeInformation[OUT] = returnType + + override def close(): Unit = { + FunctionUtils.closeFunction(function) + } } http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/FlatMapRunner.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/FlatMapRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/FlatMapRunner.scala index a7bd980..b446306 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/FlatMapRunner.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/FlatMapRunner.scala @@ -18,6 +18,7 @@ package org.apache.flink.table.runtime +import org.apache.flink.api.common.functions.util.FunctionUtils import org.apache.flink.api.common.functions.{FlatMapFunction, RichFlatMapFunction} import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.typeutils.ResultTypeQueryable @@ -43,10 +44,16 @@ class FlatMapRunner[IN, OUT]( val clazz = compile(getRuntimeContext.getUserCodeClassLoader, name, code) LOG.debug("Instantiating FlatMapFunction.") function = clazz.newInstance() + FunctionUtils.setFunctionRuntimeContext(function, getRuntimeContext) + FunctionUtils.openFunction(function, parameters) } override def flatMap(in: IN, out: Collector[OUT]): Unit = function.flatMap(in, out) override def getProducedType: TypeInformation[OUT] = returnType + + override def close(): Unit = { + FunctionUtils.closeFunction(function) + } } http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/CalcITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/CalcITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/CalcITCase.scala index 3710642..00f4782 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/CalcITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/CalcITCase.scala @@ -23,11 +23,11 @@ import java.sql.{Date, Time, Timestamp} import java.util import org.apache.flink.api.scala._ +import org.apache.flink.api.scala.util.CollectionDataSets +import org.apache.flink.table.api.scala._ import org.apache.flink.table.api.scala.batch.sql.FilterITCase.MyHashCode -import org.apache.flink.table.api.scala.batch.utils.{TableProgramsCollectionTestBase, TableProgramsTestBase} import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode -import org.apache.flink.table.api.scala._ -import org.apache.flink.api.scala.util.CollectionDataSets +import org.apache.flink.table.api.scala.batch.utils.{TableProgramsCollectionTestBase, TableProgramsTestBase} import org.apache.flink.table.api.{TableEnvironment, ValidationException} import org.apache.flink.table.functions.ScalarFunction import org.apache.flink.test.util.TestBaseUtils http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/CalcITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/CalcITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/CalcITCase.scala index 2f853f3..b78dd91 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/CalcITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/CalcITCase.scala @@ -22,15 +22,15 @@ import java.sql.{Date, Time, Timestamp} import java.util import org.apache.flink.api.scala._ -import org.apache.flink.table.api.scala.batch.utils.{TableProgramsCollectionTestBase, TableProgramsTestBase} -import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode -import org.apache.flink.table.api.scala._ import org.apache.flink.api.scala.util.CollectionDataSets -import org.apache.flink.types.Row import org.apache.flink.table.api.TableEnvironment +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode +import org.apache.flink.table.api.scala.batch.utils.{TableProgramsCollectionTestBase, TableProgramsTestBase} import org.apache.flink.table.expressions.Literal import org.apache.flink.table.functions.ScalarFunction import org.apache.flink.test.util.TestBaseUtils +import org.apache.flink.types.Row import org.junit._ import org.junit.runner.RunWith import org.junit.runners.Parameterized http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala index 97e76fa..70bec72 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala @@ -19,12 +19,12 @@ package org.apache.flink.table.api.scala.stream.sql import org.apache.flink.api.scala._ -import org.apache.flink.types.Row -import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, StreamTestData} -import org.apache.flink.table.api.scala._ import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase import org.apache.flink.table.api.TableEnvironment +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, StreamTestData} +import org.apache.flink.types.Row import org.junit.Assert._ import org.junit._ http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/CalcITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/CalcITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/CalcITCase.scala index f541eb4..5969e91 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/CalcITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/CalcITCase.scala @@ -19,13 +19,13 @@ package org.apache.flink.table.api.scala.stream.table import org.apache.flink.api.scala._ -import org.apache.flink.types.Row -import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, StreamTestData} -import org.apache.flink.table.api.scala._ -import org.apache.flink.table.expressions.Literal import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, StreamTestData} import org.apache.flink.table.api.{TableEnvironment, TableException} +import org.apache.flink.table.expressions.Literal +import org.apache.flink.types.Row import org.junit.Assert._ import org.junit.Test http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala index da8c748..a6c1760 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala @@ -179,7 +179,7 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase { "Func12(f8)", "+0 00:00:01.000") } - + @Test def testJavaBoxedPrimitives(): Unit = { val JavaFunc0 = new JavaFunc0() @@ -211,6 +211,30 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase { "null and 15 and null") } + @Test + def testRichFunctions(): Unit = { + val richFunc0 = new RichFunc0 + val richFunc1 = new RichFunc1 + val richFunc2 = new RichFunc2 + testAllApis( + richFunc0('f0), + "RichFunc0(f0)", + "RichFunc0(f0)", + "43") + + testAllApis( + richFunc1('f0), + "RichFunc1(f0)", + "RichFunc1(f0)", + "42") + + testAllApis( + richFunc2('f1), + "RichFunc2(f1)", + "RichFunc2(f1)", + "#Test") + } + // ---------------------------------------------------------------------------------------------- override def testData: Any = { @@ -256,7 +280,10 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase { "Func11" -> Func11, "Func12" -> Func12, "JavaFunc0" -> new JavaFunc0, - "JavaFunc1" -> new JavaFunc1 + "JavaFunc1" -> new JavaFunc1, + "RichFunc0" -> new RichFunc0, + "RichFunc1" -> new RichFunc1, + "RichFunc2" -> new RichFunc2 ) } http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala index 679942c..30da5ba 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala @@ -19,17 +19,23 @@ package org.apache.flink.table.expressions.utils import org.apache.calcite.plan.hep.{HepMatchOrder, HepPlanner, HepProgramBuilder} +import java.util +import java.util.concurrent.Future import org.apache.calcite.rex.RexNode import org.apache.calcite.sql.`type`.SqlTypeName._ import org.apache.calcite.sql2rel.RelDecorrelator import org.apache.calcite.tools.{Programs, RelBuilder} -import org.apache.flink.api.common.functions.{Function, MapFunction} +import org.apache.flink.api.common.TaskInfo +import org.apache.flink.api.common.accumulators.Accumulator +import org.apache.flink.api.common.functions._ +import org.apache.flink.api.common.functions.util.RuntimeUDFContext import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.typeutils.RowTypeInfo import org.apache.flink.api.java.{DataSet => JDataSet} import org.apache.flink.api.scala.{DataSet, ExecutionEnvironment} -import org.apache.flink.api.java.typeutils.RowTypeInfo -import org.apache.flink.types.Row +import org.apache.flink.configuration.Configuration +import org.apache.flink.core.fs.Path import org.apache.flink.table.api.{BatchTableEnvironment, TableConfig, TableEnvironment} import org.apache.flink.table.calcite.FlinkPlannerImpl import org.apache.flink.table.codegen.{CodeGenerator, Compiler, GeneratedFunction} @@ -37,6 +43,7 @@ import org.apache.flink.table.expressions.{Expression, ExpressionParser} import org.apache.flink.table.functions.ScalarFunction import org.apache.flink.table.plan.nodes.dataset.{DataSetCalc, DataSetConvention} import org.apache.flink.table.plan.rules.FlinkRuleSets +import org.apache.flink.types.Row import org.junit.Assert._ import org.junit.{After, Before} import org.mockito.Mockito._ @@ -69,7 +76,8 @@ abstract class ExpressionTestBase { new HepPlanner(builder.build, context._2.getFrameworkConfig.getContext) } - private def prepareContext(typeInfo: TypeInformation[Any]): (RelBuilder, TableEnvironment) = { + private def prepareContext(typeInfo: TypeInformation[Any]) + : (RelBuilder, TableEnvironment, ExecutionEnvironment) = { // create DataSetTable val dataSetMock = mock(classOf[DataSet[Any]]) val jDataSetMock = mock(classOf[JDataSet[Any]]) @@ -85,7 +93,7 @@ abstract class ExpressionTestBase { val relBuilder = tEnv.getRelBuilder relBuilder.scan(tableName) - (relBuilder, tEnv) + (relBuilder, tEnv, env) } def testData: Any @@ -130,8 +138,30 @@ abstract class ExpressionTestBase { // compile and evaluate val clazz = new TestCompiler[MapFunction[Any, Row], Row]().compile(genFunc) val mapper = clazz.newInstance() + + val isRichFunction = mapper.isInstanceOf[RichFunction] + + // call setRuntimeContext method and open method for RichFunction + if (isRichFunction) { + val richMapper = mapper.asInstanceOf[RichMapFunction[_, _]] + val t = new RuntimeUDFContext( + new TaskInfo("ExpressionTest", 1, 0, 1, 1), + null, + context._3.getConfig, + new util.HashMap[String, Future[Path]](), + new util.HashMap[String, Accumulator[_, _]](), + null) + richMapper.setRuntimeContext(t) + richMapper.open(new Configuration()) + } + val result = mapper.map(testData) + // call close method for RichFunction + if (isRichFunction) { + mapper.asInstanceOf[RichMapFunction[_, _]].close() + } + // compare testExprs .zipWithIndex http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala index 4e9b6d3..f0b347d 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala @@ -22,7 +22,11 @@ import java.sql.{Date, Time, Timestamp} import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.table.api.Types -import org.apache.flink.table.functions.ScalarFunction +import org.apache.flink.table.functions.{ScalarFunction, FunctionContext} +import org.junit.Assert + +import scala.collection.mutable +import scala.io.Source case class SimplePojo(name: String, age: Int) @@ -119,3 +123,94 @@ object Func12 extends ScalarFunction { Types.INTERVAL_MILLIS } } + +class RichFunc0 extends ScalarFunction { + var openCalled = false + var closeCalled = false + + override def open(context: FunctionContext): Unit = { + super.open(context) + if (openCalled) { + Assert.fail("Open called more than once.") + } else { + openCalled = true + } + if (closeCalled) { + Assert.fail("Close called before open.") + } + } + + def eval(index: Int): Int = { + if (!openCalled) { + Assert.fail("Open was not called before eval.") + } + if (closeCalled) { + Assert.fail("Close called before eval.") + } + + index + 1 + } + + override def close(): Unit = { + super.close() + if (closeCalled) { + Assert.fail("Close called more than once.") + } else { + closeCalled = true + } + if (!openCalled) { + Assert.fail("Open was not called before close.") + } + } +} + +class RichFunc1 extends ScalarFunction { + var added = Int.MaxValue + + override def open(context: FunctionContext): Unit = { + added = context.getJobParameter("int.value", "0").toInt + } + + def eval(index: Int): Int = { + index + added + } + + override def close(): Unit = { + added = Int.MaxValue + } +} + +class RichFunc2 extends ScalarFunction { + var prefix = "ERROR_VALUE" + + override def open(context: FunctionContext): Unit = { + prefix = context.getJobParameter("string.value", "") + } + + def eval(value: String): String = { + prefix + "#" + value + } + + override def close(): Unit = { + prefix = "ERROR_VALUE" + } +} + +class RichFunc3 extends ScalarFunction { + private val words = mutable.HashSet[String]() + + override def open(context: FunctionContext): Unit = { + val file = context.getCachedFile("words") + for (line <- Source.fromFile(file.getCanonicalPath).getLines) { + words.add(line.trim) + } + } + + def eval(value: String): Boolean = { + words.contains(value) + } + + override def close(): Unit = { + words.clear() + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCalcITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCalcITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCalcITCase.scala new file mode 100644 index 0000000..f0b3b44 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCalcITCase.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.dataset + +import org.apache.flink.api.scala._ +import org.apache.flink.api.scala.util.CollectionDataSets +import org.apache.flink.table.api.TableEnvironment +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.api.scala.batch.utils.TableProgramsClusterTestBase +import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode +import org.apache.flink.table.expressions.utils.{RichFunc1, RichFunc2, RichFunc3} +import org.apache.flink.table.utils._ +import org.apache.flink.test.util.TestBaseUtils +import org.apache.flink.types.Row +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized + +import scala.collection.JavaConverters._ + +@RunWith(classOf[Parameterized]) +class DataSetCalcITCase( + configMode: TableConfigMode) + extends TableProgramsClusterTestBase(configMode) { + + @Test + def testUserDefinedScalarFunctionWithParameter(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + tEnv.registerFunction("RichFunc2", new RichFunc2) + UserDefinedFunctionTestUtils.setJobParameters(env, Map("string.value" -> "ABC")) + + val ds = CollectionDataSets.getSmall3TupleDataSet(env) + tEnv.registerDataSet("t1", ds, 'a, 'b, 'c) + + val sqlQuery = "SELECT c FROM t1 where RichFunc2(c)='ABC#Hello'" + + val result = tEnv.sql(sqlQuery) + + val expected = "Hello" + val results = result.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testUserDefinedScalarFunctionWithDistributedCache(): Unit = { + val words = "Hello\nWord" + val filePath = UserDefinedFunctionTestUtils.writeCacheFile("test_words", words) + val env = ExecutionEnvironment.getExecutionEnvironment + env.registerCachedFile(filePath, "words") + val tEnv = TableEnvironment.getTableEnvironment(env) + tEnv.registerFunction("RichFunc3", new RichFunc3) + + val ds = CollectionDataSets.getSmall3TupleDataSet(env) + tEnv.registerDataSet("t1", ds, 'a, 'b, 'c) + + val sqlQuery = "SELECT c FROM t1 where RichFunc3(c)=true" + + val result = tEnv.sql(sqlQuery) + + val expected = "Hello" + val results = result.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testMultipleUserDefinedScalarFunctions(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + tEnv.registerFunction("RichFunc1", new RichFunc1) + tEnv.registerFunction("RichFunc2", new RichFunc2) + UserDefinedFunctionTestUtils.setJobParameters(env, Map("string.value" -> "Abc")) + + val ds = CollectionDataSets.getSmall3TupleDataSet(env) + tEnv.registerDataSet("t1", ds, 'a, 'b, 'c) + + val sqlQuery = "SELECT c FROM t1 where " + + "RichFunc2(c)='Abc#Hello' or RichFunc1(a)=3 and b=2" + + val result = tEnv.sql(sqlQuery) + + val expected = "Hello\nHello world" + val results = result.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + +} http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCorrelateITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCorrelateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCorrelateITCase.scala index 818f52b..cd1ffb5 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCorrelateITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCorrelateITCase.scala @@ -20,14 +20,16 @@ package org.apache.flink.table.runtime.dataset import java.sql.{Date, Timestamp} import org.apache.flink.api.scala._ -import org.apache.flink.types.Row -import org.apache.flink.table.api.scala.batch.utils.TableProgramsClusterTestBase -import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode -import org.apache.flink.table.api.scala._ +import org.apache.flink.api.scala.util.CollectionDataSets import org.apache.flink.table.api.TableEnvironment import org.apache.flink.table.api.java.utils.UserDefinedTableFunctions.JavaTableFunc0 +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode +import org.apache.flink.table.api.scala.batch.utils.TableProgramsClusterTestBase +import org.apache.flink.table.expressions.utils.RichFunc2 import org.apache.flink.table.utils._ import org.apache.flink.test.util.TestBaseUtils +import org.apache.flink.types.Row import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.Parameterized @@ -147,7 +149,7 @@ class DataSetCorrelateITCase( } @Test - def testUDTFWithScalarFunction(): Unit = { + def testUserDefinedTableFunctionWithScalarFunction(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment val tableEnv = TableEnvironment.getTableEnvironment(env, config) val in = testData(env).toTable(tableEnv).as('a, 'b, 'c) @@ -185,6 +187,46 @@ class DataSetCorrelateITCase( TestBaseUtils.compareResultAsText(results.asJava, expected) } + @Test + def testUserDefinedTableFunctionWithParameter(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + val richTableFunc1 = new RichTableFunc1 + tEnv.registerFunction("RichTableFunc1", richTableFunc1) + UserDefinedFunctionTestUtils.setJobParameters(env, Map("word_separator" -> "#")) + + val result = testData(env) + .toTable(tEnv, 'a, 'b, 'c) + .join(richTableFunc1('c) as 's) + .select('a, 's) + + val expected = "1,Jack\n" + "1,22\n" + "2,John\n" + "2,19\n" + "3,Anna\n" + "3,44" + val results = result.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testUserDefinedTableFunctionWithScalarFunctionWithParameters(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + val richTableFunc1 = new RichTableFunc1 + tEnv.registerFunction("RichTableFunc1", richTableFunc1) + val richFunc2 = new RichFunc2 + tEnv.registerFunction("RichFunc2", richFunc2) + UserDefinedFunctionTestUtils.setJobParameters( + env, + Map("word_separator" -> "#", "string.value" -> "test")) + + val result = CollectionDataSets.getSmall3TupleDataSet(env) + .toTable(tEnv, 'a, 'b, 'c) + .join(richTableFunc1(richFunc2('c)) as 's) + .select('a, 's) + + val expected = "1,Hi\n1,test\n2,Hello\n2,test\n3,Hello world\n3,test" + val results = result.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + private def testData( env: ExecutionEnvironment) : DataSet[(Int, Long, String)] = { http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamCalcITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamCalcITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamCalcITCase.scala new file mode 100644 index 0000000..1d48f2c --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamCalcITCase.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.datastream + +import org.apache.flink.api.scala._ +import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment +import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase +import org.apache.flink.table.api.TableEnvironment +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, StreamTestData} +import org.apache.flink.table.expressions.utils.{RichFunc1, RichFunc2} +import org.apache.flink.table.utils.UserDefinedFunctionTestUtils +import org.apache.flink.types.Row +import org.junit.Assert._ +import org.junit.Test + +import scala.collection.mutable + +class DataStreamCalcITCase extends StreamingMultipleProgramsTestBase { + + @Test + def testUserDefinedFunctionWithParameter(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + tEnv.registerFunction("RichFunc2", new RichFunc2) + UserDefinedFunctionTestUtils.setJobParameters(env, Map("string.value" -> "ABC")) + + StreamITCase.testResults = mutable.MutableList() + + val result = StreamTestData.get3TupleDataStream(env) + .toTable(tEnv, 'a, 'b, 'c) + .where("RichFunc2(c)='ABC#Hello'") + .select('c) + + val results = result.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList("Hello") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testMultipleUserDefinedFunctions(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + tEnv.registerFunction("RichFunc1", new RichFunc1) + tEnv.registerFunction("RichFunc2", new RichFunc2) + UserDefinedFunctionTestUtils.setJobParameters(env, Map("string.value" -> "Abc")) + + StreamITCase.testResults = mutable.MutableList() + + val result = StreamTestData.get3TupleDataStream(env) + .toTable(tEnv, 'a, 'b, 'c) + .where("RichFunc2(c)='Abc#Hello' || RichFunc1(a)=3 && b=2") + .select('c) + + val results = result.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList("Hello", "Hello world") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamCorrelateITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamCorrelateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamCorrelateITCase.scala index eb20517..f8a697d 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamCorrelateITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamCorrelateITCase.scala @@ -18,13 +18,14 @@ package org.apache.flink.table.runtime.datastream import org.apache.flink.api.scala._ -import org.apache.flink.types.Row -import org.apache.flink.table.api.scala.stream.utils.StreamITCase -import org.apache.flink.table.api.scala._ -import org.apache.flink.table.utils.TableFunc0 import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment} import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase import org.apache.flink.table.api.TableEnvironment +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, StreamTestData} +import org.apache.flink.table.expressions.utils.RichFunc2 +import org.apache.flink.table.utils.{RichTableFunc1, TableFunc0, UserDefinedFunctionTestUtils} +import org.apache.flink.types.Row import org.junit.Assert._ import org.junit.Test @@ -76,9 +77,63 @@ class DataStreamCorrelateITCase extends StreamingMultipleProgramsTestBase { assertEquals(expected.sorted, StreamITCase.testResults.sorted) } + @Test + def testUserDefinedTableFunctionWithParameter(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + val tableFunc1 = new RichTableFunc1 + tEnv.registerFunction("RichTableFunc1", tableFunc1) + UserDefinedFunctionTestUtils.setJobParameters(env, Map("word_separator" -> " ")) + StreamITCase.testResults = mutable.MutableList() + + val result = StreamTestData.getSmall3TupleDataStream(env) + .toTable(tEnv, 'a, 'b, 'c) + .join(tableFunc1('c) as 's) + .select('a, 's) + + val results = result.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList("3,Hello", "3,world") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testUserDefinedTableFunctionWithUserDefinedScalarFunction(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + val tableFunc1 = new RichTableFunc1 + val richFunc2 = new RichFunc2 + tEnv.registerFunction("RichTableFunc1", tableFunc1) + tEnv.registerFunction("RichFunc2", richFunc2) + UserDefinedFunctionTestUtils.setJobParameters( + env, + Map("word_separator" -> "#", "string.value" -> "test")) + StreamITCase.testResults = mutable.MutableList() + + val result = StreamTestData.getSmall3TupleDataStream(env) + .toTable(tEnv, 'a, 'b, 'c) + .join(tableFunc1(richFunc2('c)) as 's) + .select('a, 's) + + val results = result.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList( + "1,Hi", + "1,test", + "2,Hello", + "2,test", + "3,Hello world", + "3,test") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + private def testData( - env: StreamExecutionEnvironment) - : DataStream[(Int, Long, String)] = { + env: StreamExecutionEnvironment) + : DataStream[(Int, Long, String)] = { val data = new mutable.MutableList[(Int, Long, String)] data.+=((1, 1L, "Jack#22")) http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedFunctionTestUtils.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedFunctionTestUtils.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedFunctionTestUtils.scala new file mode 100644 index 0000000..deaedc9 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedFunctionTestUtils.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.utils + +import java.io.File + +import com.google.common.base.Charsets +import com.google.common.io.Files +import org.apache.flink.api.scala.ExecutionEnvironment +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment + +object UserDefinedFunctionTestUtils { + + def setJobParameters(env: ExecutionEnvironment, parameters: Map[String, String]): Unit = { + val conf = new Configuration() + parameters.foreach { + case (k, v) => conf.setString(k, v) + } + env.getConfig.setGlobalJobParameters(conf) + } + + def setJobParameters(env: StreamExecutionEnvironment, parameters: Map[String, String]): Unit = { + val conf = new Configuration() + parameters.foreach { + case (k, v) => conf.setString(k, v) + } + env.getConfig.setGlobalJobParameters(conf) + } + + def writeCacheFile(fileName: String, contents: String): String = { + val tempFile = File.createTempFile(this.getClass.getName + "-" + fileName, "tmp") + tempFile.deleteOnExit() + Files.write(contents, tempFile, Charsets.UTF_8) + tempFile.getAbsolutePath + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala index 54861ea..5db9d5f 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala @@ -21,9 +21,11 @@ import java.lang.Boolean import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.java.tuple.Tuple3 -import org.apache.flink.types.Row import org.apache.flink.api.java.typeutils.RowTypeInfo -import org.apache.flink.table.functions.TableFunction +import org.apache.flink.table.api.ValidationException +import org.apache.flink.table.functions.{TableFunction, FunctionContext} +import org.apache.flink.types.Row +import org.junit.Assert case class SimpleUser(name: String, age: Int) @@ -115,3 +117,55 @@ object ObjectTableFunction extends TableFunction[Integer] { collect(b) } } + +class RichTableFunc0 extends TableFunction[String] { + var openCalled = false + var closeCalled = false + + override def open(context: FunctionContext): Unit = { + super.open(context) + if (closeCalled) { + Assert.fail("Close called before open.") + } + openCalled = true + } + + def eval(str: String): Unit = { + if (!openCalled) { + Assert.fail("Open was not called before eval.") + } + if (closeCalled) { + Assert.fail("Close called before eval.") + } + + if (!str.contains("#")) { + collect(str) + } + } + + override def close(): Unit = { + super.close() + if (!openCalled) { + Assert.fail("Open was not called before close.") + } + closeCalled = true + } +} + +class RichTableFunc1 extends TableFunction[String] { + var separator: Option[String] = None + + override def open(context: FunctionContext): Unit = { + separator = Some(context.getJobParameter("word_separator", "")) + } + + def eval(str: String): Unit = { + if (str.contains(separator.getOrElse(throw new ValidationException(s"no separator")))) { + str.split(separator.get).foreach(collect) + } + } + + override def close(): Unit = { + separator = None + } +}