twalthr closed pull request #3456: [FLINK-5832] [table] Support for simple hive 
UDF
URL: https://github.com/apache/flink/pull/3456
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/flink-connectors/flink-hcatalog/pom.xml 
b/flink-connectors/flink-hcatalog/pom.xml
index ba0e142692b..afb6914281b 100644
--- a/flink-connectors/flink-hcatalog/pom.xml
+++ b/flink-connectors/flink-hcatalog/pom.xml
@@ -42,6 +42,12 @@ under the License.
                        <scope>provided</scope>
                </dependency>
 
+               <dependency>
+                       <groupId>org.apache.flink</groupId>
+                       <artifactId>flink-table_2.10</artifactId>
+                       <version>${project.version}</version>
+               </dependency>
+
                <dependency>
                        <groupId>org.apache.flink</groupId>
                        <artifactId>flink-hadoop-compatibility_2.10</artifactId>
@@ -50,8 +56,8 @@ under the License.
 
                <dependency>
                        <groupId>org.apache.hive.hcatalog</groupId>
-                       <artifactId>hcatalog-core</artifactId>
-                       <version>0.12.0</version>
+                       <artifactId>hive-hcatalog-core</artifactId>
+                       <version>0.13.0</version>
                        <exclusions>
                                <exclusion>
                                        <groupId>org.json</groupId>
diff --git 
a/flink-connectors/flink-hcatalog/src/main/scala/org/apache/flink/table/hive/functions/HiveFunctionWrapper.scala
 
b/flink-connectors/flink-hcatalog/src/main/scala/org/apache/flink/table/hive/functions/HiveFunctionWrapper.scala
new file mode 100644
index 00000000000..27f042036d9
--- /dev/null
+++ 
b/flink-connectors/flink-hcatalog/src/main/scala/org/apache/flink/table/hive/functions/HiveFunctionWrapper.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.hive.functions
+
+import org.apache.hadoop.hive.ql.exec.UDF
+
+private[hive] case class HiveFunctionWrapper(
+  var functionClassName: String,
+  private var instance: AnyRef = null) {
+
+  def createFunction[UDFType <: AnyRef](): UDFType = {
+    if (instance != null) {
+      instance.asInstanceOf[UDFType]
+    } else {
+      val func = 
getClassLoader.loadClass(functionClassName).newInstance().asInstanceOf[UDFType]
+      if (!func.isInstanceOf[UDF]) {
+        instance = func
+      }
+      func
+    }
+  }
+
+  def getClassLoader: ClassLoader = {
+    Thread.currentThread.getContextClassLoader
+  }
+
+}
diff --git 
a/flink-connectors/flink-hcatalog/src/main/scala/org/apache/flink/table/hive/functions/HiveSimpleUDF.scala
 
b/flink-connectors/flink-hcatalog/src/main/scala/org/apache/flink/table/hive/functions/HiveSimpleUDF.scala
new file mode 100644
index 00000000000..00018bb47d0
--- /dev/null
+++ 
b/flink-connectors/flink-hcatalog/src/main/scala/org/apache/flink/table/hive/functions/HiveSimpleUDF.scala
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.hive.functions
+
+import java.lang.reflect.Method
+import java.math.BigDecimal
+import java.util
+
+import org.apache.flink.table.functions.ScalarFunction
+import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, UDF}
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
+import org.apache.hadoop.hive.serde2.typeinfo.{PrimitiveTypeInfo, TypeInfo, 
TypeInfoFactory}
+
+import scala.annotation.varargs
+
+/**
+  * A Hive UDF Wrapper which behaves as a Flink-table ScalarFunction.
+  *
+  * This class has to have a method with @varargs annotation. For scala will 
compile
+  * <code> eval(args: Any*) </code> to <code>eval(args: Seq)</code>.
+  * This will cause an exception in Janino compiler.
+  */
+class HiveSimpleUDF(className: String) extends ScalarFunction {
+
+  @transient
+  private lazy val functionWrapper = HiveFunctionWrapper(className)
+
+  @transient
+  private lazy val function = functionWrapper.createFunction[UDF]()
+
+  @transient
+  private var typeInfos: util.List[TypeInfo] = _
+
+  @transient
+  private var objectInspectors: Array[ObjectInspector] = _
+
+  @transient
+  private var conversionHelper: ConversionHelper = _
+
+  @transient
+  private var method: Method = _
+
+  @varargs
+  def eval(args: AnyRef*) : Any = {
+    if (null == typeInfos) {
+      typeInfos = new util.ArrayList[TypeInfo]()
+      args.foreach(arg => {
+          
typeInfos.add(TypeInfoFactory.getPrimitiveTypeInfoFromJavaPrimitive(arg.getClass))
+      })
+      method = function.getResolver.getEvalMethod(typeInfos)
+
+      objectInspectors = new Array[ObjectInspector](typeInfos.size())
+      args.zipWithIndex.foreach { case (_, i) =>
+        objectInspectors(i) = 
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+          typeInfos.get(i).asInstanceOf[PrimitiveTypeInfo])
+      }
+      conversionHelper = new ConversionHelper(method, objectInspectors)
+    }
+
+    val mappedArgs = args.map {
+      case arg: BigDecimal =>
+        arg.asInstanceOf[BigDecimal].doubleValue().asInstanceOf[AnyRef]
+      case arg: AnyRef =>
+        arg
+    }
+
+    FunctionRegistry.invoke(method, function,
+      conversionHelper.convertIfNecessary(mappedArgs: _*): _*)
+  }
+}
diff --git 
a/flink-connectors/flink-hcatalog/src/test/scala/org/apache/flink/table/hive/functions/HiveScalarFunctionTest.scala
 
b/flink-connectors/flink-hcatalog/src/test/scala/org/apache/flink/table/hive/functions/HiveScalarFunctionTest.scala
new file mode 100644
index 00000000000..98dd3666e37
--- /dev/null
+++ 
b/flink-connectors/flink-hcatalog/src/test/scala/org/apache/flink/table/hive/functions/HiveScalarFunctionTest.scala
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.hive.functions
+
+import java.sql.{Date, Time, Timestamp}
+
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.typeutils.RowTypeInfo
+import org.apache.flink.table.api.Types
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.functions.ScalarFunction
+import org.apache.flink.table.hive.functions.utils.{ExpressionTestBase, 
SimplePojo}
+import org.apache.flink.types.Row
+import org.junit.Test
+
+class HiveScalarFunctionTest extends ExpressionTestBase {
+
+  @Test
+  def testHiveSimpleFunctions(): Unit = {
+    val HiveUDFAcos = new 
HiveSimpleUDF("org.apache.hadoop.hive.ql.udf.UDFAcos")
+    testAllApis(
+      HiveUDFAcos(1.0),
+      "HiveUDFAcos(1.0)",
+      "HiveUDFAcos(1.0)",
+      "0.0"
+    )
+
+    val HiveUDFAscii = new 
HiveSimpleUDF("org.apache.hadoop.hive.ql.udf.UDFAscii")
+    testAllApis(
+      HiveUDFAscii("0"),
+      "HiveUDFAscii('0')",
+      "HiveUDFAscii('0')",
+      "48"
+    )
+
+    val HiveUDFAsin = new 
HiveSimpleUDF("org.apache.hadoop.hive.ql.udf.UDFAsin")
+    testAllApis(
+      HiveUDFAsin("0"),
+      "HiveUDFAsin('0')",
+      "HiveUDFAsin('0')",
+      "0.0"
+    )
+
+    val HiveUDFBin = new HiveSimpleUDF("org.apache.hadoop.hive.ql.udf.UDFBin")
+    testAllApis(
+      HiveUDFBin(13),
+      "HiveUDFBin(13)",
+      "HiveUDFBin(13)",
+      "1101"
+    )
+
+    val HiveUDFConv = new 
HiveSimpleUDF("org.apache.hadoop.hive.ql.udf.UDFConv")
+    testAllApis(
+      HiveUDFConv("100", 2, 10),
+      "HiveUDFConv('100', 2, 10)",
+      "HiveUDFConv('100', 2, 10)",
+      "4"
+    )
+    testAllApis(
+      HiveUDFConv(-10, 16, -10),
+      "HiveUDFConv(-10, 16, -10)",
+      "HiveUDFConv(-10, 16, -10)",
+      "-16"
+    )
+
+    val HiveUDFCos = new HiveSimpleUDF("org.apache.hadoop.hive.ql.udf.UDFCos")
+    testAllApis(
+      HiveUDFCos(0.0),
+      "HiveUDFCos(0.0)",
+      "HiveUDFCos(0.0)",
+      "1.0"
+    )
+
+    val HiveUDFDayOfMonth = new 
HiveSimpleUDF("org.apache.hadoop.hive.ql.udf.UDFDayOfMonth")
+    testAllApis(
+      HiveUDFDayOfMonth("2009-07-30"),
+      "HiveUDFDayOfMonth('2009-07-30')",
+      "HiveUDFDayOfMonth('2009-07-30')",
+      "30"
+    )
+  }
+
+  // 
----------------------------------------------------------------------------------------------
+
+  override def testData: Any = {
+    val testData = new Row(9)
+    testData.setField(0, 42)
+    testData.setField(1, "Test")
+    testData.setField(2, null)
+    testData.setField(3, SimplePojo("Bob", 36))
+    testData.setField(4, Date.valueOf("1990-10-14"))
+    testData.setField(5, Time.valueOf("12:10:10"))
+    testData.setField(6, Timestamp.valueOf("1990-10-14 12:10:10"))
+    testData.setField(7, 12)
+    testData.setField(8, 1000L)
+    testData
+  }
+
+  override def typeInfo: TypeInformation[Any] = {
+    new RowTypeInfo(
+      Types.INT,
+      Types.STRING,
+      Types.BOOLEAN,
+      TypeInformation.of(classOf[SimplePojo]),
+      Types.DATE,
+      Types.TIME,
+      Types.TIMESTAMP,
+      Types.INTERVAL_MONTHS,
+      Types.INTERVAL_MILLIS
+    ).asInstanceOf[TypeInformation[Any]]
+  }
+
+  override def functions: Map[String, ScalarFunction] = Map(
+    "HiveUDFAcos" -> new 
HiveSimpleUDF("org.apache.hadoop.hive.ql.udf.UDFAcos"),
+    "HiveUDFAscii" -> new 
HiveSimpleUDF("org.apache.hadoop.hive.ql.udf.UDFAscii"),
+    "HiveUDFAsin" -> new 
HiveSimpleUDF("org.apache.hadoop.hive.ql.udf.UDFAsin"),
+    "HiveUDFBin" -> new HiveSimpleUDF("org.apache.hadoop.hive.ql.udf.UDFBin"),
+    "HiveUDFConv" -> new 
HiveSimpleUDF("org.apache.hadoop.hive.ql.udf.UDFConv"),
+    "HiveUDFCos" -> new HiveSimpleUDF("org.apache.hadoop.hive.ql.udf.UDFCos"),
+    "HiveUDFDayOfMonth" -> new 
HiveSimpleUDF("org.apache.hadoop.hive.ql.udf.UDFDayOfMonth")
+  )
+}
diff --git 
a/flink-connectors/flink-hcatalog/src/test/scala/org/apache/flink/table/hive/functions/utils/ExpressionTestBase.scala
 
b/flink-connectors/flink-hcatalog/src/test/scala/org/apache/flink/table/hive/functions/utils/ExpressionTestBase.scala
new file mode 100644
index 00000000000..f9eb1bf445f
--- /dev/null
+++ 
b/flink-connectors/flink-hcatalog/src/test/scala/org/apache/flink/table/hive/functions/utils/ExpressionTestBase.scala
@@ -0,0 +1,270 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.hive.functions.utils
+
+import java.util
+import java.util.concurrent.Future
+
+import org.apache.calcite.plan.hep.{HepMatchOrder, HepPlanner, 
HepProgramBuilder}
+import org.apache.calcite.rex.RexNode
+import org.apache.calcite.sql.`type`.SqlTypeName._
+import org.apache.calcite.sql2rel.RelDecorrelator
+import org.apache.calcite.tools.{Programs, RelBuilder}
+import org.apache.flink.api.common.TaskInfo
+import org.apache.flink.api.common.accumulators.Accumulator
+import org.apache.flink.api.common.functions._
+import org.apache.flink.api.common.functions.util.RuntimeUDFContext
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.typeutils.RowTypeInfo
+import org.apache.flink.api.java.{DataSet => JDataSet}
+import org.apache.flink.api.scala.{DataSet, ExecutionEnvironment}
+import org.apache.flink.configuration.Configuration
+import org.apache.flink.core.fs.Path
+import org.apache.flink.table.api.{BatchTableEnvironment, TableConfig, 
TableEnvironment}
+import org.apache.flink.table.calcite.FlinkPlannerImpl
+import org.apache.flink.table.codegen.{CodeGenerator, Compiler, 
GeneratedFunction}
+import org.apache.flink.table.expressions.{Expression, ExpressionParser}
+import org.apache.flink.table.functions.ScalarFunction
+import org.apache.flink.table.plan.nodes.dataset.{DataSetCalc, 
DataSetConvention}
+import org.apache.flink.table.plan.rules.FlinkRuleSets
+import org.apache.flink.types.Row
+import org.junit.Assert._
+import org.junit.{After, Before}
+import org.mockito.Mockito._
+
+import scala.collection.mutable
+
+/**
+  * Base test class for expression tests.
+  */
+abstract class ExpressionTestBase {
+
+  private val testExprs = mutable.ArrayBuffer[(RexNode, String)]()
+
+  // setup test utils
+  private val tableName = "testTable"
+  private val context = prepareContext(typeInfo)
+  private val planner = new FlinkPlannerImpl(
+    context._2.getFrameworkConfig,
+    context._2.getPlanner,
+    context._2.getTypeFactory)
+  private val optProgram = Programs.ofRules(FlinkRuleSets.DATASET_OPT_RULES)
+
+  private def hepPlanner = {
+    val builder = new HepProgramBuilder
+    builder.addMatchOrder(HepMatchOrder.BOTTOM_UP)
+    val it = FlinkRuleSets.DATASET_NORM_RULES.iterator()
+    while (it.hasNext) {
+      builder.addRuleInstance(it.next())
+    }
+    new HepPlanner(builder.build, context._2.getFrameworkConfig.getContext)
+  }
+
+  private def prepareContext(typeInfo: TypeInformation[Any])
+    : (RelBuilder, TableEnvironment, ExecutionEnvironment) = {
+    // create DataSetTable
+    val dataSetMock = mock(classOf[DataSet[Any]])
+    val jDataSetMock = mock(classOf[JDataSet[Any]])
+    when(dataSetMock.javaSet).thenReturn(jDataSetMock)
+    when(jDataSetMock.getType).thenReturn(typeInfo)
+
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+    tEnv.registerDataSet(tableName, dataSetMock)
+    functions.foreach(f => tEnv.registerFunction(f._1, f._2))
+
+    // prepare RelBuilder
+    val relBuilder = tEnv.getRelBuilder
+    relBuilder.scan(tableName)
+
+    (relBuilder, tEnv, env)
+  }
+
+  def testData: Any
+
+  def typeInfo: TypeInformation[Any]
+
+  def functions: Map[String, ScalarFunction] = Map()
+
+  @Before
+  def resetTestExprs() = {
+    testExprs.clear()
+  }
+
+  @After
+  def evaluateExprs() = {
+    val relBuilder = context._1
+    val config = new TableConfig()
+    val generator = new CodeGenerator(config, false, typeInfo)
+
+    // cast expressions to String
+    val stringTestExprs = testExprs.map(expr => relBuilder.cast(expr._1, 
VARCHAR))
+
+    // generate code
+    val resultType = new 
RowTypeInfo(Seq.fill(testExprs.size)(STRING_TYPE_INFO): _*)
+    val genExpr = generator.generateResultExpression(
+      resultType,
+      resultType.getFieldNames,
+      stringTestExprs)
+
+    val bodyCode =
+      s"""
+        |${genExpr.code}
+        |return ${genExpr.resultTerm};
+        |""".stripMargin
+
+    val genFunc = generator.generateFunction[MapFunction[Any, Row], Row](
+      "TestFunction",
+      classOf[MapFunction[Any, Row]],
+      bodyCode,
+      resultType)
+
+    // compile and evaluate
+    val clazz = new TestCompiler[MapFunction[Any, Row], Row]().compile(genFunc)
+    val mapper = clazz.newInstance()
+
+    val isRichFunction = mapper.isInstanceOf[RichFunction]
+
+    // call setRuntimeContext method and open method for RichFunction
+    if (isRichFunction) {
+      val richMapper = mapper.asInstanceOf[RichMapFunction[_, _]]
+      val t = new RuntimeUDFContext(
+        new TaskInfo("ExpressionTest", 1, 0, 1, 1),
+        null,
+        context._3.getConfig,
+        new util.HashMap[String, Future[Path]](),
+        new util.HashMap[String, Accumulator[_, _]](),
+        null)
+      richMapper.setRuntimeContext(t)
+      richMapper.open(new Configuration())
+    }
+
+    val result = mapper.map(testData)
+
+    // call close method for RichFunction
+    if (isRichFunction) {
+      mapper.asInstanceOf[RichMapFunction[_, _]].close()
+    }
+
+    // compare
+    testExprs
+      .zipWithIndex
+      .foreach {
+        case ((expr, expected), index) =>
+          val actual = result.getField(index)
+          assertEquals(
+            s"Wrong result for: $expr",
+            expected,
+            if (actual == null) "null" else actual)
+      }
+  }
+
+  private def addSqlTestExpr(sqlExpr: String, expected: String): Unit = {
+    // create RelNode from SQL expression
+    val parsed = planner.parse(s"SELECT $sqlExpr FROM $tableName")
+    val validated = planner.validate(parsed)
+    val converted = planner.rel(validated).rel
+
+    val decorPlan = RelDecorrelator.decorrelateQuery(converted)
+
+    // normalize
+    val normalizedPlan = if 
(FlinkRuleSets.DATASET_NORM_RULES.iterator().hasNext) {
+      val planner = hepPlanner
+      planner.setRoot(decorPlan)
+      planner.findBestExp
+    } else {
+      decorPlan
+    }
+
+    // create DataSetCalc
+    val flinkOutputProps = 
converted.getTraitSet.replace(DataSetConvention.INSTANCE).simplify()
+    val dataSetCalc = optProgram.run(context._2.getPlanner, normalizedPlan, 
flinkOutputProps)
+
+    // extract RexNode
+    val calcProgram = dataSetCalc
+     .asInstanceOf[DataSetCalc]
+     .calcProgram
+    val expanded = 
calcProgram.expandLocalRef(calcProgram.getProjectList.get(0))
+
+    testExprs += ((expanded, expected))
+  }
+
+  private def addTableApiTestExpr(tableApiExpr: Expression, expected: String): 
Unit = {
+    // create RelNode from Table API expression
+    val env = context._2
+    val converted = env
+      .asInstanceOf[BatchTableEnvironment]
+      .scan(tableName)
+      .select(tableApiExpr)
+      .getRelNode
+
+    // create DataSetCalc
+    val decorPlan = RelDecorrelator.decorrelateQuery(converted)
+    val flinkOutputProps = 
converted.getTraitSet.replace(DataSetConvention.INSTANCE).simplify()
+    val dataSetCalc = optProgram.run(context._2.getPlanner, decorPlan, 
flinkOutputProps)
+
+    // extract RexNode
+    val calcProgram = dataSetCalc
+     .asInstanceOf[DataSetCalc]
+     .calcProgram
+    val expanded = 
calcProgram.expandLocalRef(calcProgram.getProjectList.get(0))
+
+    testExprs += ((expanded, expected))
+  }
+
+  private def addTableApiTestExpr(tableApiString: String, expected: String): 
Unit = {
+    addTableApiTestExpr(ExpressionParser.parseExpression(tableApiString), 
expected)
+  }
+
+  def testAllApis(
+      expr: Expression,
+      exprString: String,
+      sqlExpr: String,
+      expected: String)
+    : Unit = {
+    addTableApiTestExpr(expr, expected)
+    addTableApiTestExpr(exprString, expected)
+    addSqlTestExpr(sqlExpr, expected)
+  }
+
+  def testTableApi(
+      expr: Expression,
+      exprString: String,
+      expected: String)
+    : Unit = {
+    addTableApiTestExpr(expr, expected)
+    addTableApiTestExpr(exprString, expected)
+  }
+
+  def testSqlApi(
+      sqlExpr: String,
+      expected: String)
+    : Unit = {
+    addSqlTestExpr(sqlExpr, expected)
+  }
+
+  // 
----------------------------------------------------------------------------------------------
+
+  // TestCompiler that uses current class loader
+  class TestCompiler[F <: Function, T <: Any] extends Compiler[F] {
+    def compile(genFunc: GeneratedFunction[F, T]): Class[F] =
+      compile(getClass.getClassLoader, genFunc.name, genFunc.code)
+  }
+}
diff --git 
a/flink-connectors/flink-hcatalog/src/test/scala/org/apache/flink/table/hive/functions/utils/SimplePojo.scala
 
b/flink-connectors/flink-hcatalog/src/test/scala/org/apache/flink/table/hive/functions/utils/SimplePojo.scala
new file mode 100644
index 00000000000..1ebe8bcad72
--- /dev/null
+++ 
b/flink-connectors/flink-hcatalog/src/test/scala/org/apache/flink/table/hive/functions/utils/SimplePojo.scala
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.hive.functions.utils
+
+case class SimplePojo(name: String, age: Int)
+
diff --git a/flink-libraries/flink-table/pom.xml 
b/flink-libraries/flink-table/pom.xml
index c6071b06162..2d9f6a38d73 100644
--- a/flink-libraries/flink-table/pom.xml
+++ b/flink-libraries/flink-table/pom.xml
@@ -92,7 +92,6 @@ under the License.
                        </exclusions>
                </dependency>
 
-
                <!-- test dependencies -->
 
                <dependency>
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala
index 7ff18eb6332..2a8ba28c3ca 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarFunctionCallGen.scala
@@ -44,14 +44,22 @@ class ScalarFunctionCallGen(
       operands: Seq[GeneratedExpression])
     : GeneratedExpression = {
     // determine function signature and result class
-    val matchingSignature = getSignature(scalarFunction, signature)
+    val matchingMethod = getEvalMethod(scalarFunction, signature)
       .getOrElse(throw new CodeGenException("No matching signature found."))
+    val matchingSignature = matchingMethod.getParameterTypes
     val resultClass = getResultTypeClass(scalarFunction, matchingSignature)
 
+    // zip for variable signatures
+    var paramToOperands = matchingSignature.zip(operands)
+    if (operands.length > matchingSignature.length) {
+      operands.drop(matchingSignature.length).foreach(op =>
+        paramToOperands = paramToOperands :+
+          (matchingSignature.last.getComponentType, op)
+      )
+    }
+
     // convert parameters for function (output boxing)
-    val parameters = matchingSignature
-        .zip(operands)
-        .map { case (paramClass, operandExpr) =>
+    val parameters = paramToOperands.map { case (paramClass, operandExpr) =>
           if (paramClass.isPrimitive) {
             operandExpr
           } else if (ClassUtils.isPrimitiveWrapper(paramClass)
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala
index da652e043d1..11021203498 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala
@@ -112,9 +112,16 @@ object ScalarSqlFunction {
           .getParameterTypes(foundSignature)
           .map(typeFactory.createTypeFromTypeInfo)
 
-        inferredTypes.zipWithIndex.foreach {
-          case (inferredType, i) =>
-            operandTypes(i) = inferredType
+        operandTypes.zipWithIndex.foreach {
+          case (_, i) =>
+            if (i < inferredTypes.length - 1) {
+              operandTypes(i) = inferredTypes(i)
+            } else if (null != inferredTypes.last.getComponentType) {
+              // last arguments is a collection, the array type
+              operandTypes(i) = inferredTypes.last.getComponentType
+            } else {
+              operandTypes(i) = inferredTypes.last
+            }
         }
       }
     }
@@ -136,8 +143,18 @@ object ScalarSqlFunction {
       }
 
       override def getOperandCountRange: SqlOperandCountRange = {
-        val signatureLengths = signatures.map(_.length)
-        SqlOperandCountRanges.between(signatureLengths.min, 
signatureLengths.max)
+        var min = 255
+        var max = -1
+        signatures.foreach(sig => {
+          var len = sig.length
+          if (len > 0 && sig(sig.length - 1).isArray) {
+            max = 254  // according to JVM spec 4.3.3
+            len = sig.length - 1
+          }
+          max = Math.max(len, max)
+          min = Math.min(len, min)
+        })
+        SqlOperandCountRanges.between(min, max)
       }
 
       override def checkOperandTypes(
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
index 21d28b5e591..2f0756b9afc 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
@@ -78,20 +78,7 @@ object UserDefinedFunctionUtils {
       function: UserDefinedFunction,
       signature: Seq[TypeInformation[_]])
     : Option[Array[Class[_]]] = {
-    // We compare the raw Java classes not the TypeInformation.
-    // TypeInformation does not matter during runtime (e.g. within a 
MapFunction).
-    val actualSignature = typeInfoToClass(signature)
-    val signatures = getSignatures(function)
-
-    signatures
-      // go over all signatures and find one matching actual signature
-      .find { curSig =>
-      // match parameters of signature to actual parameters
-      actualSignature.length == curSig.length &&
-        curSig.zipWithIndex.forall { case (clazz, i) =>
-          parameterTypeEquals(actualSignature(i), clazz)
-        }
-    }
+    getEvalMethod(function, signature).map(_.getParameterTypes)
   }
 
   /**
@@ -106,16 +93,53 @@ object UserDefinedFunctionUtils {
     val actualSignature = typeInfoToClass(signature)
     val evalMethods = checkAndExtractEvalMethods(function)
 
-    evalMethods
-      // go over all eval methods and find one matching
-      .find { cur =>
-      val signatures = cur.getParameterTypes
-      // match parameters of signature to actual parameters
-      actualSignature.length == signatures.length &&
-        signatures.zipWithIndex.forall { case (clazz, i) =>
-          parameterTypeEquals(actualSignature(i), clazz)
+    val filtered = evalMethods
+      // go over all eval methods and filter out matching methods
+      .filter {
+        case cur if !cur.isVarArgs =>
+          val signatures = cur.getParameterTypes
+          // match parameters of signature to actual par(ameters
+          actualSignature.length == signatures.length &&
+            signatures.zipWithIndex.forall { case (clazz, i) =>
+              parameterTypeEquals(actualSignature(i), clazz)
+          }
+        case cur if cur.isVarArgs =>
+          val signatures = cur.getParameterTypes
+          actualSignature.zipWithIndex.forall {
+            case (clazz, i) if i < signatures.length - 1  =>
+              parameterTypeEquals(clazz, signatures(i))
+            case (clazz, i) if i >= signatures.length - 1 =>
+              parameterTypeEquals(clazz, signatures.last.getComponentType)
+          } ||
+          (actualSignature.isEmpty && signatures.length == 1)
+    }
+
+    // if there is a fixed method, compiler will call the method preferentially
+    val fixedMethods = filtered.count{!_.isVarArgs}
+    val found = filtered.filter { cur =>
+      fixedMethods > 0 && !cur.isVarArgs ||
+      fixedMethods == 0 && cur.isVarArgs
+    }
+
+    if (found.isEmpty &&
+      // does there exist scala type variable arguments
+      evalMethods.exists{ evalMethod =>
+        val signatures = evalMethod.getParameterTypes
+        signatures.zipWithIndex.forall {
+          case (clazz, i) if i < signatures.length - 1 =>
+            parameterTypeEquals(actualSignature(i), clazz)
+          case (clazz, i) if i == signatures.length - 1 =>
+            clazz.getName.equals("scala.collection.Seq")
         }
+      }) {
+      throw new ValidationException("The 'eval' method do not support Scala 
type of " +
+        "variable args eg. Type*, please add a @scala.annotation.varargs 
annotation " +
+        "to your 'eval' method")
+    } else if (found.length > 1) {
+      throw new ValidationException("Found multiple 'eval' methods which " +
+        "match the signature.")
     }
+    found.headOption
   }
 
   /**
@@ -133,7 +157,7 @@ object UserDefinedFunctionUtils {
 
   /**
     * Extracts "eval" methods and throws a [[ValidationException]] if no 
implementation
-    * can be found.
+    * can be found, or implementation does not match the requirements
     */
   def checkAndExtractEvalMethods(function: UserDefinedFunction): Array[Method] 
= {
     val methods = function
@@ -152,9 +176,9 @@ object UserDefinedFunctionUtils {
         s"Function class '${function.getClass.getCanonicalName}' does not 
implement at least " +
           s"one method named 'eval' which is public, not abstract and " +
           s"(in case of table functions) not static.")
-    } else {
-      methods
     }
+
+    methods
   }
 
   def getSignatures(function: UserDefinedFunction): Array[Array[Class[_]]] = {
@@ -317,6 +341,7 @@ object UserDefinedFunctionUtils {
   private def parameterTypeEquals(candidate: Class[_], expected: Class[_]): 
Boolean =
   candidate == null ||
     candidate == expected ||
+    expected == classOf[Object] ||
     expected.isPrimitive && Primitives.wrap(expected) == candidate ||
     candidate == classOf[Date] && (expected == classOf[Int] || expected == 
classOf[JInt])  ||
     candidate == classOf[Time] && (expected == classOf[Int] || expected == 
classOf[JInt]) ||
diff --git 
a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedScalarFunctions.java
 
b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedScalarFunctions.java
index e817f06b4e1..56f866d2b11 100644
--- 
a/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedScalarFunctions.java
+++ 
b/flink-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/utils/UserDefinedScalarFunctions.java
@@ -33,4 +33,24 @@ public String eval(Integer a, int b,  Long c) {
                }
        }
 
+       public static class JavaFunc2 extends ScalarFunction {
+               public String eval(String s, Integer... a) {
+                       int m = 1;
+                       for (int n : a) {
+                               m *= n;
+                       }
+                       return s + m;
+               }
+       }
+
+       public static class JavaFunc3 extends ScalarFunction {
+               public int eval(String a, int... b) {
+                       return b.length;
+               }
+
+               public String eval(String c) {
+                       return c;
+               }
+       }
+
 }
diff --git 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala
 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala
index a6c1760c9b8..4985e410eee 100644
--- 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala
+++ 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala
@@ -24,8 +24,8 @@ import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
 import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.java.typeutils.RowTypeInfo
 import org.apache.flink.types.Row
-import org.apache.flink.table.api.Types
-import 
org.apache.flink.table.api.java.utils.UserDefinedScalarFunctions.{JavaFunc0, 
JavaFunc1}
+import org.apache.flink.table.api.{Types, ValidationException}
+import 
org.apache.flink.table.api.java.utils.UserDefinedScalarFunctions.{JavaFunc0, 
JavaFunc1, JavaFunc2, JavaFunc3}
 import org.apache.flink.table.api.scala._
 import org.apache.flink.table.expressions.utils._
 import org.apache.flink.table.functions.ScalarFunction
@@ -180,6 +180,85 @@ class UserDefinedScalarFunctionTest extends 
ExpressionTestBase {
       "+0 00:00:01.000")
   }
   
+  @Test
+  def testVariableArgs(): Unit = {
+    testAllApis(
+      Func14(1, 2, 3, 4),
+      "Func14(1, 2, 3, 4)",
+      "Func14(1, 2, 3, 4)",
+      "10")
+
+    // Test for empty arguments
+    testAllApis(
+      Func14(),
+      "Func14()",
+      "Func14()",
+      "0")
+
+    // Test for override
+    testAllApis(
+      Func15("Hello"),
+      "Func15('Hello')",
+      "Func15('Hello')",
+      "Hello"
+    )
+
+    testAllApis(
+      Func15('f1),
+      "Func15(f1)",
+      "Func15(f1)",
+      "Test"
+    )
+
+    testAllApis(
+      Func15("Hello", 1, 2, 3),
+      "Func15('Hello', 1, 2, 3)",
+      "Func15('Hello', 1, 2, 3)",
+      "Hello3"
+    )
+
+    testAllApis(
+      Func16('f9),
+      "Func16(f9)",
+      "Func16(f9)",
+      "Hello, World"
+    )
+
+    try {
+      testAllApis(
+        Func17("Hello", "World"),
+        "Func17('Hello', 'World')",
+        "Func17('Hello', 'World')",
+        "Hello, World"
+      )
+      throw new RuntimeException("Shouldn't be reached here!")
+    } catch {
+      case ex: ValidationException =>
+        // It's normal
+    }
+
+    val JavaFunc2 = new JavaFunc2
+    testAllApis(
+      JavaFunc2("Hi", 1, 3, 5, 7),
+      "JavaFunc2('Hi', 1, 3, 5, 7)",
+      "JavaFunc2('Hi', 1, 3, 5, 7)",
+      "Hi105")
+
+    // Test for override
+    val JavaFunc3 = new JavaFunc3
+    testAllApis(
+      JavaFunc3("Hi"),
+      "JavaFunc3('Hi')",
+      "JavaFunc3('Hi')",
+      "Hi")
+
+    testAllApis(
+      JavaFunc3('f1),
+      "JavaFunc3(f1)",
+      "JavaFunc3(f1)",
+      "Test")
+  }
+
   @Test
   def testJavaBoxedPrimitives(): Unit = {
     val JavaFunc0 = new JavaFunc0()
@@ -235,10 +314,11 @@ class UserDefinedScalarFunctionTest extends 
ExpressionTestBase {
       "#Test")
   }
 
+
   // 
----------------------------------------------------------------------------------------------
 
   override def testData: Any = {
-    val testData = new Row(9)
+    val testData = new Row(10)
     testData.setField(0, 42)
     testData.setField(1, "Test")
     testData.setField(2, null)
@@ -248,6 +328,7 @@ class UserDefinedScalarFunctionTest extends 
ExpressionTestBase {
     testData.setField(6, Timestamp.valueOf("1990-10-14 12:10:10"))
     testData.setField(7, 12)
     testData.setField(8, 1000L)
+    testData.setField(9, Seq("Hello", "World"))
     testData
   }
 
@@ -261,7 +342,8 @@ class UserDefinedScalarFunctionTest extends 
ExpressionTestBase {
       Types.TIME,
       Types.TIMESTAMP,
       Types.INTERVAL_MONTHS,
-      Types.INTERVAL_MILLIS
+      Types.INTERVAL_MILLIS,
+      TypeInformation.of(classOf[Seq[String]])
     ).asInstanceOf[TypeInformation[Any]]
   }
 
@@ -279,8 +361,14 @@ class UserDefinedScalarFunctionTest extends 
ExpressionTestBase {
     "Func10" -> Func10,
     "Func11" -> Func11,
     "Func12" -> Func12,
+    "Func14" -> Func14,
+    "Func15" -> Func15,
+    "Func16" -> Func16,
+    "Func17" -> Func17,
     "JavaFunc0" -> new JavaFunc0,
     "JavaFunc1" -> new JavaFunc1,
+    "JavaFunc2" -> new JavaFunc2,
+    "JavaFunc3" -> new JavaFunc3,
     "RichFunc0" -> new RichFunc0,
     "RichFunc1" -> new RichFunc1,
     "RichFunc2" -> new RichFunc2
diff --git 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala
 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala
index 1258137df7e..982a1d6625c 100644
--- 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala
+++ 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala
@@ -28,6 +28,8 @@ import org.junit.Assert
 import scala.collection.mutable
 import scala.io.Source
 
+import scala.annotation.varargs
+
 case class SimplePojo(name: String, age: Int)
 
 object Func0 extends ScalarFunction {
@@ -227,3 +229,37 @@ class Func13(prefix: String) extends ScalarFunction {
   }
 }
 
+object Func14 extends ScalarFunction {
+
+  @varargs
+  def eval(a: Int*): Int = {
+    a.sum
+  }
+}
+
+object Func15 extends ScalarFunction {
+
+  @varargs
+  def eval(a: String, b: Int*): String = {
+    a + b.length
+  }
+
+  def eval(a: String): String = {
+    a
+  }
+}
+
+object Func16 extends ScalarFunction {
+
+  def eval(a: Seq[String]): String = {
+    a.mkString(", ")
+  }
+}
+
+object Func17 extends ScalarFunction {
+
+  // Without @varargs, it will throw exception
+  def eval(a: String*): String = {
+    a.mkString(", ")
+  }
+}


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to