Repository: flink Updated Branches: refs/heads/master 11c868f91 -> 45e01cf23
[FLINK-5795] [table] Improve UDF&UDTF to support constructor with parameter this closes #3330 Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/45e01cf2 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/45e01cf2 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/45e01cf2 Branch: refs/heads/master Commit: 45e01cf2321dda58f572d8b9dbe64947c6725ad1 Parents: 11c868f Author: é竹 <[email protected]> Authored: Tue Feb 14 14:43:41 2017 +0800 Committer: Jark Wu <[email protected]> Committed: Wed Feb 22 10:02:04 2017 +0800 ---------------------------------------------------------------------- .../flink/table/codegen/CodeGenerator.scala | 19 +- .../apache/flink/table/expressions/call.scala | 6 +- .../table/functions/UserDefinedFunction.scala | 11 +- .../utils/UserDefinedFunctionUtils.scala | 33 +-- .../flink/table/plan/logical/operators.scala | 2 +- .../flink/table/validate/FunctionCatalog.scala | 12 +- .../flink/table/CompositeFlatteningTest.scala | 8 +- .../scala/batch/table/FieldProjectionTest.scala | 4 +- .../table/UserDefinedTableFunctionTest.scala | 6 +- .../table/UserDefinedTableFunctionTest.scala | 16 +- .../utils/UserDefinedScalarFunctions.scala | 7 + .../dataset/DataSetCorrelateITCase.scala | 241 ---------------- .../DataSetUserDefinedFunctionITCase.scala | 288 +++++++++++++++++++ .../DataSetUserDefinedFunctionITCase.scala | 206 +++++++++++++ .../datastream/DataStreamCorrelateITCase.scala | 146 ---------- .../table/utils/UserDefinedTableFunctions.scala | 36 ++- 16 files changed, 595 insertions(+), 446 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/45e01cf2/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala index 441b1c0..6658645 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala @@ -40,6 +40,7 @@ import org.apache.flink.table.codegen.Indenter.toISC import org.apache.flink.table.codegen.calls.FunctionGenerator import org.apache.flink.table.codegen.calls.ScalarOperators._ import org.apache.flink.table.functions.{FunctionContext, UserDefinedFunction} +import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils import org.apache.flink.table.runtime.TableFunctionCollector import org.apache.flink.table.typeutils.TypeCheckUtils._ import org.apache.flink.types.Row @@ -1494,15 +1495,14 @@ class CodeGenerator( /** * Adds a reusable [[UserDefinedFunction]] to the member area of the generated [[Function]]. - * The [[UserDefinedFunction]] must have a default constructor, however, it does not have - * to be public. * * @param function [[UserDefinedFunction]] object to be instantiated during runtime * @return member variable term */ def addReusableFunction(function: UserDefinedFunction): String = { val classQualifier = function.getClass.getCanonicalName - val fieldTerm = s"function_${classQualifier.replace('.', '$')}" + val functionSerializedData = UserDefinedFunctionUtils.serialize(function) + val fieldTerm = s"function_${function.functionIdentifier}" val fieldFunction = s""" @@ -1510,15 +1510,14 @@ class CodeGenerator( |""".stripMargin reusableMemberStatements.add(fieldFunction) - val constructorTerm = s"constructor_${classQualifier.replace('.', '$')}" - val constructorAccessibility = + val functionDeserialization = s""" - |java.lang.reflect.Constructor $constructorTerm = - | $classQualifier.class.getDeclaredConstructor(); - |$constructorTerm.setAccessible(true); - |$fieldTerm = ($classQualifier) $constructorTerm.newInstance(); + |$fieldTerm = ($classQualifier) + |${UserDefinedFunctionUtils.getClass.getName.stripSuffix("$")} + |.deserialize("$functionSerializedData"); """.stripMargin - reusableInitStatements.add(constructorAccessibility) + + reusableInitStatements.add(functionDeserialization) val openFunction = s""" http://git-wip-us.apache.org/repos/asf/flink/blob/45e01cf2/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/call.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/call.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/call.scala index ef2cf4e..40db13e 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/call.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/call.scala @@ -20,12 +20,12 @@ package org.apache.flink.table.expressions import org.apache.calcite.rex.RexNode import org.apache.calcite.tools.RelBuilder import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.table.api.{UnresolvedException, ValidationException} import org.apache.flink.table.calcite.FlinkTypeFactory -import org.apache.flink.table.functions.{ScalarFunction, TableFunction} import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._ +import org.apache.flink.table.functions.{ScalarFunction, TableFunction} import org.apache.flink.table.plan.logical.{LogicalNode, LogicalTableFunctionCall} import org.apache.flink.table.validate.{ValidationFailure, ValidationResult, ValidationSuccess} -import org.apache.flink.table.api.{UnresolvedException, ValidationException} /** * General expression for unresolved function calls. The function can be a built-in @@ -67,7 +67,7 @@ case class ScalarFunctionCall( val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory] relBuilder.call( createScalarSqlFunction( - scalarFunction.getClass.getCanonicalName, + scalarFunction.functionIdentifier, scalarFunction, typeFactory), parameters.map(_.toRexNode): _*) http://git-wip-us.apache.org/repos/asf/flink/blob/45e01cf2/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UserDefinedFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UserDefinedFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UserDefinedFunction.scala index c313d80..e9e01ee 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UserDefinedFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UserDefinedFunction.scala @@ -17,13 +17,13 @@ */ package org.apache.flink.table.functions +import org.apache.commons.codec.digest.DigestUtils +import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.serialize /** * Base class for all user-defined functions such as scalar functions, table functions, * or aggregation functions. - * - * User-defined functions must have a default constructor and must be instantiable during runtime. */ -abstract class UserDefinedFunction { +abstract class UserDefinedFunction extends Serializable { /** * Setup method for user-defined function. It can be used for initialization work. * @@ -39,4 +39,9 @@ abstract class UserDefinedFunction { */ @throws(classOf[Exception]) def close(): Unit = {} + + final def functionIdentifier: String = { + val md5 = DigestUtils.md5Hex(serialize(this)) + getClass.getCanonicalName.replace('.', '$').concat("$").concat(md5) + } } http://git-wip-us.apache.org/repos/asf/flink/blob/45e01cf2/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala index f324dc1..16a6717b 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala @@ -23,6 +23,7 @@ import java.lang.{Long => JLong, Integer => JInt} import java.lang.reflect.{Method, Modifier} import java.sql.{Date, Time, Timestamp} +import org.apache.commons.codec.binary.Base64 import com.google.common.primitives.Primitives import org.apache.calcite.sql.SqlFunction import org.apache.flink.api.common.functions.InvalidTypesException @@ -37,15 +38,6 @@ import org.apache.flink.util.InstantiationUtil object UserDefinedFunctionUtils { /** - * Instantiates a user-defined function. - */ - def instantiate[T <: UserDefinedFunction](clazz: Class[T]): T = { - val constructor = clazz.getDeclaredConstructor() - constructor.setAccessible(true) - constructor.newInstance() - } - - /** * Checks if a user-defined function can be easily instantiated. */ def checkForInstantiation(clazz: Class[_]): Unit = { @@ -59,12 +51,6 @@ object UserDefinedFunctionUtils { else if (InstantiationUtil.isNonStaticInnerClass(clazz)) { throw ValidationException("The class is an inner class, but not statically accessible.") } - - // check for default constructor (can be private) - clazz - .getDeclaredConstructors - .find(_.getParameterTypes.isEmpty) - .getOrElse(throw ValidationException("Function class needs a default constructor.")) } /** @@ -168,7 +154,7 @@ object UserDefinedFunctionUtils { /** * Create [[SqlFunction]] for a [[ScalarFunction]] - * + * * @param name function name * @param function scalar function * @param typeFactory type factory @@ -184,7 +170,7 @@ object UserDefinedFunctionUtils { /** * Create [[SqlFunction]]s for a [[TableFunction]]'s every eval method - * + * * @param name function name * @param tableFunction table function * @param resultType the type information of returned table @@ -311,7 +297,6 @@ object UserDefinedFunctionUtils { } }.toArray - /** * Compares parameter candidate classes with expected classes. If true, the parameters match. * Candidate can be null (acts as a wildcard). @@ -324,4 +309,16 @@ object UserDefinedFunctionUtils { candidate == classOf[Time] && (expected == classOf[Int] || expected == classOf[JInt]) || candidate == classOf[Timestamp] && (expected == classOf[Long] || expected == classOf[JLong]) + @throws[Exception] + def serialize(function: UserDefinedFunction): String = { + val byteArray = InstantiationUtil.serializeObject(function) + Base64.encodeBase64URLSafeString(byteArray) + } + + @throws[Exception] + def deserialize(data: String): UserDefinedFunction = { + val byteData = Base64.decodeBase64(data) + InstantiationUtil + .deserializeObject[UserDefinedFunction](byteData, Thread.currentThread.getContextClassLoader) + } } http://git-wip-us.apache.org/repos/asf/flink/blob/45e01cf2/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala index 20f810a..1b5eafb 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala @@ -694,7 +694,7 @@ case class LogicalTableFunctionCall( val function = new FlinkTableFunctionImpl(resultType, fieldIndexes, fieldNames, evalMethod) val typeFactory = relBuilder.getTypeFactory.asInstanceOf[FlinkTypeFactory] val sqlFunction = TableSqlFunction( - tableFunction.toString, + tableFunction.functionIdentifier, tableFunction, resultType, typeFactory, http://git-wip-us.apache.org/repos/asf/flink/blob/45e01cf2/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala index 207eba1..94237f7 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/validate/FunctionCatalog.scala @@ -23,8 +23,8 @@ import org.apache.calcite.sql.util.{ChainedSqlOperatorTable, ListSqlOperatorTabl import org.apache.calcite.sql.{SqlFunction, SqlOperator, SqlOperatorTable} import org.apache.flink.table.api.ValidationException import org.apache.flink.table.expressions._ -import org.apache.flink.table.functions.utils.{TableSqlFunction, UserDefinedFunctionUtils} import org.apache.flink.table.functions.{EventTimeExtractor, RowTime, ScalarFunction, TableFunction} +import org.apache.flink.table.functions.utils.{TableSqlFunction, ScalarSqlFunction} import scala.collection.JavaConversions._ import scala.collection.mutable @@ -81,11 +81,11 @@ class FunctionCatalog { // user-defined scalar function call case sf if classOf[ScalarFunction].isAssignableFrom(sf) => - Try(UserDefinedFunctionUtils.instantiate(sf.asInstanceOf[Class[ScalarFunction]])) match { - case Success(scalarFunction) => ScalarFunctionCall(scalarFunction, children) - case Failure(e) => throw ValidationException(e.getMessage) - } - + val scalarSqlFunction = sqlFunctions + .find(f => f.getName.equalsIgnoreCase(name) && f.isInstanceOf[ScalarSqlFunction]) + .getOrElse(throw ValidationException(s"Undefined scalar function: $name")) + .asInstanceOf[ScalarSqlFunction] + ScalarFunctionCall(scalarSqlFunction.getScalarFunction, children) // user-defined table function call case tf if classOf[TableFunction[_]].isAssignableFrom(tf) => val tableSqlFunction = sqlFunctions http://git-wip-us.apache.org/repos/asf/flink/blob/45e01cf2/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/CompositeFlatteningTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/CompositeFlatteningTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/CompositeFlatteningTest.scala index 0055fc2..f5f5ff1 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/CompositeFlatteningTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/CompositeFlatteningTest.scala @@ -119,10 +119,10 @@ class CompositeFlatteningTest extends TableTestBase { "DataSetCalc", batchTableNode(0), term("select", - "org.apache.flink.table.CompositeFlatteningTest.giveMeCaseClass$().my AS _c0", - "org.apache.flink.table.CompositeFlatteningTest.giveMeCaseClass$().clazz AS _c1", - "org.apache.flink.table.CompositeFlatteningTest.giveMeCaseClass$().my AS _c2", - "org.apache.flink.table.CompositeFlatteningTest.giveMeCaseClass$().clazz AS _c3" + s"${giveMeCaseClass.functionIdentifier}().my AS _c0", + s"${giveMeCaseClass.functionIdentifier}().clazz AS _c1", + s"${giveMeCaseClass.functionIdentifier}().my AS _c2", + s"${giveMeCaseClass.functionIdentifier}().clazz AS _c3" ) ) http://git-wip-us.apache.org/repos/asf/flink/blob/45e01cf2/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala index d053b9f..0066ad2 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala @@ -103,7 +103,7 @@ class FieldProjectionTest extends TableTestBase { val expected = unaryNode( "DataSetCalc", batchTableNode(0), - term("select", s"${MyHashCode.getClass.getCanonicalName}(c) AS _c0", "b") + term("select", s"${MyHashCode.functionIdentifier}(c) AS _c0", "b") ) util.verifyTable(resultTable, expected) @@ -212,7 +212,7 @@ class FieldProjectionTest extends TableTestBase { unaryNode( "DataSetCalc", batchTableNode(0), - term("select", "a", "c", s"${MyHashCode.getClass.getCanonicalName}(c) AS k") + term("select", "a", "c", s"${MyHashCode.functionIdentifier}(c) AS k") ), term("groupBy", "k"), term("select", "k", "SUM(a) AS TMP_0") http://git-wip-us.apache.org/repos/asf/flink/blob/45e01cf2/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/UserDefinedTableFunctionTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/UserDefinedTableFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/UserDefinedTableFunctionTest.scala index f8d9c92..2dbcccf 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/UserDefinedTableFunctionTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/UserDefinedTableFunctionTest.scala @@ -120,7 +120,7 @@ class UserDefinedTableFunctionTest extends TableTestBase { unaryNode( "DataSetCorrelate", batchTableNode(0), - term("invocation", s"$function($$2)"), + term("invocation", s"${function.functionIdentifier}($$2)"), term("function", function), term("rowType", "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) s)"), @@ -140,7 +140,7 @@ class UserDefinedTableFunctionTest extends TableTestBase { unaryNode( "DataSetCorrelate", batchTableNode(0), - term("invocation", s"$function($$2, '$$')"), + term("invocation", s"${function.functionIdentifier}($$2, '$$')"), term("function", function), term("rowType", "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) s)"), @@ -165,7 +165,7 @@ class UserDefinedTableFunctionTest extends TableTestBase { unaryNode( "DataSetCorrelate", batchTableNode(0), - term("invocation", s"$function($$2)"), + term("invocation", s"${function.functionIdentifier}($$2)"), term("function", function), term("rowType", "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) s)"), http://git-wip-us.apache.org/repos/asf/flink/blob/45e01cf2/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/UserDefinedTableFunctionTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/UserDefinedTableFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/UserDefinedTableFunctionTest.scala index 168f9ec..56b9fdb 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/UserDefinedTableFunctionTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/UserDefinedTableFunctionTest.scala @@ -183,7 +183,7 @@ class UserDefinedTableFunctionTest extends TableTestBase { unaryNode( "DataStreamCorrelate", streamTableNode(0), - term("invocation", s"$function($$2)"), + term("invocation", s"${function.functionIdentifier}($$2)"), term("function", function), term("rowType", "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) s)"), @@ -203,7 +203,7 @@ class UserDefinedTableFunctionTest extends TableTestBase { unaryNode( "DataStreamCorrelate", streamTableNode(0), - term("invocation", s"$function($$2, '$$')"), + term("invocation", s"${function.functionIdentifier}($$2, '$$')"), term("function", function), term("rowType", "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) s)"), @@ -228,7 +228,7 @@ class UserDefinedTableFunctionTest extends TableTestBase { unaryNode( "DataStreamCorrelate", streamTableNode(0), - term("invocation", s"$function($$2)"), + term("invocation", s"${function.functionIdentifier}($$2)"), term("function", function), term("rowType", "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) s)"), @@ -253,7 +253,7 @@ class UserDefinedTableFunctionTest extends TableTestBase { unaryNode( "DataStreamCorrelate", streamTableNode(0), - term("invocation", s"$function($$2)"), + term("invocation", s"${function.functionIdentifier}($$2)"), term("function", function), term("rowType", "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, " + @@ -277,7 +277,7 @@ class UserDefinedTableFunctionTest extends TableTestBase { val expected = unaryNode( "DataStreamCorrelate", streamTableNode(0), - term("invocation", s"$function($$2)"), + term("invocation", s"${function.functionIdentifier}($$2)"), term("function", function), term("rowType", "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c," + @@ -299,7 +299,7 @@ class UserDefinedTableFunctionTest extends TableTestBase { val expected = unaryNode( "DataStreamCorrelate", streamTableNode(0), - term("invocation", s"$function($$2)"), + term("invocation", s"${function.functionIdentifier}($$2)"), term("function", function), term("rowType", "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, " + @@ -326,7 +326,7 @@ class UserDefinedTableFunctionTest extends TableTestBase { unaryNode( "DataStreamCorrelate", streamTableNode(0), - term("invocation", s"$function($$2)"), + term("invocation", s"${function.functionIdentifier}($$2)"), term("function", function), term("rowType", "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, " + @@ -351,7 +351,7 @@ class UserDefinedTableFunctionTest extends TableTestBase { val expected = unaryNode( "DataStreamCorrelate", streamTableNode(0), - term("invocation", s"$function(SUBSTRING($$2, 2, CHAR_LENGTH($$2)))"), + term("invocation", s"${function.functionIdentifier}(SUBSTRING($$2, 2, CHAR_LENGTH($$2)))"), term("function", function), term("rowType", "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647) s)"), http://git-wip-us.apache.org/repos/asf/flink/blob/45e01cf2/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala index f0b347d..4fee3b2 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala @@ -214,3 +214,10 @@ class RichFunc3 extends ScalarFunction { words.clear() } } + +class Func13(prefix: String) extends ScalarFunction { + def eval(a: String): String = { + s"$prefix-$a" + } +} + http://git-wip-us.apache.org/repos/asf/flink/blob/45e01cf2/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCorrelateITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCorrelateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCorrelateITCase.scala deleted file mode 100644 index cd1ffb5..0000000 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCorrelateITCase.scala +++ /dev/null @@ -1,241 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.table.runtime.dataset - -import java.sql.{Date, Timestamp} - -import org.apache.flink.api.scala._ -import org.apache.flink.api.scala.util.CollectionDataSets -import org.apache.flink.table.api.TableEnvironment -import org.apache.flink.table.api.java.utils.UserDefinedTableFunctions.JavaTableFunc0 -import org.apache.flink.table.api.scala._ -import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode -import org.apache.flink.table.api.scala.batch.utils.TableProgramsClusterTestBase -import org.apache.flink.table.expressions.utils.RichFunc2 -import org.apache.flink.table.utils._ -import org.apache.flink.test.util.TestBaseUtils -import org.apache.flink.types.Row -import org.junit.Test -import org.junit.runner.RunWith -import org.junit.runners.Parameterized - -import scala.collection.JavaConverters._ -import scala.collection.mutable - -@RunWith(classOf[Parameterized]) -class DataSetCorrelateITCase( - configMode: TableConfigMode) - extends TableProgramsClusterTestBase(configMode) { - - @Test - def testCrossJoin(): Unit = { - val env = ExecutionEnvironment.getExecutionEnvironment - val tableEnv = TableEnvironment.getTableEnvironment(env, config) - val in = testData(env).toTable(tableEnv).as('a, 'b, 'c) - - val func1 = new TableFunc1 - val result = in.join(func1('c) as 's).select('c, 's).toDataSet[Row] - val results = result.collect() - val expected = "Jack#22,Jack\n" + "Jack#22,22\n" + "John#19,John\n" + "John#19,19\n" + - "Anna#44,Anna\n" + "Anna#44,44\n" - TestBaseUtils.compareResultAsText(results.asJava, expected) - - // with overloading - val result2 = in.join(func1('c, "$") as 's).select('c, 's).toDataSet[Row] - val results2 = result2.collect() - val expected2 = "Jack#22,$Jack\n" + "Jack#22,$22\n" + "John#19,$John\n" + - "John#19,$19\n" + "Anna#44,$Anna\n" + "Anna#44,$44\n" - TestBaseUtils.compareResultAsText(results2.asJava, expected2) - } - - @Test - def testLeftOuterJoin(): Unit = { - val env = ExecutionEnvironment.getExecutionEnvironment - val tableEnv = TableEnvironment.getTableEnvironment(env, config) - val in = testData(env).toTable(tableEnv).as('a, 'b, 'c) - - val func2 = new TableFunc2 - val result = in.leftOuterJoin(func2('c) as ('s, 'l)).select('c, 's, 'l).toDataSet[Row] - val results = result.collect() - val expected = "Jack#22,Jack,4\n" + "Jack#22,22,2\n" + "John#19,John,4\n" + - "John#19,19,2\n" + "Anna#44,Anna,4\n" + "Anna#44,44,2\n" + "nosharp,null,null" - TestBaseUtils.compareResultAsText(results.asJava, expected) - } - - @Test - def testWithFilter(): Unit = { - val env = ExecutionEnvironment.getExecutionEnvironment - val tableEnv = TableEnvironment.getTableEnvironment(env, config) - val in = testData(env).toTable(tableEnv).as('a, 'b, 'c) - val func0 = new TableFunc0 - - val result = in - .join(func0('c) as ('name, 'age)) - .select('c, 'name, 'age) - .filter('age > 20) - .toDataSet[Row] - - val results = result.collect() - val expected = "Jack#22,Jack,22\n" + "Anna#44,Anna,44\n" - TestBaseUtils.compareResultAsText(results.asJava, expected) - } - - @Test - def testCustomReturnType(): Unit = { - val env = ExecutionEnvironment.getExecutionEnvironment - val tableEnv = TableEnvironment.getTableEnvironment(env, config) - val in = testData(env).toTable(tableEnv).as('a, 'b, 'c) - val func2 = new TableFunc2 - - val result = in - .join(func2('c) as ('name, 'len)) - .select('c, 'name, 'len) - .toDataSet[Row] - - val results = result.collect() - val expected = "Jack#22,Jack,4\n" + "Jack#22,22,2\n" + "John#19,John,4\n" + - "John#19,19,2\n" + "Anna#44,Anna,4\n" + "Anna#44,44,2\n" - TestBaseUtils.compareResultAsText(results.asJava, expected) - } - - @Test - def testHierarchyType(): Unit = { - val env = ExecutionEnvironment.getExecutionEnvironment - val tableEnv = TableEnvironment.getTableEnvironment(env, config) - val in = testData(env).toTable(tableEnv).as('a, 'b, 'c) - - val hierarchy = new HierarchyTableFunction - val result = in - .join(hierarchy('c) as ('name, 'adult, 'len)) - .select('c, 'name, 'adult, 'len) - .toDataSet[Row] - - val results = result.collect() - val expected = "Jack#22,Jack,true,22\n" + "John#19,John,false,19\n" + - "Anna#44,Anna,true,44\n" - TestBaseUtils.compareResultAsText(results.asJava, expected) - } - - @Test - def testPojoType(): Unit = { - val env = ExecutionEnvironment.getExecutionEnvironment - val tableEnv = TableEnvironment.getTableEnvironment(env, config) - val in = testData(env).toTable(tableEnv).as('a, 'b, 'c) - - val pojo = new PojoTableFunc() - val result = in - .join(pojo('c)) - .select('c, 'name, 'age) - .toDataSet[Row] - - val results = result.collect() - val expected = "Jack#22,Jack,22\n" + "John#19,John,19\n" + "Anna#44,Anna,44\n" - TestBaseUtils.compareResultAsText(results.asJava, expected) - } - - @Test - def testUserDefinedTableFunctionWithScalarFunction(): Unit = { - val env = ExecutionEnvironment.getExecutionEnvironment - val tableEnv = TableEnvironment.getTableEnvironment(env, config) - val in = testData(env).toTable(tableEnv).as('a, 'b, 'c) - val func1 = new TableFunc1 - - val result = in - .join(func1('c.substring(2)) as 's) - .select('c, 's) - .toDataSet[Row] - - val results = result.collect() - val expected = "Jack#22,ack\n" + "Jack#22,22\n" + "John#19,ohn\n" + "John#19,19\n" + - "Anna#44,nna\n" + "Anna#44,44\n" - TestBaseUtils.compareResultAsText(results.asJava, expected) - } - - @Test - def testLongAndTemporalTypes(): Unit = { - val env = ExecutionEnvironment.getExecutionEnvironment - val tableEnv = TableEnvironment.getTableEnvironment(env, config) - val in = testData(env).toTable(tableEnv).as('a, 'b, 'c) - val func0 = new JavaTableFunc0 - - val result = in - .where('a === 1) - .select(Date.valueOf("1990-10-14") as 'x, - 1000L as 'y, - Timestamp.valueOf("1990-10-14 12:10:10") as 'z) - .join(func0('x, 'y, 'z) as 's) - .select('s) - .toDataSet[Row] - - val results = result.collect() - val expected = "1000\n" + "655906210000\n" + "7591\n" - TestBaseUtils.compareResultAsText(results.asJava, expected) - } - - @Test - def testUserDefinedTableFunctionWithParameter(): Unit = { - val env = ExecutionEnvironment.getExecutionEnvironment - val tEnv = TableEnvironment.getTableEnvironment(env) - val richTableFunc1 = new RichTableFunc1 - tEnv.registerFunction("RichTableFunc1", richTableFunc1) - UserDefinedFunctionTestUtils.setJobParameters(env, Map("word_separator" -> "#")) - - val result = testData(env) - .toTable(tEnv, 'a, 'b, 'c) - .join(richTableFunc1('c) as 's) - .select('a, 's) - - val expected = "1,Jack\n" + "1,22\n" + "2,John\n" + "2,19\n" + "3,Anna\n" + "3,44" - val results = result.toDataSet[Row].collect() - TestBaseUtils.compareResultAsText(results.asJava, expected) - } - - @Test - def testUserDefinedTableFunctionWithScalarFunctionWithParameters(): Unit = { - val env = ExecutionEnvironment.getExecutionEnvironment - val tEnv = TableEnvironment.getTableEnvironment(env) - val richTableFunc1 = new RichTableFunc1 - tEnv.registerFunction("RichTableFunc1", richTableFunc1) - val richFunc2 = new RichFunc2 - tEnv.registerFunction("RichFunc2", richFunc2) - UserDefinedFunctionTestUtils.setJobParameters( - env, - Map("word_separator" -> "#", "string.value" -> "test")) - - val result = CollectionDataSets.getSmall3TupleDataSet(env) - .toTable(tEnv, 'a, 'b, 'c) - .join(richTableFunc1(richFunc2('c)) as 's) - .select('a, 's) - - val expected = "1,Hi\n1,test\n2,Hello\n2,test\n3,Hello world\n3,test" - val results = result.toDataSet[Row].collect() - TestBaseUtils.compareResultAsText(results.asJava, expected) - } - - private def testData( - env: ExecutionEnvironment) - : DataSet[(Int, Long, String)] = { - - val data = new mutable.MutableList[(Int, Long, String)] - data.+=((1, 1L, "Jack#22")) - data.+=((2, 2L, "John#19")) - data.+=((3, 2L, "Anna#44")) - data.+=((4, 3L, "nosharp")) - env.fromCollection(data) - } -} http://git-wip-us.apache.org/repos/asf/flink/blob/45e01cf2/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetUserDefinedFunctionITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetUserDefinedFunctionITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetUserDefinedFunctionITCase.scala new file mode 100644 index 0000000..d268594 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetUserDefinedFunctionITCase.scala @@ -0,0 +1,288 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.dataset + +import java.sql.{Date, Timestamp} + +import org.apache.flink.api.scala._ +import org.apache.flink.api.scala.util.CollectionDataSets +import org.apache.flink.table.api.TableEnvironment +import org.apache.flink.table.api.java.utils.UserDefinedTableFunctions.JavaTableFunc0 +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode +import org.apache.flink.table.api.scala.batch.utils.TableProgramsClusterTestBase +import org.apache.flink.table.expressions.utils.{Func13, RichFunc2} +import org.apache.flink.table.utils._ +import org.apache.flink.test.util.TestBaseUtils +import org.apache.flink.types.Row +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +@RunWith(classOf[Parameterized]) +class DataSetUserDefinedFunctionITCase( + configMode: TableConfigMode) + extends TableProgramsClusterTestBase(configMode) { + + @Test + def testCrossJoin(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tableEnv = TableEnvironment.getTableEnvironment(env, config) + val in = testData(env).toTable(tableEnv).as('a, 'b, 'c) + + val func1 = new TableFunc1 + val result = in.join(func1('c) as 's).select('c, 's).toDataSet[Row] + val results = result.collect() + val expected = "Jack#22,Jack\n" + "Jack#22,22\n" + "John#19,John\n" + "John#19,19\n" + + "Anna#44,Anna\n" + "Anna#44,44\n" + TestBaseUtils.compareResultAsText(results.asJava, expected) + + // with overloading + val result2 = in.join(func1('c, "$") as 's).select('c, 's).toDataSet[Row] + val results2 = result2.collect() + val expected2 = "Jack#22,$Jack\n" + "Jack#22,$22\n" + "John#19,$John\n" + + "John#19,$19\n" + "Anna#44,$Anna\n" + "Anna#44,$44\n" + TestBaseUtils.compareResultAsText(results2.asJava, expected2) + } + + @Test + def testLeftOuterJoin(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tableEnv = TableEnvironment.getTableEnvironment(env, config) + val in = testData(env).toTable(tableEnv).as('a, 'b, 'c) + + val func2 = new TableFunc2 + val result = in.leftOuterJoin(func2('c) as ('s, 'l)).select('c, 's, 'l).toDataSet[Row] + val results = result.collect() + val expected = "Jack#22,Jack,4\n" + "Jack#22,22,2\n" + "John#19,John,4\n" + + "John#19,19,2\n" + "Anna#44,Anna,4\n" + "Anna#44,44,2\n" + "nosharp,null,null" + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testWithFilter(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tableEnv = TableEnvironment.getTableEnvironment(env, config) + val in = testData(env).toTable(tableEnv).as('a, 'b, 'c) + val func0 = new TableFunc0 + + val result = in + .join(func0('c) as ('name, 'age)) + .select('c, 'name, 'age) + .filter('age > 20) + .toDataSet[Row] + + val results = result.collect() + val expected = "Jack#22,Jack,22\n" + "Anna#44,Anna,44\n" + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testCustomReturnType(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tableEnv = TableEnvironment.getTableEnvironment(env, config) + val in = testData(env).toTable(tableEnv).as('a, 'b, 'c) + val func2 = new TableFunc2 + + val result = in + .join(func2('c) as ('name, 'len)) + .select('c, 'name, 'len) + .toDataSet[Row] + + val results = result.collect() + val expected = "Jack#22,Jack,4\n" + "Jack#22,22,2\n" + "John#19,John,4\n" + + "John#19,19,2\n" + "Anna#44,Anna,4\n" + "Anna#44,44,2\n" + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testHierarchyType(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tableEnv = TableEnvironment.getTableEnvironment(env, config) + val in = testData(env).toTable(tableEnv).as('a, 'b, 'c) + + val hierarchy = new HierarchyTableFunction + val result = in + .join(hierarchy('c) as ('name, 'adult, 'len)) + .select('c, 'name, 'adult, 'len) + .toDataSet[Row] + + val results = result.collect() + val expected = "Jack#22,Jack,true,22\n" + "John#19,John,false,19\n" + + "Anna#44,Anna,true,44\n" + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testPojoType(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tableEnv = TableEnvironment.getTableEnvironment(env, config) + val in = testData(env).toTable(tableEnv).as('a, 'b, 'c) + + val pojo = new PojoTableFunc() + val result = in + .join(pojo('c)) + .select('c, 'name, 'age) + .toDataSet[Row] + + val results = result.collect() + val expected = "Jack#22,Jack,22\n" + "John#19,John,19\n" + "Anna#44,Anna,44\n" + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testUserDefinedTableFunctionWithScalarFunction(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tableEnv = TableEnvironment.getTableEnvironment(env, config) + val in = testData(env).toTable(tableEnv).as('a, 'b, 'c) + val func1 = new TableFunc1 + + val result = in + .join(func1('c.substring(2)) as 's) + .select('c, 's) + .toDataSet[Row] + + val results = result.collect() + val expected = "Jack#22,ack\n" + "Jack#22,22\n" + "John#19,ohn\n" + "John#19,19\n" + + "Anna#44,nna\n" + "Anna#44,44\n" + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testLongAndTemporalTypes(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tableEnv = TableEnvironment.getTableEnvironment(env, config) + val in = testData(env).toTable(tableEnv).as('a, 'b, 'c) + val func0 = new JavaTableFunc0 + + val result = in + .where('a === 1) + .select(Date.valueOf("1990-10-14") as 'x, + 1000L as 'y, + Timestamp.valueOf("1990-10-14 12:10:10") as 'z) + .join(func0('x, 'y, 'z) as 's) + .select('s) + .toDataSet[Row] + + val results = result.collect() + val expected = "1000\n" + "655906210000\n" + "7591\n" + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testUserDefinedTableFunctionWithParameter(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + val richTableFunc1 = new RichTableFunc1 + tEnv.registerFunction("RichTableFunc1", richTableFunc1) + UserDefinedFunctionTestUtils.setJobParameters(env, Map("word_separator" -> "#")) + + val result = testData(env) + .toTable(tEnv, 'a, 'b, 'c) + .join(richTableFunc1('c) as 's) + .select('a, 's) + + val expected = "1,Jack\n" + "1,22\n" + "2,John\n" + "2,19\n" + "3,Anna\n" + "3,44" + val results = result.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testUserDefinedTableFunctionWithScalarFunctionWithParameters(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + val richTableFunc1 = new RichTableFunc1 + tEnv.registerFunction("RichTableFunc1", richTableFunc1) + val richFunc2 = new RichFunc2 + tEnv.registerFunction("RichFunc2", richFunc2) + UserDefinedFunctionTestUtils.setJobParameters( + env, + Map("word_separator" -> "#", "string.value" -> "test")) + + val result = CollectionDataSets.getSmall3TupleDataSet(env) + .toTable(tEnv, 'a, 'b, 'c) + .join(richTableFunc1(richFunc2('c)) as 's) + .select('a, 's) + + val expected = "1,Hi\n1,test\n2,Hello\n2,test\n3,Hello world\n3,test" + val results = result.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testTableFunctionConstructorWithParams(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tableEnv = TableEnvironment.getTableEnvironment(env, config) + val in = testData(env).toTable(tableEnv).as('a, 'b, 'c) + val func30 = new TableFunc3(null) + val func31 = new TableFunc3("OneConf_") + val func32 = new TableFunc3("TwoConf_") + + val result = in + .join(func30('c) as('d, 'e)) + .select('c, 'd, 'e) + .join(func31('c) as ('f, 'g)) + .select('c, 'd, 'e, 'f, 'g) + .join(func32('c) as ('h, 'i)) + .select('c, 'd, 'f, 'h, 'e, 'g, 'i) + .toDataSet[Row] + + val results = result.collect() + + val expected = "Anna#44,Anna,OneConf_Anna,TwoConf_Anna,44,44,44\n" + + "Jack#22,Jack,OneConf_Jack,TwoConf_Jack,22,22,22\n" + + "John#19,John,OneConf_John,TwoConf_John,19,19,19\n" + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testScalarFunctionConstructorWithParams(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tableEnv = TableEnvironment.getTableEnvironment(env, config) + val in = testData(env).toTable(tableEnv).as('a, 'b, 'c) + + val func0 = new Func13("default") + val func1 = new Func13("Sunny") + val func2 = new Func13("kevin2") + + val result = in.select(func0('c), func1('c),func2('c)) + + val results = result.collect() + + val expected = "default-Anna#44,Sunny-Anna#44,kevin2-Anna#44\n" + + "default-Jack#22,Sunny-Jack#22,kevin2-Jack#22\n" + + "default-John#19,Sunny-John#19,kevin2-John#19\n" + + "default-nosharp,Sunny-nosharp,kevin2-nosharp" + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + private def testData( + env: ExecutionEnvironment) + : DataSet[(Int, Long, String)] = { + + val data = new mutable.MutableList[(Int, Long, String)] + data.+=((1, 1L, "Jack#22")) + data.+=((2, 2L, "John#19")) + data.+=((3, 2L, "Anna#44")) + data.+=((4, 3L, "nosharp")) + env.fromCollection(data) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/45e01cf2/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataSetUserDefinedFunctionITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataSetUserDefinedFunctionITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataSetUserDefinedFunctionITCase.scala new file mode 100644 index 0000000..21b87e9 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataSetUserDefinedFunctionITCase.scala @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.datastream + +import org.apache.flink.api.scala._ +import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment} +import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase +import org.apache.flink.table.api.TableEnvironment +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, StreamTestData} +import org.apache.flink.table.expressions.utils.{Func13, RichFunc2} +import org.apache.flink.table.utils.{RichTableFunc1, TableFunc0, TableFunc3, UserDefinedFunctionTestUtils} +import org.apache.flink.types.Row +import org.junit.Assert._ +import org.junit.Test + +import scala.collection.mutable + +class DataSetUserDefinedFunctionITCase extends StreamingMultipleProgramsTestBase { + + @Test + def testCrossJoin(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val t = testData(env).toTable(tEnv).as('a, 'b, 'c) + val func0 = new TableFunc0 + + val result = t + .join(func0('c) as('d, 'e)) + .select('c, 'd, 'e) + .toDataStream[Row] + + result.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList("Jack#22,Jack,22", "John#19,John,19", "Anna#44,Anna,44") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testLeftOuterJoin(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val t = testData(env).toTable(tEnv).as('a, 'b, 'c) + val func0 = new TableFunc0 + + val result = t + .leftOuterJoin(func0('c) as('d, 'e)) + .select('c, 'd, 'e) + .toDataStream[Row] + + result.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList( + "nosharp,null,null", "Jack#22,Jack,22", + "John#19,John,19", "Anna#44,Anna,44") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testUserDefinedTableFunctionWithParameter(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + val tableFunc1 = new RichTableFunc1 + tEnv.registerFunction("RichTableFunc1", tableFunc1) + UserDefinedFunctionTestUtils.setJobParameters(env, Map("word_separator" -> " ")) + StreamITCase.testResults = mutable.MutableList() + + val result = StreamTestData.getSmall3TupleDataStream(env) + .toTable(tEnv, 'a, 'b, 'c) + .join(tableFunc1('c) as 's) + .select('a, 's) + + val results = result.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList("3,Hello", "3,world") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testUserDefinedTableFunctionWithUserDefinedScalarFunction(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + val tableFunc1 = new RichTableFunc1 + val richFunc2 = new RichFunc2 + tEnv.registerFunction("RichTableFunc1", tableFunc1) + tEnv.registerFunction("RichFunc2", richFunc2) + UserDefinedFunctionTestUtils.setJobParameters( + env, + Map("word_separator" -> "#", "string.value" -> "test")) + StreamITCase.testResults = mutable.MutableList() + + val result = StreamTestData.getSmall3TupleDataStream(env) + .toTable(tEnv, 'a, 'b, 'c) + .join(tableFunc1(richFunc2('c)) as 's) + .select('a, 's) + + val results = result.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList( + "1,Hi", + "1,test", + "2,Hello", + "2,test", + "3,Hello world", + "3,test") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testTableFunctionConstructorWithParams(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val t = testData(env).toTable(tEnv).as('a, 'b, 'c) + val config = Map("key1" -> "value1", "key2" -> "value2") + val func30 = new TableFunc3(null) + val func31 = new TableFunc3("OneConf_") + val func32 = new TableFunc3("TwoConf_", config) + + val result = t + .join(func30('c) as('d, 'e)) + .select('c, 'd, 'e) + .join(func31('c) as ('f, 'g)) + .select('c, 'd, 'e, 'f, 'g) + .join(func32('c) as ('h, 'i)) + .select('c, 'd, 'f, 'h, 'e, 'g, 'i) + .toDataStream[Row] + + result.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList( + "Anna#44,Anna,OneConf_Anna,TwoConf__key=key1_value=value1_Anna,44,44,44", + "Anna#44,Anna,OneConf_Anna,TwoConf__key=key2_value=value2_Anna,44,44,44", + "Jack#22,Jack,OneConf_Jack,TwoConf__key=key1_value=value1_Jack,22,22,22", + "Jack#22,Jack,OneConf_Jack,TwoConf__key=key2_value=value2_Jack,22,22,22", + "John#19,John,OneConf_John,TwoConf__key=key1_value=value1_John,19,19,19", + "John#19,John,OneConf_John,TwoConf__key=key2_value=value2_John,19,19,19" + ) + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testScalarFunctionConstructorWithParams(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val t = testData(env).toTable(tEnv).as('a, 'b, 'c) + val func0 = new Func13("default") + val func1 = new Func13("Sunny") + val func2 = new Func13("kevin2") + + val result = t.select(func0('c), func1('c),func2('c)) + + result.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList( + "default-Anna#44,Sunny-Anna#44,kevin2-Anna#44", + "default-Jack#22,Sunny-Jack#22,kevin2-Jack#22", + "default-John#19,Sunny-John#19,kevin2-John#19", + "default-nosharp,Sunny-nosharp,kevin2-nosharp" + ) + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + private def testData( + env: StreamExecutionEnvironment) + : DataStream[(Int, Long, String)] = { + + val data = new mutable.MutableList[(Int, Long, String)] + data.+=((1, 1L, "Jack#22")) + data.+=((2, 2L, "John#19")) + data.+=((3, 2L, "Anna#44")) + data.+=((4, 3L, "nosharp")) + env.fromCollection(data) + } + +} http://git-wip-us.apache.org/repos/asf/flink/blob/45e01cf2/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamCorrelateITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamCorrelateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamCorrelateITCase.scala deleted file mode 100644 index f8a697d..0000000 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamCorrelateITCase.scala +++ /dev/null @@ -1,146 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.table.runtime.datastream - -import org.apache.flink.api.scala._ -import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment} -import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase -import org.apache.flink.table.api.TableEnvironment -import org.apache.flink.table.api.scala._ -import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, StreamTestData} -import org.apache.flink.table.expressions.utils.RichFunc2 -import org.apache.flink.table.utils.{RichTableFunc1, TableFunc0, UserDefinedFunctionTestUtils} -import org.apache.flink.types.Row -import org.junit.Assert._ -import org.junit.Test - -import scala.collection.mutable - -class DataStreamCorrelateITCase extends StreamingMultipleProgramsTestBase { - - @Test - def testCrossJoin(): Unit = { - val env = StreamExecutionEnvironment.getExecutionEnvironment - val tEnv = TableEnvironment.getTableEnvironment(env) - StreamITCase.clear - - val t = testData(env).toTable(tEnv).as('a, 'b, 'c) - val func0 = new TableFunc0 - - val result = t - .join(func0('c) as('d, 'e)) - .select('c, 'd, 'e) - .toDataStream[Row] - - result.addSink(new StreamITCase.StringSink) - env.execute() - - val expected = mutable.MutableList("Jack#22,Jack,22", "John#19,John,19", "Anna#44,Anna,44") - assertEquals(expected.sorted, StreamITCase.testResults.sorted) - } - - @Test - def testLeftOuterJoin(): Unit = { - val env = StreamExecutionEnvironment.getExecutionEnvironment - val tEnv = TableEnvironment.getTableEnvironment(env) - StreamITCase.clear - - val t = testData(env).toTable(tEnv).as('a, 'b, 'c) - val func0 = new TableFunc0 - - val result = t - .leftOuterJoin(func0('c) as('d, 'e)) - .select('c, 'd, 'e) - .toDataStream[Row] - - result.addSink(new StreamITCase.StringSink) - env.execute() - - val expected = mutable.MutableList( - "nosharp,null,null", "Jack#22,Jack,22", - "John#19,John,19", "Anna#44,Anna,44") - assertEquals(expected.sorted, StreamITCase.testResults.sorted) - } - - @Test - def testUserDefinedTableFunctionWithParameter(): Unit = { - val env = StreamExecutionEnvironment.getExecutionEnvironment - val tEnv = TableEnvironment.getTableEnvironment(env) - val tableFunc1 = new RichTableFunc1 - tEnv.registerFunction("RichTableFunc1", tableFunc1) - UserDefinedFunctionTestUtils.setJobParameters(env, Map("word_separator" -> " ")) - StreamITCase.testResults = mutable.MutableList() - - val result = StreamTestData.getSmall3TupleDataStream(env) - .toTable(tEnv, 'a, 'b, 'c) - .join(tableFunc1('c) as 's) - .select('a, 's) - - val results = result.toDataStream[Row] - results.addSink(new StreamITCase.StringSink) - env.execute() - - val expected = mutable.MutableList("3,Hello", "3,world") - assertEquals(expected.sorted, StreamITCase.testResults.sorted) - } - - @Test - def testUserDefinedTableFunctionWithUserDefinedScalarFunction(): Unit = { - val env = StreamExecutionEnvironment.getExecutionEnvironment - val tEnv = TableEnvironment.getTableEnvironment(env) - val tableFunc1 = new RichTableFunc1 - val richFunc2 = new RichFunc2 - tEnv.registerFunction("RichTableFunc1", tableFunc1) - tEnv.registerFunction("RichFunc2", richFunc2) - UserDefinedFunctionTestUtils.setJobParameters( - env, - Map("word_separator" -> "#", "string.value" -> "test")) - StreamITCase.testResults = mutable.MutableList() - - val result = StreamTestData.getSmall3TupleDataStream(env) - .toTable(tEnv, 'a, 'b, 'c) - .join(tableFunc1(richFunc2('c)) as 's) - .select('a, 's) - - val results = result.toDataStream[Row] - results.addSink(new StreamITCase.StringSink) - env.execute() - - val expected = mutable.MutableList( - "1,Hi", - "1,test", - "2,Hello", - "2,test", - "3,Hello world", - "3,test") - assertEquals(expected.sorted, StreamITCase.testResults.sorted) - } - - private def testData( - env: StreamExecutionEnvironment) - : DataStream[(Int, Long, String)] = { - - val data = new mutable.MutableList[(Int, Long, String)] - data.+=((1, 1L, "Jack#22")) - data.+=((2, 2L, "John#19")) - data.+=((3, 2L, "Anna#44")) - data.+=((4, 3L, "nosharp")) - env.fromCollection(data) - } - -} http://git-wip-us.apache.org/repos/asf/flink/blob/45e01cf2/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala index 5db9d5f..88917a2 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala @@ -54,7 +54,6 @@ class TableFunc1 extends TableFunction[String] { } } - class TableFunc2 extends TableFunction[Row] { def eval(str: String): Unit = { if (str.contains("#")) { @@ -73,6 +72,41 @@ class TableFunc2 extends TableFunction[Row] { } } +class TableFunc3(data: String, conf: Map[String, String]) extends TableFunction[SimpleUser] { + + def this(data: String) { + this(data, null) + } + + def eval(user: String): Unit = { + if (user.contains("#")) { + val splits = user.split("#") + if (null != data) { + if (null != conf && conf.size > 0) { + val it = conf.keys.iterator + while (it.hasNext) { + val key = it.next() + val value = conf.get(key).get + collect( + SimpleUser( + data.concat("_key=") + .concat(key) + .concat("_value=") + .concat(value) + .concat("_") + .concat(splits(0)), + splits(1).toInt)) + } + } else { + collect(SimpleUser(data.concat(splits(0)), splits(1).toInt)) + } + } else { + collect(SimpleUser(splits(0), splits(1).toInt)) + } + } + } +} + class HierarchyTableFunction extends SplittableTableFunction[Boolean, Integer] { def eval(user: String) { if (user.contains("#")) {
