This is an automated email from the ASF dual-hosted git repository.
marong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 1e07bde47 [VL] UDF: Support variable arity in function sigatures
(#5495)
1e07bde47 is described below
commit 1e07bde4710961ae8e3457c7a07c51d03b08cf09
Author: Rong Ma <[email protected]>
AuthorDate: Fri Apr 26 08:52:33 2024 +0800
[VL] UDF: Support variable arity in function sigatures (#5495)
---
.../apache/spark/sql/expression/UDFResolver.scala | 155 ++++++++++---
.../apache/gluten/expression/VeloxUdfSuite.scala | 18 +-
cpp/velox/jni/JniUdf.cc | 9 +-
cpp/velox/udf/Udaf.h | 1 +
cpp/velox/udf/Udf.h | 2 +
cpp/velox/udf/UdfLoader.cc | 31 ++-
cpp/velox/udf/UdfLoader.h | 45 ++--
cpp/velox/udf/examples/MyUDF.cc | 257 ++++++++++++++++-----
.../spark/sql/catalyst/types/DataTypeUtils.scala | 29 +--
.../spark/sql/catalyst/types/DataTypeUtils.scala | 29 +--
.../spark/sql/catalyst/types/DataTypeUtils.scala | 29 +--
11 files changed, 415 insertions(+), 190 deletions(-)
diff --git
a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
index 3445e40e5..bdfd24ed5 100644
---
a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
+++
b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
@@ -32,6 +32,7 @@ import
org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.{AttributeReference,
Expression, ExpressionInfo}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext,
ExprCode}
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.util.Utils
@@ -72,6 +73,25 @@ case class UserDefinedAggregateFunction(
}
}
+trait UDFSignatureBase {
+ val expressionType: ExpressionType
+ val children: Seq[DataType]
+ val variableArity: Boolean
+}
+
+case class UDFSignature(
+ expressionType: ExpressionType,
+ children: Seq[DataType],
+ variableArity: Boolean)
+ extends UDFSignatureBase
+
+case class UDAFSignature(
+ expressionType: ExpressionType,
+ children: Seq[DataType],
+ variableArity: Boolean,
+ intermediateAttrs: Seq[AttributeReference])
+ extends UDFSignatureBase
+
case class UDFExpression(
name: String,
dataType: DataType,
@@ -109,31 +129,40 @@ case class UDFExpression(
object UDFResolver extends Logging {
private val UDFNames = mutable.HashSet[String]()
// (udf_name, arg1, arg2, ...) => return type
- private val UDFMap = mutable.HashMap[(String, Seq[DataType]),
ExpressionType]()
+ private val UDFMap = mutable.HashMap[String,
mutable.MutableList[UDFSignature]]()
private val UDAFNames = mutable.HashSet[String]()
// (udaf_name, arg1, arg2, ...) => return type, intermediate attributes
private val UDAFMap =
- mutable.HashMap[(String, Seq[DataType]), (ExpressionType,
Seq[AttributeReference])]()
+ mutable.HashMap[String, mutable.MutableList[UDAFSignature]]()
private val LIB_EXTENSION = ".so"
// Called by JNI.
- def registerUDF(name: String, returnType: Array[Byte], argTypes:
Array[Byte]): Unit = {
+ def registerUDF(
+ name: String,
+ returnType: Array[Byte],
+ argTypes: Array[Byte],
+ variableArity: Boolean): Unit = {
registerUDF(
name,
ConverterUtils.parseFromBytes(returnType),
- ConverterUtils.parseFromBytes(argTypes))
+ ConverterUtils.parseFromBytes(argTypes),
+ variableArity)
}
private def registerUDF(
name: String,
returnType: ExpressionType,
- argTypes: ExpressionType): Unit = {
+ argTypes: ExpressionType,
+ variableArity: Boolean): Unit = {
assert(argTypes.dataType.isInstanceOf[StructType])
- UDFMap.put(
- (name,
argTypes.dataType.asInstanceOf[StructType].fields.map(_.dataType)),
- returnType)
+ val v =
+ UDFMap.getOrElseUpdate(name, mutable.MutableList[UDFSignature]())
+ v += UDFSignature(
+ returnType,
+ argTypes.dataType.asInstanceOf[StructType].fields.map(_.dataType),
+ variableArity)
UDFNames += name
logInfo(s"Registered UDF: $name($argTypes) -> $returnType")
}
@@ -142,12 +171,14 @@ object UDFResolver extends Logging {
name: String,
returnType: Array[Byte],
argTypes: Array[Byte],
- intermediateTypes: Array[Byte]): Unit = {
+ intermediateTypes: Array[Byte],
+ variableArity: Boolean): Unit = {
registerUDAF(
name,
ConverterUtils.parseFromBytes(returnType),
ConverterUtils.parseFromBytes(argTypes),
- ConverterUtils.parseFromBytes(intermediateTypes)
+ ConverterUtils.parseFromBytes(intermediateTypes),
+ variableArity
)
}
@@ -155,7 +186,8 @@ object UDFResolver extends Logging {
name: String,
returnType: ExpressionType,
argTypes: ExpressionType,
- intermediateTypes: ExpressionType): Unit = {
+ intermediateTypes: ExpressionType,
+ variableArity: Boolean): Unit = {
assert(argTypes.dataType.isInstanceOf[StructType])
assert(intermediateTypes.dataType.isInstanceOf[StructType])
@@ -164,10 +196,14 @@ object UDFResolver extends Logging {
case (f, index) =>
AttributeReference(s"inter_$index", f.dataType, f.nullable)()
}
- UDAFMap.put(
- (name,
argTypes.dataType.asInstanceOf[StructType].fields.map(_.dataType)),
- (returnType, aggBufferAttributes)
- )
+
+ val v =
+ UDAFMap.getOrElseUpdate(name, mutable.MutableList[UDAFSignature]())
+ v += UDAFSignature(
+ returnType,
+ argTypes.dataType.asInstanceOf[StructType].fields.map(_.dataType),
+ variableArity,
+ aggBufferAttributes)
UDAFNames += name
logInfo(s"Registered UDAF: $name($argTypes) -> $returnType")
}
@@ -319,30 +355,81 @@ object UDFResolver extends Logging {
}
private def getUdfExpression(name: String)(children: Seq[Expression]) = {
- val expressionType =
- UDFMap.getOrElse(
- (name, children.map(_.dataType)),
- throw new UnsupportedOperationException(
- s"UDF $name -> ${children.map(_.dataType.simpleString).mkString(",
")} " +
- s"is not registered.")
- )
- UDFExpression(name, expressionType.dataType, expressionType.nullable,
children)
+ def errorMessage: String =
+ s"UDF $name -> ${children.map(_.dataType.simpleString).mkString(", ")}
is not registered."
+
+ val signatures =
+ UDFMap.getOrElse(name, throw new
UnsupportedOperationException(errorMessage));
+
+ signatures.find(sig => tryBind(sig, children.map(_.dataType))) match {
+ case Some(sig) =>
+ UDFExpression(name, sig.expressionType.dataType,
sig.expressionType.nullable, children)
+ case None =>
+ throw new UnsupportedOperationException(errorMessage)
+ }
}
private def getUdafExpression(name: String)(children: Seq[Expression]) = {
- val (expressionType, aggBufferAttributes) =
+ def errorMessage: String =
+ s"UDAF $name -> ${children.map(_.dataType.simpleString).mkString(", ")}
is not registered."
+
+ val signatures =
UDAFMap.getOrElse(
- (name, children.map(_.dataType)),
- throw new UnsupportedOperationException(
- s"UDAF $name -> ${children.map(_.dataType.simpleString).mkString(",
")} " +
- s"is not registered.")
+ name,
+ throw new UnsupportedOperationException(errorMessage)
)
- UserDefinedAggregateFunction(
- name,
- expressionType.dataType,
- expressionType.nullable,
- children,
- aggBufferAttributes)
+ signatures.find(sig => tryBind(sig, children.map(_.dataType))) match {
+ case Some(sig) =>
+ UserDefinedAggregateFunction(
+ name,
+ sig.expressionType.dataType,
+ sig.expressionType.nullable,
+ children,
+ sig.intermediateAttrs)
+ case None =>
+ throw new UnsupportedOperationException(errorMessage)
+ }
+ }
+
+ // Returns true if required data types match the function signature.
+ // If the function signature is variable arity, the number of the last
argument can be zero
+ // or more.
+ private def tryBind(sig: UDFSignatureBase, requiredDataTypes:
Seq[DataType]): Boolean = {
+ if (!sig.variableArity) {
+ sig.children.size == requiredDataTypes.size &&
+ sig.children
+ .zip(requiredDataTypes)
+ .forall { case (candidate, required) =>
DataTypeUtils.sameType(candidate, required) }
+ } else {
+ // If variableArity is true, there must be at least one argument in the
signature.
+ if (requiredDataTypes.size < sig.children.size - 1) {
+ false
+ } else if (requiredDataTypes.size == sig.children.size - 1) {
+ sig.children
+ .dropRight(1)
+ .zip(requiredDataTypes)
+ .forall { case (candidate, required) =>
DataTypeUtils.sameType(candidate, required) }
+ } else {
+ val varArgStartIndex = sig.children.size - 1
+ // First check all var args has the same type with the last argument
of the signature.
+ if (
+ !requiredDataTypes
+ .drop(varArgStartIndex)
+ .forall(argType => DataTypeUtils.sameType(sig.children.last,
argType))
+ ) {
+ false
+ } else if (varArgStartIndex == 0) {
+ // No fixed args.
+ true
+ } else {
+ // Whether fixed args matches.
+ sig.children
+ .dropRight(1)
+ .zip(requiredDataTypes.dropRight(1 + requiredDataTypes.size -
sig.children.size))
+ .forall { case (candidate, required) =>
DataTypeUtils.sameType(candidate, required) }
+ }
+ }
+ }
}
}
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala
b/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala
index 40452dfbc..4d2f9fae3 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala
@@ -71,20 +71,24 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with
SQLHelper {
.set("spark.memory.offHeap.size", "1024MB")
}
- testWithSpecifiedSparkVersion("test udf", Some("3.2")) {
+ test("test udf") {
val df = spark.sql("""select
- | myudf1(1),
- | myudf1(1L),
- | myudf2(100L),
+ | myudf1(100L),
+ | myudf2(1),
+ | myudf2(1L),
+ | myudf3(),
+ | myudf3(1),
+ | myudf3(1, 2, 3),
+ | myudf3(1L),
+ | myudf3(1L, 2L, 3L),
| mydate(cast('2024-03-25' as date), 5)
|""".stripMargin)
- df.collect()
assert(
df.collect()
- .sameElements(Array(Row(6, 6L, 105, Date.valueOf("2024-03-30")))))
+ .sameElements(Array(Row(105L, 6, 6L, 5, 6, 11, 6L, 11L,
Date.valueOf("2024-03-30")))))
}
- testWithSpecifiedSparkVersion("test udaf", Some("3.2")) {
+ test("test udaf") {
val df = spark.sql("""select
| myavg(1),
| myavg(1L),
diff --git a/cpp/velox/jni/JniUdf.cc b/cpp/velox/jni/JniUdf.cc
index cd5a4f7c8..cab90b325 100644
--- a/cpp/velox/jni/JniUdf.cc
+++ b/cpp/velox/jni/JniUdf.cc
@@ -41,8 +41,8 @@ void gluten::initVeloxJniUDF(JNIEnv* env) {
udfResolverClass = createGlobalClassReferenceOrError(env,
kUdfResolverClassPath.c_str());
// methods
- registerUDFMethod = getMethodIdOrError(env, udfResolverClass, "registerUDF",
"(Ljava/lang/String;[B[B)V");
- registerUDAFMethod = getMethodIdOrError(env, udfResolverClass,
"registerUDAF", "(Ljava/lang/String;[B[B[B)V");
+ registerUDFMethod = getMethodIdOrError(env, udfResolverClass, "registerUDF",
"(Ljava/lang/String;[B[BZ)V");
+ registerUDAFMethod = getMethodIdOrError(env, udfResolverClass,
"registerUDAF", "(Ljava/lang/String;[B[B[BZ)V");
}
void gluten::finalizeVeloxJniUDF(JNIEnv* env) {
@@ -70,9 +70,10 @@ void gluten::jniGetFunctionSignatures(JNIEnv* env) {
0,
signature->intermediateType.length(),
reinterpret_cast<const jbyte*>(signature->intermediateType.c_str()));
- env->CallVoidMethod(instance, registerUDAFMethod, name, returnType,
argTypes, intermediateType);
+ env->CallVoidMethod(
+ instance, registerUDAFMethod, name, returnType, argTypes,
intermediateType, signature->variableArity);
} else {
- env->CallVoidMethod(instance, registerUDFMethod, name, returnType,
argTypes);
+ env->CallVoidMethod(instance, registerUDFMethod, name, returnType,
argTypes, signature->variableArity);
}
checkException(env);
}
diff --git a/cpp/velox/udf/Udaf.h b/cpp/velox/udf/Udaf.h
index 7e8f03402..5b33e0611 100644
--- a/cpp/velox/udf/Udaf.h
+++ b/cpp/velox/udf/Udaf.h
@@ -27,6 +27,7 @@ struct UdafEntry {
const char** argTypes;
const char* intermediateType{nullptr};
+ bool variableArity{false};
};
#define GLUTEN_GET_NUM_UDAF getNumUdaf
diff --git a/cpp/velox/udf/Udf.h b/cpp/velox/udf/Udf.h
index c3e579c44..1fa3c54d5 100644
--- a/cpp/velox/udf/Udf.h
+++ b/cpp/velox/udf/Udf.h
@@ -25,6 +25,8 @@ struct UdfEntry {
size_t numArgs;
const char** argTypes;
+
+ bool variableArity{false};
};
#define GLUTEN_GET_NUM_UDF getNumUdf
diff --git a/cpp/velox/udf/UdfLoader.cc b/cpp/velox/udf/UdfLoader.cc
index a8a99ce9f..02aa410a9 100644
--- a/cpp/velox/udf/UdfLoader.cc
+++ b/cpp/velox/udf/UdfLoader.cc
@@ -86,11 +86,11 @@
std::unordered_set<std::shared_ptr<UdfLoader::UdfSignature>> UdfLoader::getRegis
const auto& entry = udfEntries[i];
auto dataType = toSubstraitTypeStr(entry.dataType);
auto argTypes = toSubstraitTypeStr(entry.numArgs, entry.argTypes);
- signatures_.insert(std::make_shared<UdfSignature>(entry.name,
dataType, argTypes));
+ signatures_.insert(std::make_shared<UdfSignature>(entry.name,
dataType, argTypes, entry.variableArity));
}
free(udfEntries);
} else {
- LOG(INFO) << "No UDFs found in " << libPath;
+ LOG(INFO) << "No UDF found in " << libPath;
}
// Handle UDAFs.
@@ -110,11 +110,12 @@
std::unordered_set<std::shared_ptr<UdfLoader::UdfSignature>> UdfLoader::getRegis
auto dataType = toSubstraitTypeStr(entry.dataType);
auto argTypes = toSubstraitTypeStr(entry.numArgs, entry.argTypes);
auto intermediateType = toSubstraitTypeStr(entry.intermediateType);
- signatures_.insert(std::make_shared<UdfSignature>(entry.name,
dataType, argTypes, intermediateType));
+ signatures_.insert(
+ std::make_shared<UdfSignature>(entry.name, dataType, argTypes,
intermediateType, entry.variableArity));
}
free(udafEntries);
} else {
- LOG(INFO) << "No UDAFs found in " << libPath;
+ LOG(INFO) << "No UDAF found in " << libPath;
}
}
return signatures_;
@@ -151,4 +152,26 @@ std::shared_ptr<UdfLoader> UdfLoader::getInstance() {
return instance;
}
+std::string UdfLoader::toSubstraitTypeStr(const std::string& type) {
+ auto returnType = parser_.parse(type);
+ auto substraitType = convertor_.toSubstraitType(arena_, returnType);
+
+ std::string output;
+ substraitType.SerializeToString(&output);
+ return output;
+}
+
+std::string UdfLoader::toSubstraitTypeStr(int32_t numArgs, const char** args) {
+ std::vector<facebook::velox::TypePtr> argTypes;
+ argTypes.resize(numArgs);
+ for (auto i = 0; i < numArgs; ++i) {
+ argTypes[i] = parser_.parse(args[i]);
+ }
+ auto substraitType = convertor_.toSubstraitType(arena_,
facebook::velox::ROW(std::move(argTypes)));
+
+ std::string output;
+ substraitType.SerializeToString(&output);
+ return output;
+}
+
} // namespace gluten
diff --git a/cpp/velox/udf/UdfLoader.h b/cpp/velox/udf/UdfLoader.h
index 31098d2f4..2783beb85 100644
--- a/cpp/velox/udf/UdfLoader.h
+++ b/cpp/velox/udf/UdfLoader.h
@@ -36,11 +36,22 @@ class UdfLoader {
std::string intermediateType{};
- UdfSignature(std::string name, std::string returnType, std::string
argTypes)
- : name(name), returnType(returnType), argTypes(argTypes) {}
-
- UdfSignature(std::string name, std::string returnType, std::string
argTypes, std::string intermediateType)
- : name(name), returnType(returnType), argTypes(argTypes),
intermediateType(intermediateType) {}
+ bool variableArity;
+
+ UdfSignature(std::string name, std::string returnType, std::string
argTypes, bool variableArity)
+ : name(name), returnType(returnType), argTypes(argTypes),
variableArity(variableArity) {}
+
+ UdfSignature(
+ std::string name,
+ std::string returnType,
+ std::string argTypes,
+ std::string intermediateType,
+ bool variableArity)
+ : name(name),
+ returnType(returnType),
+ argTypes(argTypes),
+ intermediateType(intermediateType),
+ variableArity(variableArity) {}
~UdfSignature() = default;
};
@@ -58,27 +69,9 @@ class UdfLoader {
private:
void loadUdfLibraries0(const std::vector<std::string>& libPaths);
- std::string toSubstraitTypeStr(const std::string& type) {
- auto returnType = parser_.parse(type);
- auto substraitType = convertor_.toSubstraitType(arena_, returnType);
-
- std::string output;
- substraitType.SerializeToString(&output);
- return output;
- }
-
- std::string toSubstraitTypeStr(int32_t numArgs, const char** args) {
- std::vector<facebook::velox::TypePtr> argTypes;
- argTypes.resize(numArgs);
- for (auto i = 0; i < numArgs; ++i) {
- argTypes[i] = parser_.parse(args[i]);
- }
- auto substraitType = convertor_.toSubstraitType(arena_,
facebook::velox::ROW(std::move(argTypes)));
-
- std::string output;
- substraitType.SerializeToString(&output);
- return output;
- }
+ std::string toSubstraitTypeStr(const std::string& type);
+
+ std::string toSubstraitTypeStr(int32_t numArgs, const char** args);
std::unordered_map<std::string, void*> handles_;
diff --git a/cpp/velox/udf/examples/MyUDF.cc b/cpp/velox/udf/examples/MyUDF.cc
index 578e3effb..88bc3ad85 100644
--- a/cpp/velox/udf/examples/MyUDF.cc
+++ b/cpp/velox/udf/examples/MyUDF.cc
@@ -21,8 +21,6 @@
#include <iostream>
#include "udf/Udf.h"
-namespace {
-
using namespace facebook::velox;
using namespace facebook::velox::exec;
@@ -30,10 +28,26 @@ static const char* kInteger = "int";
static const char* kBigInt = "bigint";
static const char* kDate = "date";
+class UdfRegisterer {
+ public:
+ ~UdfRegisterer() = default;
+
+ // Returns the number of UDFs in populateUdfEntries.
+ virtual int getNumUdf() = 0;
+
+ // Populate the udfEntries, starting at the given index.
+ virtual void populateUdfEntries(int& index, gluten::UdfEntry* udfEntries) =
0;
+
+ // Register all function signatures to velox.
+ virtual void registerSignatures() = 0;
+};
+
+namespace myudf {
+
template <TypeKind Kind>
-class PlusConstantFunction : public exec::VectorFunction {
+class PlusFiveFunction : public exec::VectorFunction {
public:
- explicit PlusConstantFunction(int32_t addition) : addition_(addition) {}
+ explicit PlusFiveFunction() {}
void apply(
const SelectivityVector& rows,
@@ -42,12 +56,6 @@ class PlusConstantFunction : public exec::VectorFunction {
exec::EvalCtx& context,
VectorPtr& result) const override {
using nativeType = typename TypeTraits<Kind>::NativeType;
- VELOX_CHECK_EQ(args.size(), 1);
-
- auto& arg = args[0];
-
- // The argument may be flat or constant.
- VELOX_CHECK(arg->isFlatEncoding() || arg->isConstantEncoding());
BaseVector::ensureWritable(rows, createScalarType<Kind>(), context.pool(),
result);
@@ -56,79 +64,218 @@ class PlusConstantFunction : public exec::VectorFunction {
flatResult->clearNulls(rows);
- if (arg->isConstantEncoding()) {
- auto value = arg->as<ConstantVector<nativeType>>()->valueAt(0);
- rows.applyToSelected([&](auto row) { rawResult[row] = value + addition_;
});
- } else {
- auto* rawInput = arg->as<FlatVector<nativeType>>()->rawValues();
+ rows.applyToSelected([&](auto row) { rawResult[row] = 5; });
- rows.applyToSelected([&](auto row) { rawResult[row] = rawInput[row] +
addition_; });
+ if (args.size() == 0) {
+ return;
}
- }
- private:
- const int32_t addition_;
-};
-
-template <typename T>
-struct MyDateSimpleFunction {
- VELOX_DEFINE_FUNCTION_TYPES(T);
-
- FOLLY_ALWAYS_INLINE void call(int32_t& result, const arg_type<Date>& date,
const arg_type<int32_t> addition) {
- result = date + addition;
+ for (int i = 0; i < args.size(); i++) {
+ auto& arg = args[i];
+ VELOX_CHECK(arg->isFlatEncoding() || arg->isConstantEncoding());
+ if (arg->isConstantEncoding()) {
+ auto value = arg->as<ConstantVector<nativeType>>()->valueAt(0);
+ rows.applyToSelected([&](auto row) { rawResult[row] += value; });
+ } else {
+ auto* rawInput = arg->as<FlatVector<nativeType>>()->rawValues();
+ rows.applyToSelected([&](auto row) { rawResult[row] += rawInput[row];
});
+ }
+ }
}
};
-std::shared_ptr<facebook::velox::exec::VectorFunction> makeMyUdf1(
+static std::shared_ptr<facebook::velox::exec::VectorFunction> makePlusConstant(
const std::string& /*name*/,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const core::QueryConfig& /*config*/) {
+ if (inputArgs.size() == 0) {
+ return std::make_shared<PlusFiveFunction<TypeKind::INTEGER>>();
+ }
auto typeKind = inputArgs[0].type->kind();
switch (typeKind) {
case TypeKind::INTEGER:
- return std::make_shared<PlusConstantFunction<TypeKind::INTEGER>>(5);
+ return std::make_shared<PlusFiveFunction<TypeKind::INTEGER>>();
case TypeKind::BIGINT:
- return std::make_shared<PlusConstantFunction<TypeKind::BIGINT>>(5);
+ return std::make_shared<PlusFiveFunction<TypeKind::BIGINT>>();
default:
VELOX_UNREACHABLE();
}
}
-static std::vector<std::shared_ptr<exec::FunctionSignature>>
integerSignatures() {
- // integer -> integer, bigint ->bigint
- return {
-
exec::FunctionSignatureBuilder().returnType("integer").argumentType("integer").build(),
-
exec::FunctionSignatureBuilder().returnType("bigint").argumentType("bigint").build()};
-}
+// name: myudf1
+// signatures:
+// bigint -> bigint
+// type: VectorFunction
+class MyUdf1Registerer final : public UdfRegisterer {
+ public:
+ int getNumUdf() override {
+ return 1;
+ }
-static std::vector<std::shared_ptr<exec::FunctionSignature>>
bigintSignatures() {
- // bigint -> bigint
- return
{exec::FunctionSignatureBuilder().returnType("bigint").argumentType("bigint").build()};
-}
+ void populateUdfEntries(int& index, gluten::UdfEntry* udfEntries) override {
+ udfEntries[index++] = {name_.c_str(), kBigInt, 1, bigintArg_};
+ }
+
+ void registerSignatures() override {
+ facebook::velox::exec::registerVectorFunction(
+ name_, bigintSignatures(),
std::make_unique<PlusFiveFunction<facebook::velox::TypeKind::BIGINT>>());
+ }
+
+ private:
+ std::vector<std::shared_ptr<exec::FunctionSignature>> bigintSignatures() {
+ return
{exec::FunctionSignatureBuilder().returnType("bigint").argumentType("bigint").build()};
+ }
+
+ const std::string name_ = "myudf1";
+ const char* bigintArg_[1] = {kBigInt};
+};
+
+// name: myudf2
+// signatures:
+// integer -> integer
+// bigint -> bigint
+// type: StatefulVectorFunction
+class MyUdf2Registerer final : public UdfRegisterer {
+ public:
+ int getNumUdf() override {
+ return 2;
+ }
+
+ void populateUdfEntries(int& index, gluten::UdfEntry* udfEntries) override {
+ udfEntries[index++] = {name_.c_str(), kInteger, 1, integerArg_};
+ udfEntries[index++] = {name_.c_str(), kBigInt, 1, bigintArg_};
+ }
+
+ void registerSignatures() override {
+ facebook::velox::exec::registerStatefulVectorFunction(name_,
integerAndBigintSignatures(), makePlusConstant);
+ }
+
+ private:
+ std::vector<std::shared_ptr<exec::FunctionSignature>>
integerAndBigintSignatures() {
+ return {
+
exec::FunctionSignatureBuilder().returnType("integer").argumentType("integer").build(),
+
exec::FunctionSignatureBuilder().returnType("bigint").argumentType("bigint").build()};
+ }
+
+ const std::string name_ = "myudf2";
+ const char* integerArg_[1] = {kInteger};
+ const char* bigintArg_[1] = {kBigInt};
+};
+
+// name: myudf3
+// signatures:
+// [integer,] ... -> integer
+// bigint, [bigint,] ... -> bigint
+// type: StatefulVectorFunction with variable arity
+class MyUdf3Registerer final : public UdfRegisterer {
+ public:
+ int getNumUdf() override {
+ return 2;
+ }
+
+ void populateUdfEntries(int& index, gluten::UdfEntry* udfEntries) override {
+ udfEntries[index++] = {name_.c_str(), kInteger, 1, integerArg_, true};
+ udfEntries[index++] = {name_.c_str(), kBigInt, 2, bigintArgs_, true};
+ }
+
+ void registerSignatures() override {
+ facebook::velox::exec::registerStatefulVectorFunction(
+ name_, integerAndBigintSignaturesWithVariableArity(),
makePlusConstant);
+ }
+
+ private:
+ std::vector<std::shared_ptr<exec::FunctionSignature>>
integerAndBigintSignaturesWithVariableArity() {
+ return {
+
exec::FunctionSignatureBuilder().returnType("integer").argumentType("integer").variableArity().build(),
+ exec::FunctionSignatureBuilder()
+ .returnType("bigint")
+ .argumentType("bigint")
+ .argumentType("bigint")
+ .variableArity()
+ .build()};
+ }
-} // namespace
+ const std::string name_ = "myudf3";
+ const char* integerArg_[1] = {kInteger};
+ const char* bigintArgs_[2] = {kBigInt, kBigInt};
+};
+} // namespace myudf
-const int kNumMyUdf = 4;
+namespace mydate {
+template <typename T>
+struct MyDateSimpleFunction {
+ VELOX_DEFINE_FUNCTION_TYPES(T);
+
+ FOLLY_ALWAYS_INLINE void call(int32_t& result, const arg_type<Date>& date,
const arg_type<int32_t> addition) {
+ result = date + addition;
+ }
+};
+
+// name: mydate
+// signatures:
+// date, integer -> bigint
+// type: SimpleFunction
+class MyDateRegisterer final : public UdfRegisterer {
+ public:
+ int getNumUdf() override {
+ return 1;
+ }
+
+ void populateUdfEntries(int& index, gluten::UdfEntry* udfEntries) override {
+ udfEntries[index++] = {name_.c_str(), kDate, 2, myDateArg_};
+ }
+
+ void registerSignatures() override {
+ facebook::velox::registerFunction<mydate::MyDateSimpleFunction, Date,
Date, int32_t>({name_});
+ }
+
+ private:
+ const std::string name_ = "mydate";
+ const char* myDateArg_[2] = {kDate, kInteger};
+};
+} // namespace mydate
+
+std::vector<std::shared_ptr<UdfRegisterer>>& globalRegisters() {
+ static std::vector<std::shared_ptr<UdfRegisterer>> registerers;
+ return registerers;
+}
+
+void setupRegisterers() {
+ static bool inited = false;
+ if (inited) {
+ return;
+ }
+ auto& registerers = globalRegisters();
+ registerers.push_back(std::make_shared<myudf::MyUdf1Registerer>());
+ registerers.push_back(std::make_shared<myudf::MyUdf2Registerer>());
+ registerers.push_back(std::make_shared<myudf::MyUdf3Registerer>());
+ registerers.push_back(std::make_shared<mydate::MyDateRegisterer>());
+ inited = true;
+}
DEFINE_GET_NUM_UDF {
- return kNumMyUdf;
+ setupRegisterers();
+
+ int numUdf = 0;
+ for (const auto& registerer : globalRegisters()) {
+ numUdf += registerer->getNumUdf();
+ }
+ return numUdf;
}
-const char* myUdf1Arg1[] = {kInteger};
-const char* myUdf1Arg2[] = {kBigInt};
-const char* myUdf2Arg1[] = {kBigInt};
-const char* myDateArg[] = {kDate, kInteger};
DEFINE_GET_UDF_ENTRIES {
+ setupRegisterers();
+
int index = 0;
- udfEntries[index++] = {"myudf1", kInteger, 1, myUdf1Arg1};
- udfEntries[index++] = {"myudf1", kBigInt, 1, myUdf1Arg2};
- udfEntries[index++] = {"myudf2", kBigInt, 1, myUdf2Arg1};
- udfEntries[index++] = {"mydate", kDate, 2, myDateArg};
+ for (const auto& registerer : globalRegisters()) {
+ registerer->populateUdfEntries(index, udfEntries);
+ }
}
DEFINE_REGISTER_UDF {
- facebook::velox::exec::registerStatefulVectorFunction("myudf1",
integerSignatures(), makeMyUdf1);
- facebook::velox::exec::registerVectorFunction(
- "myudf2", bigintSignatures(),
std::make_unique<PlusConstantFunction<facebook::velox::TypeKind::BIGINT>>(5));
- facebook::velox::registerFunction<MyDateSimpleFunction, Date, Date,
int32_t>({"mydate"});
+ setupRegisterers();
+
+ for (const auto& registerer : globalRegisters()) {
+ registerer->registerSignatures();
+ }
}
diff --git a/cpp/velox/udf/Udf.h
b/shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala
similarity index 60%
copy from cpp/velox/udf/Udf.h
copy to
shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala
index c3e579c44..597b5936f 100644
--- a/cpp/velox/udf/Udf.h
+++
b/shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala
@@ -14,26 +14,15 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+package org.apache.spark.sql.catalyst.types
-#pragma once
+import org.apache.spark.sql.types.DataType
-namespace gluten {
+object DataTypeUtils {
-struct UdfEntry {
- const char* name;
- const char* dataType;
-
- size_t numArgs;
- const char** argTypes;
-};
-
-#define GLUTEN_GET_NUM_UDF getNumUdf
-#define DEFINE_GET_NUM_UDF extern "C" int GLUTEN_GET_NUM_UDF()
-
-#define GLUTEN_GET_UDF_ENTRIES getUdfEntries
-#define DEFINE_GET_UDF_ENTRIES extern "C" void
GLUTEN_GET_UDF_ENTRIES(gluten::UdfEntry* udfEntries)
-
-#define GLUTEN_REGISTER_UDF registerUdf
-#define DEFINE_REGISTER_UDF extern "C" void GLUTEN_REGISTER_UDF()
-
-} // namespace gluten
+ /**
+ * Check if `this` and `other` are the same data type when ignoring
nullability
+ * (`StructField.nullable`, `ArrayType.containsNull`, and
`MapType.valueContainsNull`).
+ */
+ def sameType(left: DataType, right: DataType): Boolean = left.sameType(right)
+}
diff --git a/cpp/velox/udf/Udf.h
b/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala
similarity index 60%
copy from cpp/velox/udf/Udf.h
copy to
shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala
index c3e579c44..597b5936f 100644
--- a/cpp/velox/udf/Udf.h
+++
b/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala
@@ -14,26 +14,15 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+package org.apache.spark.sql.catalyst.types
-#pragma once
+import org.apache.spark.sql.types.DataType
-namespace gluten {
+object DataTypeUtils {
-struct UdfEntry {
- const char* name;
- const char* dataType;
-
- size_t numArgs;
- const char** argTypes;
-};
-
-#define GLUTEN_GET_NUM_UDF getNumUdf
-#define DEFINE_GET_NUM_UDF extern "C" int GLUTEN_GET_NUM_UDF()
-
-#define GLUTEN_GET_UDF_ENTRIES getUdfEntries
-#define DEFINE_GET_UDF_ENTRIES extern "C" void
GLUTEN_GET_UDF_ENTRIES(gluten::UdfEntry* udfEntries)
-
-#define GLUTEN_REGISTER_UDF registerUdf
-#define DEFINE_REGISTER_UDF extern "C" void GLUTEN_REGISTER_UDF()
-
-} // namespace gluten
+ /**
+ * Check if `this` and `other` are the same data type when ignoring
nullability
+ * (`StructField.nullable`, `ArrayType.containsNull`, and
`MapType.valueContainsNull`).
+ */
+ def sameType(left: DataType, right: DataType): Boolean = left.sameType(right)
+}
diff --git a/cpp/velox/udf/Udf.h
b/shims/spark34/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala
similarity index 60%
copy from cpp/velox/udf/Udf.h
copy to
shims/spark34/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala
index c3e579c44..597b5936f 100644
--- a/cpp/velox/udf/Udf.h
+++
b/shims/spark34/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala
@@ -14,26 +14,15 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+package org.apache.spark.sql.catalyst.types
-#pragma once
+import org.apache.spark.sql.types.DataType
-namespace gluten {
+object DataTypeUtils {
-struct UdfEntry {
- const char* name;
- const char* dataType;
-
- size_t numArgs;
- const char** argTypes;
-};
-
-#define GLUTEN_GET_NUM_UDF getNumUdf
-#define DEFINE_GET_NUM_UDF extern "C" int GLUTEN_GET_NUM_UDF()
-
-#define GLUTEN_GET_UDF_ENTRIES getUdfEntries
-#define DEFINE_GET_UDF_ENTRIES extern "C" void
GLUTEN_GET_UDF_ENTRIES(gluten::UdfEntry* udfEntries)
-
-#define GLUTEN_REGISTER_UDF registerUdf
-#define DEFINE_REGISTER_UDF extern "C" void GLUTEN_REGISTER_UDF()
-
-} // namespace gluten
+ /**
+ * Check if `this` and `other` are the same data type when ignoring
nullability
+ * (`StructField.nullable`, `ArrayType.containsNull`, and
`MapType.valueContainsNull`).
+ */
+ def sameType(left: DataType, right: DataType): Boolean = left.sameType(right)
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]