Repository: flink Updated Branches: refs/heads/release-1.3 4c6b6c29d -> 310817035
[FLINK-6226] [table] Add tests for UDFs with Byte, Short, and Float arguments. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/31081703 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/31081703 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/31081703 Branch: refs/heads/release-1.3 Commit: 310817035fef1d843b37e00ccd4a32efffaad3dc Parents: 4c6b6c2 Author: Fabian Hueske <[email protected]> Authored: Thu Nov 2 21:10:03 2017 +0100 Committer: Fabian Hueske <[email protected]> Committed: Fri Nov 3 10:41:57 2017 +0100 ---------------------------------------------------------------------- .../UserDefinedScalarFunctionTest.scala | 28 ++++++++++++++++++-- .../utils/UserDefinedScalarFunctions.scala | 6 +++++ .../DataSetUserDefinedFunctionITCase.scala | 23 +++++++++++++++- .../table/utils/UserDefinedTableFunctions.scala | 12 ++++++++- 4 files changed, 65 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/31081703/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala index 56cdf3c..df83f7c 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala @@ -48,6 +48,24 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase { "43") testAllApis( + Func1('f11), + "Func1(f11)", + "Func1(f11)", + "4") + + testAllApis( + Func1('f12), + "Func1(f12)", + "Func1(f12)", + "4") + + testAllApis( + Func1('f13), + "Func1(f13)", + "Func1(f13)", + "4.0") + + testAllApis( Func2('f0, 'f1, 'f3), "Func2(f0, f1, f3)", "Func2(f0, f1, f3)", @@ -360,7 +378,7 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase { // ---------------------------------------------------------------------------------------------- override def testData: Any = { - val testData = new Row(11) + val testData = new Row(14) testData.setField(0, 42) testData.setField(1, "Test") testData.setField(2, null) @@ -372,6 +390,9 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase { testData.setField(8, 1000L) testData.setField(9, Seq("Hello", "World")) testData.setField(10, Array[Integer](1, 2, null)) + testData.setField(11, 3.toByte) + testData.setField(12, 3.toShort) + testData.setField(13, 3.toFloat) testData } @@ -387,7 +408,10 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase { Types.INTERVAL_MONTHS, Types.INTERVAL_MILLIS, TypeInformation.of(classOf[Seq[String]]), - BasicArrayTypeInfo.INT_ARRAY_TYPE_INFO + BasicArrayTypeInfo.INT_ARRAY_TYPE_INFO, + Types.BYTE, + Types.SHORT, + Types.FLOAT ).asInstanceOf[TypeInformation[Any]] } http://git-wip-us.apache.org/repos/asf/flink/blob/31081703/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala index 5285569..9535cdf 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala @@ -41,6 +41,12 @@ object Func1 extends ScalarFunction { def eval(index: Integer): Integer = { index + 1 } + + def eval(b: Byte): Byte = (b + 1).toByte + + def eval(s: Short): Short = (s + 1).toShort + + def eval(f: Float): Float = f + 1 } object Func2 extends ScalarFunction { http://git-wip-us.apache.org/repos/asf/flink/blob/31081703/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetUserDefinedFunctionITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetUserDefinedFunctionITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetUserDefinedFunctionITCase.scala index 1860d3c..9755596 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetUserDefinedFunctionITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetUserDefinedFunctionITCase.scala @@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp} import org.apache.flink.api.scala._ import org.apache.flink.api.scala.util.CollectionDataSets import org.apache.flink.table.api.java.utils.UserDefinedTableFunctions.JavaTableFunc0 -import org.apache.flink.table.api.{TableEnvironment, ValidationException} +import org.apache.flink.table.api.{TableEnvironment, Types, ValidationException} import org.apache.flink.table.api.scala._ import org.apache.flink.table.api.scala.batch.utils.TableProgramsClusterTestBase import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode @@ -229,6 +229,27 @@ class DataSetUserDefinedFunctionITCase( } @Test + def testByteShortFloatArguments(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tableEnv = TableEnvironment.getTableEnvironment(env, config) + val in = testData(env).toTable(tableEnv).as('a, 'b, 'c) + val tFunc = new TableFunc4 + + val result = in + .select('a.cast(Types.BYTE) as 'a, 'a.cast(Types.SHORT) as 'b, 'b.cast(Types.FLOAT) as 'c) + .join(tFunc('a, 'b, 'c) as ('a2, 'b2, 'c2)) + .toDataSet[Row] + + val results = result.collect() + val expected = Seq( + "1,1,1.0,Byte=1,Short=1,Float=1.0", + "2,2,2.0,Byte=2,Short=2,Float=2.0", + "3,3,2.0,Byte=3,Short=3,Float=2.0", + "4,4,3.0,Byte=4,Short=4,Float=3.0").mkString("\n") + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test def testUserDefinedTableFunctionWithParameter(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env) http://git-wip-us.apache.org/repos/asf/flink/blob/31081703/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala index d0ffade..e1af23b 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala @@ -22,7 +22,7 @@ import java.lang.Boolean import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.java.tuple.Tuple3 import org.apache.flink.api.java.typeutils.RowTypeInfo -import org.apache.flink.table.api.ValidationException +import org.apache.flink.table.api.{Types, ValidationException} import org.apache.flink.table.functions.{FunctionContext, TableFunction} import org.apache.flink.types.Row import org.junit.Assert @@ -109,6 +109,16 @@ class TableFunc3(data: String, conf: Map[String, String]) extends TableFunction[ } } +class TableFunc4 extends TableFunction[Row] { + def eval(b: Byte, s: Short, f: Float): Unit = { + collect(Row.of("Byte=" + b, "Short=" + s, "Float=" + f)) + } + + override def getResultType: TypeInformation[Row] = { + new RowTypeInfo(Types.STRING, Types.STRING, Types.STRING) + } +} + class HierarchyTableFunction extends SplittableTableFunction[Boolean, Integer] { def eval(user: String) { if (user.contains("#")) {
