This is an automated email from the ASF dual-hosted git repository.

marong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 1e07bde47 [VL] UDF: Support variable arity in function sigatures 
(#5495)
1e07bde47 is described below

commit 1e07bde4710961ae8e3457c7a07c51d03b08cf09
Author: Rong Ma <[email protected]>
AuthorDate: Fri Apr 26 08:52:33 2024 +0800

    [VL] UDF: Support variable arity in function sigatures (#5495)
---
 .../apache/spark/sql/expression/UDFResolver.scala  | 155 ++++++++++---
 .../apache/gluten/expression/VeloxUdfSuite.scala   |  18 +-
 cpp/velox/jni/JniUdf.cc                            |   9 +-
 cpp/velox/udf/Udaf.h                               |   1 +
 cpp/velox/udf/Udf.h                                |   2 +
 cpp/velox/udf/UdfLoader.cc                         |  31 ++-
 cpp/velox/udf/UdfLoader.h                          |  45 ++--
 cpp/velox/udf/examples/MyUDF.cc                    | 257 ++++++++++++++++-----
 .../spark/sql/catalyst/types/DataTypeUtils.scala   |  29 +--
 .../spark/sql/catalyst/types/DataTypeUtils.scala   |  29 +--
 .../spark/sql/catalyst/types/DataTypeUtils.scala   |  29 +--
 11 files changed, 415 insertions(+), 190 deletions(-)

diff --git 
a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
 
b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
index 3445e40e5..bdfd24ed5 100644
--- 
a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
+++ 
b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
@@ -32,6 +32,7 @@ import 
org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
 import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
Expression, ExpressionInfo}
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
ExprCode}
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.types.{DataType, StructField, StructType}
 import org.apache.spark.util.Utils
@@ -72,6 +73,25 @@ case class UserDefinedAggregateFunction(
   }
 }
 
+trait UDFSignatureBase {
+  val expressionType: ExpressionType
+  val children: Seq[DataType]
+  val variableArity: Boolean
+}
+
+case class UDFSignature(
+    expressionType: ExpressionType,
+    children: Seq[DataType],
+    variableArity: Boolean)
+  extends UDFSignatureBase
+
+case class UDAFSignature(
+    expressionType: ExpressionType,
+    children: Seq[DataType],
+    variableArity: Boolean,
+    intermediateAttrs: Seq[AttributeReference])
+  extends UDFSignatureBase
+
 case class UDFExpression(
     name: String,
     dataType: DataType,
@@ -109,31 +129,40 @@ case class UDFExpression(
 object UDFResolver extends Logging {
   private val UDFNames = mutable.HashSet[String]()
   // (udf_name, arg1, arg2, ...) => return type
-  private val UDFMap = mutable.HashMap[(String, Seq[DataType]), 
ExpressionType]()
+  private val UDFMap = mutable.HashMap[String, 
mutable.MutableList[UDFSignature]]()
 
   private val UDAFNames = mutable.HashSet[String]()
   // (udaf_name, arg1, arg2, ...) => return type, intermediate attributes
   private val UDAFMap =
-    mutable.HashMap[(String, Seq[DataType]), (ExpressionType, 
Seq[AttributeReference])]()
+    mutable.HashMap[String, mutable.MutableList[UDAFSignature]]()
 
   private val LIB_EXTENSION = ".so"
 
   // Called by JNI.
-  def registerUDF(name: String, returnType: Array[Byte], argTypes: 
Array[Byte]): Unit = {
+  def registerUDF(
+      name: String,
+      returnType: Array[Byte],
+      argTypes: Array[Byte],
+      variableArity: Boolean): Unit = {
     registerUDF(
       name,
       ConverterUtils.parseFromBytes(returnType),
-      ConverterUtils.parseFromBytes(argTypes))
+      ConverterUtils.parseFromBytes(argTypes),
+      variableArity)
   }
 
   private def registerUDF(
       name: String,
       returnType: ExpressionType,
-      argTypes: ExpressionType): Unit = {
+      argTypes: ExpressionType,
+      variableArity: Boolean): Unit = {
     assert(argTypes.dataType.isInstanceOf[StructType])
-    UDFMap.put(
-      (name, 
argTypes.dataType.asInstanceOf[StructType].fields.map(_.dataType)),
-      returnType)
+    val v =
+      UDFMap.getOrElseUpdate(name, mutable.MutableList[UDFSignature]())
+    v += UDFSignature(
+      returnType,
+      argTypes.dataType.asInstanceOf[StructType].fields.map(_.dataType),
+      variableArity)
     UDFNames += name
     logInfo(s"Registered UDF: $name($argTypes) -> $returnType")
   }
@@ -142,12 +171,14 @@ object UDFResolver extends Logging {
       name: String,
       returnType: Array[Byte],
       argTypes: Array[Byte],
-      intermediateTypes: Array[Byte]): Unit = {
+      intermediateTypes: Array[Byte],
+      variableArity: Boolean): Unit = {
     registerUDAF(
       name,
       ConverterUtils.parseFromBytes(returnType),
       ConverterUtils.parseFromBytes(argTypes),
-      ConverterUtils.parseFromBytes(intermediateTypes)
+      ConverterUtils.parseFromBytes(intermediateTypes),
+      variableArity
     )
   }
 
@@ -155,7 +186,8 @@ object UDFResolver extends Logging {
       name: String,
       returnType: ExpressionType,
       argTypes: ExpressionType,
-      intermediateTypes: ExpressionType): Unit = {
+      intermediateTypes: ExpressionType,
+      variableArity: Boolean): Unit = {
     assert(argTypes.dataType.isInstanceOf[StructType])
     assert(intermediateTypes.dataType.isInstanceOf[StructType])
 
@@ -164,10 +196,14 @@ object UDFResolver extends Logging {
         case (f, index) =>
           AttributeReference(s"inter_$index", f.dataType, f.nullable)()
       }
-    UDAFMap.put(
-      (name, 
argTypes.dataType.asInstanceOf[StructType].fields.map(_.dataType)),
-      (returnType, aggBufferAttributes)
-    )
+
+    val v =
+      UDAFMap.getOrElseUpdate(name, mutable.MutableList[UDAFSignature]())
+    v += UDAFSignature(
+      returnType,
+      argTypes.dataType.asInstanceOf[StructType].fields.map(_.dataType),
+      variableArity,
+      aggBufferAttributes)
     UDAFNames += name
     logInfo(s"Registered UDAF: $name($argTypes) -> $returnType")
   }
@@ -319,30 +355,81 @@ object UDFResolver extends Logging {
   }
 
   private def getUdfExpression(name: String)(children: Seq[Expression]) = {
-    val expressionType =
-      UDFMap.getOrElse(
-        (name, children.map(_.dataType)),
-        throw new UnsupportedOperationException(
-          s"UDF $name -> ${children.map(_.dataType.simpleString).mkString(", 
")} " +
-            s"is not registered.")
-      )
-    UDFExpression(name, expressionType.dataType, expressionType.nullable, 
children)
+    def errorMessage: String =
+      s"UDF $name -> ${children.map(_.dataType.simpleString).mkString(", ")} 
is not registered."
+
+    val signatures =
+      UDFMap.getOrElse(name, throw new 
UnsupportedOperationException(errorMessage));
+
+    signatures.find(sig => tryBind(sig, children.map(_.dataType))) match {
+      case Some(sig) =>
+        UDFExpression(name, sig.expressionType.dataType, 
sig.expressionType.nullable, children)
+      case None =>
+        throw new UnsupportedOperationException(errorMessage)
+    }
   }
 
   private def getUdafExpression(name: String)(children: Seq[Expression]) = {
-    val (expressionType, aggBufferAttributes) =
+    def errorMessage: String =
+      s"UDAF $name -> ${children.map(_.dataType.simpleString).mkString(", ")} 
is not registered."
+
+    val signatures =
       UDAFMap.getOrElse(
-        (name, children.map(_.dataType)),
-        throw new UnsupportedOperationException(
-          s"UDAF $name -> ${children.map(_.dataType.simpleString).mkString(", 
")} " +
-            s"is not registered.")
+        name,
+        throw new UnsupportedOperationException(errorMessage)
       )
 
-    UserDefinedAggregateFunction(
-      name,
-      expressionType.dataType,
-      expressionType.nullable,
-      children,
-      aggBufferAttributes)
+    signatures.find(sig => tryBind(sig, children.map(_.dataType))) match {
+      case Some(sig) =>
+        UserDefinedAggregateFunction(
+          name,
+          sig.expressionType.dataType,
+          sig.expressionType.nullable,
+          children,
+          sig.intermediateAttrs)
+      case None =>
+        throw new UnsupportedOperationException(errorMessage)
+    }
+  }
+
+  // Returns true if required data types match the function signature.
+  // If the function signature is variable arity, the number of the last 
argument can be zero
+  // or more.
+  private def tryBind(sig: UDFSignatureBase, requiredDataTypes: 
Seq[DataType]): Boolean = {
+    if (!sig.variableArity) {
+      sig.children.size == requiredDataTypes.size &&
+      sig.children
+        .zip(requiredDataTypes)
+        .forall { case (candidate, required) => 
DataTypeUtils.sameType(candidate, required) }
+    } else {
+      // If variableArity is true, there must be at least one argument in the 
signature.
+      if (requiredDataTypes.size < sig.children.size - 1) {
+        false
+      } else if (requiredDataTypes.size == sig.children.size - 1) {
+        sig.children
+          .dropRight(1)
+          .zip(requiredDataTypes)
+          .forall { case (candidate, required) => 
DataTypeUtils.sameType(candidate, required) }
+      } else {
+        val varArgStartIndex = sig.children.size - 1
+        // First check all var args has the same type with the last argument 
of the signature.
+        if (
+          !requiredDataTypes
+            .drop(varArgStartIndex)
+            .forall(argType => DataTypeUtils.sameType(sig.children.last, 
argType))
+        ) {
+          false
+        } else if (varArgStartIndex == 0) {
+          // No fixed args.
+          true
+        } else {
+          // Whether fixed args matches.
+          sig.children
+            .dropRight(1)
+            .zip(requiredDataTypes.dropRight(1 + requiredDataTypes.size - 
sig.children.size))
+            .forall { case (candidate, required) => 
DataTypeUtils.sameType(candidate, required) }
+        }
+      }
+    }
   }
 }
diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala
 
b/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala
index 40452dfbc..4d2f9fae3 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala
@@ -71,20 +71,24 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with 
SQLHelper {
       .set("spark.memory.offHeap.size", "1024MB")
   }
 
-  testWithSpecifiedSparkVersion("test udf", Some("3.2")) {
+  test("test udf") {
     val df = spark.sql("""select
-                         |  myudf1(1),
-                         |  myudf1(1L),
-                         |  myudf2(100L),
+                         |  myudf1(100L),
+                         |  myudf2(1),
+                         |  myudf2(1L),
+                         |  myudf3(),
+                         |  myudf3(1),
+                         |  myudf3(1, 2, 3),
+                         |  myudf3(1L),
+                         |  myudf3(1L, 2L, 3L),
                          |  mydate(cast('2024-03-25' as date), 5)
                          |""".stripMargin)
-    df.collect()
     assert(
       df.collect()
-        .sameElements(Array(Row(6, 6L, 105, Date.valueOf("2024-03-30")))))
+        .sameElements(Array(Row(105L, 6, 6L, 5, 6, 11, 6L, 11L, 
Date.valueOf("2024-03-30")))))
   }
 
-  testWithSpecifiedSparkVersion("test udaf", Some("3.2")) {
+  test("test udaf") {
     val df = spark.sql("""select
                          |  myavg(1),
                          |  myavg(1L),
diff --git a/cpp/velox/jni/JniUdf.cc b/cpp/velox/jni/JniUdf.cc
index cd5a4f7c8..cab90b325 100644
--- a/cpp/velox/jni/JniUdf.cc
+++ b/cpp/velox/jni/JniUdf.cc
@@ -41,8 +41,8 @@ void gluten::initVeloxJniUDF(JNIEnv* env) {
   udfResolverClass = createGlobalClassReferenceOrError(env, 
kUdfResolverClassPath.c_str());
 
   // methods
-  registerUDFMethod = getMethodIdOrError(env, udfResolverClass, "registerUDF", 
"(Ljava/lang/String;[B[B)V");
-  registerUDAFMethod = getMethodIdOrError(env, udfResolverClass, 
"registerUDAF", "(Ljava/lang/String;[B[B[B)V");
+  registerUDFMethod = getMethodIdOrError(env, udfResolverClass, "registerUDF", 
"(Ljava/lang/String;[B[BZ)V");
+  registerUDAFMethod = getMethodIdOrError(env, udfResolverClass, 
"registerUDAF", "(Ljava/lang/String;[B[B[BZ)V");
 }
 
 void gluten::finalizeVeloxJniUDF(JNIEnv* env) {
@@ -70,9 +70,10 @@ void gluten::jniGetFunctionSignatures(JNIEnv* env) {
           0,
           signature->intermediateType.length(),
           reinterpret_cast<const jbyte*>(signature->intermediateType.c_str()));
-      env->CallVoidMethod(instance, registerUDAFMethod, name, returnType, 
argTypes, intermediateType);
+      env->CallVoidMethod(
+          instance, registerUDAFMethod, name, returnType, argTypes, 
intermediateType, signature->variableArity);
     } else {
-      env->CallVoidMethod(instance, registerUDFMethod, name, returnType, 
argTypes);
+      env->CallVoidMethod(instance, registerUDFMethod, name, returnType, 
argTypes, signature->variableArity);
     }
     checkException(env);
   }
diff --git a/cpp/velox/udf/Udaf.h b/cpp/velox/udf/Udaf.h
index 7e8f03402..5b33e0611 100644
--- a/cpp/velox/udf/Udaf.h
+++ b/cpp/velox/udf/Udaf.h
@@ -27,6 +27,7 @@ struct UdafEntry {
   const char** argTypes;
 
   const char* intermediateType{nullptr};
+  bool variableArity{false};
 };
 
 #define GLUTEN_GET_NUM_UDAF getNumUdaf
diff --git a/cpp/velox/udf/Udf.h b/cpp/velox/udf/Udf.h
index c3e579c44..1fa3c54d5 100644
--- a/cpp/velox/udf/Udf.h
+++ b/cpp/velox/udf/Udf.h
@@ -25,6 +25,8 @@ struct UdfEntry {
 
   size_t numArgs;
   const char** argTypes;
+
+  bool variableArity{false};
 };
 
 #define GLUTEN_GET_NUM_UDF getNumUdf
diff --git a/cpp/velox/udf/UdfLoader.cc b/cpp/velox/udf/UdfLoader.cc
index a8a99ce9f..02aa410a9 100644
--- a/cpp/velox/udf/UdfLoader.cc
+++ b/cpp/velox/udf/UdfLoader.cc
@@ -86,11 +86,11 @@ 
std::unordered_set<std::shared_ptr<UdfLoader::UdfSignature>> UdfLoader::getRegis
         const auto& entry = udfEntries[i];
         auto dataType = toSubstraitTypeStr(entry.dataType);
         auto argTypes = toSubstraitTypeStr(entry.numArgs, entry.argTypes);
-        signatures_.insert(std::make_shared<UdfSignature>(entry.name, 
dataType, argTypes));
+        signatures_.insert(std::make_shared<UdfSignature>(entry.name, 
dataType, argTypes, entry.variableArity));
       }
       free(udfEntries);
     } else {
-      LOG(INFO) << "No UDFs found in " << libPath;
+      LOG(INFO) << "No UDF found in " << libPath;
     }
 
     // Handle UDAFs.
@@ -110,11 +110,12 @@ 
std::unordered_set<std::shared_ptr<UdfLoader::UdfSignature>> UdfLoader::getRegis
         auto dataType = toSubstraitTypeStr(entry.dataType);
         auto argTypes = toSubstraitTypeStr(entry.numArgs, entry.argTypes);
         auto intermediateType = toSubstraitTypeStr(entry.intermediateType);
-        signatures_.insert(std::make_shared<UdfSignature>(entry.name, 
dataType, argTypes, intermediateType));
+        signatures_.insert(
+            std::make_shared<UdfSignature>(entry.name, dataType, argTypes, 
intermediateType, entry.variableArity));
       }
       free(udafEntries);
     } else {
-      LOG(INFO) << "No UDAFs found in " << libPath;
+      LOG(INFO) << "No UDAF found in " << libPath;
     }
   }
   return signatures_;
@@ -151,4 +152,26 @@ std::shared_ptr<UdfLoader> UdfLoader::getInstance() {
   return instance;
 }
 
+std::string UdfLoader::toSubstraitTypeStr(const std::string& type) {
+  auto returnType = parser_.parse(type);
+  auto substraitType = convertor_.toSubstraitType(arena_, returnType);
+
+  std::string output;
+  substraitType.SerializeToString(&output);
+  return output;
+}
+
+std::string UdfLoader::toSubstraitTypeStr(int32_t numArgs, const char** args) {
+  std::vector<facebook::velox::TypePtr> argTypes;
+  argTypes.resize(numArgs);
+  for (auto i = 0; i < numArgs; ++i) {
+    argTypes[i] = parser_.parse(args[i]);
+  }
+  auto substraitType = convertor_.toSubstraitType(arena_, 
facebook::velox::ROW(std::move(argTypes)));
+
+  std::string output;
+  substraitType.SerializeToString(&output);
+  return output;
+}
+
 } // namespace gluten
diff --git a/cpp/velox/udf/UdfLoader.h b/cpp/velox/udf/UdfLoader.h
index 31098d2f4..2783beb85 100644
--- a/cpp/velox/udf/UdfLoader.h
+++ b/cpp/velox/udf/UdfLoader.h
@@ -36,11 +36,22 @@ class UdfLoader {
 
     std::string intermediateType{};
 
-    UdfSignature(std::string name, std::string returnType, std::string 
argTypes)
-        : name(name), returnType(returnType), argTypes(argTypes) {}
-
-    UdfSignature(std::string name, std::string returnType, std::string 
argTypes, std::string intermediateType)
-        : name(name), returnType(returnType), argTypes(argTypes), 
intermediateType(intermediateType) {}
+    bool variableArity;
+
+    UdfSignature(std::string name, std::string returnType, std::string 
argTypes, bool variableArity)
+        : name(name), returnType(returnType), argTypes(argTypes), 
variableArity(variableArity) {}
+
+    UdfSignature(
+        std::string name,
+        std::string returnType,
+        std::string argTypes,
+        std::string intermediateType,
+        bool variableArity)
+        : name(name),
+          returnType(returnType),
+          argTypes(argTypes),
+          intermediateType(intermediateType),
+          variableArity(variableArity) {}
 
     ~UdfSignature() = default;
   };
@@ -58,27 +69,9 @@ class UdfLoader {
  private:
   void loadUdfLibraries0(const std::vector<std::string>& libPaths);
 
-  std::string toSubstraitTypeStr(const std::string& type) {
-    auto returnType = parser_.parse(type);
-    auto substraitType = convertor_.toSubstraitType(arena_, returnType);
-
-    std::string output;
-    substraitType.SerializeToString(&output);
-    return output;
-  }
-
-  std::string toSubstraitTypeStr(int32_t numArgs, const char** args) {
-    std::vector<facebook::velox::TypePtr> argTypes;
-    argTypes.resize(numArgs);
-    for (auto i = 0; i < numArgs; ++i) {
-      argTypes[i] = parser_.parse(args[i]);
-    }
-    auto substraitType = convertor_.toSubstraitType(arena_, 
facebook::velox::ROW(std::move(argTypes)));
-
-    std::string output;
-    substraitType.SerializeToString(&output);
-    return output;
-  }
+  std::string toSubstraitTypeStr(const std::string& type);
+
+  std::string toSubstraitTypeStr(int32_t numArgs, const char** args);
 
   std::unordered_map<std::string, void*> handles_;
 
diff --git a/cpp/velox/udf/examples/MyUDF.cc b/cpp/velox/udf/examples/MyUDF.cc
index 578e3effb..88bc3ad85 100644
--- a/cpp/velox/udf/examples/MyUDF.cc
+++ b/cpp/velox/udf/examples/MyUDF.cc
@@ -21,8 +21,6 @@
 #include <iostream>
 #include "udf/Udf.h"
 
-namespace {
-
 using namespace facebook::velox;
 using namespace facebook::velox::exec;
 
@@ -30,10 +28,26 @@ static const char* kInteger = "int";
 static const char* kBigInt = "bigint";
 static const char* kDate = "date";
 
+class UdfRegisterer {
+ public:
+  ~UdfRegisterer() = default;
+
+  // Returns the number of UDFs in populateUdfEntries.
+  virtual int getNumUdf() = 0;
+
+  // Populate the udfEntries, starting at the given index.
+  virtual void populateUdfEntries(int& index, gluten::UdfEntry* udfEntries) = 
0;
+
+  // Register all function signatures to velox.
+  virtual void registerSignatures() = 0;
+};
+
+namespace myudf {
+
 template <TypeKind Kind>
-class PlusConstantFunction : public exec::VectorFunction {
+class PlusFiveFunction : public exec::VectorFunction {
  public:
-  explicit PlusConstantFunction(int32_t addition) : addition_(addition) {}
+  explicit PlusFiveFunction() {}
 
   void apply(
       const SelectivityVector& rows,
@@ -42,12 +56,6 @@ class PlusConstantFunction : public exec::VectorFunction {
       exec::EvalCtx& context,
       VectorPtr& result) const override {
     using nativeType = typename TypeTraits<Kind>::NativeType;
-    VELOX_CHECK_EQ(args.size(), 1);
-
-    auto& arg = args[0];
-
-    // The argument may be flat or constant.
-    VELOX_CHECK(arg->isFlatEncoding() || arg->isConstantEncoding());
 
     BaseVector::ensureWritable(rows, createScalarType<Kind>(), context.pool(), 
result);
 
@@ -56,79 +64,218 @@ class PlusConstantFunction : public exec::VectorFunction {
 
     flatResult->clearNulls(rows);
 
-    if (arg->isConstantEncoding()) {
-      auto value = arg->as<ConstantVector<nativeType>>()->valueAt(0);
-      rows.applyToSelected([&](auto row) { rawResult[row] = value + addition_; 
});
-    } else {
-      auto* rawInput = arg->as<FlatVector<nativeType>>()->rawValues();
+    rows.applyToSelected([&](auto row) { rawResult[row] = 5; });
 
-      rows.applyToSelected([&](auto row) { rawResult[row] = rawInput[row] + 
addition_; });
+    if (args.size() == 0) {
+      return;
     }
-  }
 
- private:
-  const int32_t addition_;
-};
-
-template <typename T>
-struct MyDateSimpleFunction {
-  VELOX_DEFINE_FUNCTION_TYPES(T);
-
-  FOLLY_ALWAYS_INLINE void call(int32_t& result, const arg_type<Date>& date, 
const arg_type<int32_t> addition) {
-    result = date + addition;
+    for (int i = 0; i < args.size(); i++) {
+      auto& arg = args[i];
+      VELOX_CHECK(arg->isFlatEncoding() || arg->isConstantEncoding());
+      if (arg->isConstantEncoding()) {
+        auto value = arg->as<ConstantVector<nativeType>>()->valueAt(0);
+        rows.applyToSelected([&](auto row) { rawResult[row] += value; });
+      } else {
+        auto* rawInput = arg->as<FlatVector<nativeType>>()->rawValues();
+        rows.applyToSelected([&](auto row) { rawResult[row] += rawInput[row]; 
});
+      }
+    }
   }
 };
 
-std::shared_ptr<facebook::velox::exec::VectorFunction> makeMyUdf1(
+static std::shared_ptr<facebook::velox::exec::VectorFunction> makePlusConstant(
     const std::string& /*name*/,
     const std::vector<exec::VectorFunctionArg>& inputArgs,
     const core::QueryConfig& /*config*/) {
+  if (inputArgs.size() == 0) {
+    return std::make_shared<PlusFiveFunction<TypeKind::INTEGER>>();
+  }
   auto typeKind = inputArgs[0].type->kind();
   switch (typeKind) {
     case TypeKind::INTEGER:
-      return std::make_shared<PlusConstantFunction<TypeKind::INTEGER>>(5);
+      return std::make_shared<PlusFiveFunction<TypeKind::INTEGER>>();
     case TypeKind::BIGINT:
-      return std::make_shared<PlusConstantFunction<TypeKind::BIGINT>>(5);
+      return std::make_shared<PlusFiveFunction<TypeKind::BIGINT>>();
     default:
       VELOX_UNREACHABLE();
   }
 }
 
-static std::vector<std::shared_ptr<exec::FunctionSignature>> 
integerSignatures() {
-  // integer -> integer, bigint ->bigint
-  return {
-      
exec::FunctionSignatureBuilder().returnType("integer").argumentType("integer").build(),
-      
exec::FunctionSignatureBuilder().returnType("bigint").argumentType("bigint").build()};
-}
+// name: myudf1
+// signatures:
+//    bigint -> bigint
+// type: VectorFunction
+class MyUdf1Registerer final : public UdfRegisterer {
+ public:
+  int getNumUdf() override {
+    return 1;
+  }
 
-static std::vector<std::shared_ptr<exec::FunctionSignature>> 
bigintSignatures() {
-  // bigint -> bigint
-  return 
{exec::FunctionSignatureBuilder().returnType("bigint").argumentType("bigint").build()};
-}
+  void populateUdfEntries(int& index, gluten::UdfEntry* udfEntries) override {
+    udfEntries[index++] = {name_.c_str(), kBigInt, 1, bigintArg_};
+  }
+
+  void registerSignatures() override {
+    facebook::velox::exec::registerVectorFunction(
+        name_, bigintSignatures(), 
std::make_unique<PlusFiveFunction<facebook::velox::TypeKind::BIGINT>>());
+  }
+
+ private:
+  std::vector<std::shared_ptr<exec::FunctionSignature>> bigintSignatures() {
+    return 
{exec::FunctionSignatureBuilder().returnType("bigint").argumentType("bigint").build()};
+  }
+
+  const std::string name_ = "myudf1";
+  const char* bigintArg_[1] = {kBigInt};
+};
+
+// name: myudf2
+// signatures:
+//    integer -> integer
+//    bigint -> bigint
+// type: StatefulVectorFunction
+class MyUdf2Registerer final : public UdfRegisterer {
+ public:
+  int getNumUdf() override {
+    return 2;
+  }
+
+  void populateUdfEntries(int& index, gluten::UdfEntry* udfEntries) override {
+    udfEntries[index++] = {name_.c_str(), kInteger, 1, integerArg_};
+    udfEntries[index++] = {name_.c_str(), kBigInt, 1, bigintArg_};
+  }
+
+  void registerSignatures() override {
+    facebook::velox::exec::registerStatefulVectorFunction(name_, 
integerAndBigintSignatures(), makePlusConstant);
+  }
+
+ private:
+  std::vector<std::shared_ptr<exec::FunctionSignature>> 
integerAndBigintSignatures() {
+    return {
+        
exec::FunctionSignatureBuilder().returnType("integer").argumentType("integer").build(),
+        
exec::FunctionSignatureBuilder().returnType("bigint").argumentType("bigint").build()};
+  }
+
+  const std::string name_ = "myudf2";
+  const char* integerArg_[1] = {kInteger};
+  const char* bigintArg_[1] = {kBigInt};
+};
+
+// name: myudf3
+// signatures:
+//    [integer,] ... -> integer
+//    bigint, [bigint,] ... -> bigint
+// type: StatefulVectorFunction with variable arity
+class MyUdf3Registerer final : public UdfRegisterer {
+ public:
+  int getNumUdf() override {
+    return 2;
+  }
+
+  void populateUdfEntries(int& index, gluten::UdfEntry* udfEntries) override {
+    udfEntries[index++] = {name_.c_str(), kInteger, 1, integerArg_, true};
+    udfEntries[index++] = {name_.c_str(), kBigInt, 2, bigintArgs_, true};
+  }
+
+  void registerSignatures() override {
+    facebook::velox::exec::registerStatefulVectorFunction(
+        name_, integerAndBigintSignaturesWithVariableArity(), 
makePlusConstant);
+  }
+
+ private:
+  std::vector<std::shared_ptr<exec::FunctionSignature>> 
integerAndBigintSignaturesWithVariableArity() {
+    return {
+        
exec::FunctionSignatureBuilder().returnType("integer").argumentType("integer").variableArity().build(),
+        exec::FunctionSignatureBuilder()
+            .returnType("bigint")
+            .argumentType("bigint")
+            .argumentType("bigint")
+            .variableArity()
+            .build()};
+  }
 
-} // namespace
+  const std::string name_ = "myudf3";
+  const char* integerArg_[1] = {kInteger};
+  const char* bigintArgs_[2] = {kBigInt, kBigInt};
+};
+} // namespace myudf
 
-const int kNumMyUdf = 4;
+namespace mydate {
+template <typename T>
+struct MyDateSimpleFunction {
+  VELOX_DEFINE_FUNCTION_TYPES(T);
+
+  FOLLY_ALWAYS_INLINE void call(int32_t& result, const arg_type<Date>& date, 
const arg_type<int32_t> addition) {
+    result = date + addition;
+  }
+};
+
+// name: mydate
+// signatures:
+//    date, integer -> bigint
+// type: SimpleFunction
+class MyDateRegisterer final : public UdfRegisterer {
+ public:
+  int getNumUdf() override {
+    return 1;
+  }
+
+  void populateUdfEntries(int& index, gluten::UdfEntry* udfEntries) override {
+    udfEntries[index++] = {name_.c_str(), kDate, 2, myDateArg_};
+  }
+
+  void registerSignatures() override {
+    facebook::velox::registerFunction<mydate::MyDateSimpleFunction, Date, 
Date, int32_t>({name_});
+  }
+
+ private:
+  const std::string name_ = "mydate";
+  const char* myDateArg_[2] = {kDate, kInteger};
+};
+} // namespace mydate
+
+std::vector<std::shared_ptr<UdfRegisterer>>& globalRegisters() {
+  static std::vector<std::shared_ptr<UdfRegisterer>> registerers;
+  return registerers;
+}
+
+void setupRegisterers() {
+  static bool inited = false;
+  if (inited) {
+    return;
+  }
+  auto& registerers = globalRegisters();
+  registerers.push_back(std::make_shared<myudf::MyUdf1Registerer>());
+  registerers.push_back(std::make_shared<myudf::MyUdf2Registerer>());
+  registerers.push_back(std::make_shared<myudf::MyUdf3Registerer>());
+  registerers.push_back(std::make_shared<mydate::MyDateRegisterer>());
+  inited = true;
+}
 
 DEFINE_GET_NUM_UDF {
-  return kNumMyUdf;
+  setupRegisterers();
+
+  int numUdf = 0;
+  for (const auto& registerer : globalRegisters()) {
+    numUdf += registerer->getNumUdf();
+  }
+  return numUdf;
 }
 
-const char* myUdf1Arg1[] = {kInteger};
-const char* myUdf1Arg2[] = {kBigInt};
-const char* myUdf2Arg1[] = {kBigInt};
-const char* myDateArg[] = {kDate, kInteger};
 DEFINE_GET_UDF_ENTRIES {
+  setupRegisterers();
+
   int index = 0;
-  udfEntries[index++] = {"myudf1", kInteger, 1, myUdf1Arg1};
-  udfEntries[index++] = {"myudf1", kBigInt, 1, myUdf1Arg2};
-  udfEntries[index++] = {"myudf2", kBigInt, 1, myUdf2Arg1};
-  udfEntries[index++] = {"mydate", kDate, 2, myDateArg};
+  for (const auto& registerer : globalRegisters()) {
+    registerer->populateUdfEntries(index, udfEntries);
+  }
 }
 
 DEFINE_REGISTER_UDF {
-  facebook::velox::exec::registerStatefulVectorFunction("myudf1", 
integerSignatures(), makeMyUdf1);
-  facebook::velox::exec::registerVectorFunction(
-      "myudf2", bigintSignatures(), 
std::make_unique<PlusConstantFunction<facebook::velox::TypeKind::BIGINT>>(5));
-  facebook::velox::registerFunction<MyDateSimpleFunction, Date, Date, 
int32_t>({"mydate"});
+  setupRegisterers();
+
+  for (const auto& registerer : globalRegisters()) {
+    registerer->registerSignatures();
+  }
 }
diff --git a/cpp/velox/udf/Udf.h 
b/shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala
similarity index 60%
copy from cpp/velox/udf/Udf.h
copy to 
shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala
index c3e579c44..597b5936f 100644
--- a/cpp/velox/udf/Udf.h
+++ 
b/shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala
@@ -14,26 +14,15 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
+package org.apache.spark.sql.catalyst.types
 
-#pragma once
+import org.apache.spark.sql.types.DataType
 
-namespace gluten {
+object DataTypeUtils {
 
-struct UdfEntry {
-  const char* name;
-  const char* dataType;
-
-  size_t numArgs;
-  const char** argTypes;
-};
-
-#define GLUTEN_GET_NUM_UDF getNumUdf
-#define DEFINE_GET_NUM_UDF extern "C" int GLUTEN_GET_NUM_UDF()
-
-#define GLUTEN_GET_UDF_ENTRIES getUdfEntries
-#define DEFINE_GET_UDF_ENTRIES extern "C" void 
GLUTEN_GET_UDF_ENTRIES(gluten::UdfEntry* udfEntries)
-
-#define GLUTEN_REGISTER_UDF registerUdf
-#define DEFINE_REGISTER_UDF extern "C" void GLUTEN_REGISTER_UDF()
-
-} // namespace gluten
+  /**
+   * Check if `this` and `other` are the same data type when ignoring 
nullability
+   * (`StructField.nullable`, `ArrayType.containsNull`, and 
`MapType.valueContainsNull`).
+   */
+  def sameType(left: DataType, right: DataType): Boolean = left.sameType(right)
+}
diff --git a/cpp/velox/udf/Udf.h 
b/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala
similarity index 60%
copy from cpp/velox/udf/Udf.h
copy to 
shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala
index c3e579c44..597b5936f 100644
--- a/cpp/velox/udf/Udf.h
+++ 
b/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala
@@ -14,26 +14,15 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
+package org.apache.spark.sql.catalyst.types
 
-#pragma once
+import org.apache.spark.sql.types.DataType
 
-namespace gluten {
+object DataTypeUtils {
 
-struct UdfEntry {
-  const char* name;
-  const char* dataType;
-
-  size_t numArgs;
-  const char** argTypes;
-};
-
-#define GLUTEN_GET_NUM_UDF getNumUdf
-#define DEFINE_GET_NUM_UDF extern "C" int GLUTEN_GET_NUM_UDF()
-
-#define GLUTEN_GET_UDF_ENTRIES getUdfEntries
-#define DEFINE_GET_UDF_ENTRIES extern "C" void 
GLUTEN_GET_UDF_ENTRIES(gluten::UdfEntry* udfEntries)
-
-#define GLUTEN_REGISTER_UDF registerUdf
-#define DEFINE_REGISTER_UDF extern "C" void GLUTEN_REGISTER_UDF()
-
-} // namespace gluten
+  /**
+   * Check if `this` and `other` are the same data type when ignoring 
nullability
+   * (`StructField.nullable`, `ArrayType.containsNull`, and 
`MapType.valueContainsNull`).
+   */
+  def sameType(left: DataType, right: DataType): Boolean = left.sameType(right)
+}
diff --git a/cpp/velox/udf/Udf.h 
b/shims/spark34/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala
similarity index 60%
copy from cpp/velox/udf/Udf.h
copy to 
shims/spark34/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala
index c3e579c44..597b5936f 100644
--- a/cpp/velox/udf/Udf.h
+++ 
b/shims/spark34/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala
@@ -14,26 +14,15 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
+package org.apache.spark.sql.catalyst.types
 
-#pragma once
+import org.apache.spark.sql.types.DataType
 
-namespace gluten {
+object DataTypeUtils {
 
-struct UdfEntry {
-  const char* name;
-  const char* dataType;
-
-  size_t numArgs;
-  const char** argTypes;
-};
-
-#define GLUTEN_GET_NUM_UDF getNumUdf
-#define DEFINE_GET_NUM_UDF extern "C" int GLUTEN_GET_NUM_UDF()
-
-#define GLUTEN_GET_UDF_ENTRIES getUdfEntries
-#define DEFINE_GET_UDF_ENTRIES extern "C" void 
GLUTEN_GET_UDF_ENTRIES(gluten::UdfEntry* udfEntries)
-
-#define GLUTEN_REGISTER_UDF registerUdf
-#define DEFINE_REGISTER_UDF extern "C" void GLUTEN_REGISTER_UDF()
-
-} // namespace gluten
+  /**
+   * Check if `this` and `other` are the same data type when ignoring 
nullability
+   * (`StructField.nullable`, `ArrayType.containsNull`, and 
`MapType.valueContainsNull`).
+   */
+  def sameType(left: DataType, right: DataType): Boolean = left.sameType(right)
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to