PHILO-HE commented on code in PR #5495:
URL: https://github.com/apache/incubator-gluten/pull/5495#discussion_r1578927190
##########
cpp/velox/udf/examples/MyUDF.cc:
##########
@@ -56,33 +64,36 @@ class PlusConstantFunction : public exec::VectorFunction {
flatResult->clearNulls(rows);
- if (arg->isConstantEncoding()) {
- auto value = arg->as<ConstantVector<nativeType>>()->valueAt(0);
- rows.applyToSelected([&](auto row) { rawResult[row] = value + addition_;
});
- } else {
- auto* rawInput = arg->as<FlatVector<nativeType>>()->rawValues();
+ rows.applyToSelected([&](auto row) { rawResult[row] = addition_; });
+
+ if (args.size() == 0) {
+ return;
+ }
- rows.applyToSelected([&](auto row) { rawResult[row] = rawInput[row] +
addition_; });
+ for (int i = 0; i < args.size(); i++) {
+ auto& arg = args[i];
+ VELOX_CHECK(arg->isFlatEncoding() || arg->isConstantEncoding());
+ if (arg->isConstantEncoding()) {
+ auto value = arg->as<ConstantVector<nativeType>>()->valueAt(0);
+ rows.applyToSelected([&](auto row) { rawResult[row] += value; });
+ } else {
+ auto* rawInput = arg->as<FlatVector<nativeType>>()->rawValues();
+ rows.applyToSelected([&](auto row) { rawResult[row] += rawInput[row];
});
+ }
}
}
private:
const int32_t addition_;
};
-template <typename T>
-struct MyDateSimpleFunction {
- VELOX_DEFINE_FUNCTION_TYPES(T);
-
- FOLLY_ALWAYS_INLINE void call(int32_t& result, const arg_type<Date>& date,
const arg_type<int32_t> addition) {
- result = date + addition;
- }
-};
-
-std::shared_ptr<facebook::velox::exec::VectorFunction> makeMyUdf1(
+static std::shared_ptr<facebook::velox::exec::VectorFunction> makePlusConstant(
const std::string& /*name*/,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const core::QueryConfig& /*config*/) {
+ if (inputArgs.size() == 0) {
+ return std::make_shared<PlusConstantFunction<TypeKind::INTEGER>>(5);
Review Comment:
Though it's an example, it would be better to add some comment, as directly
using 5 may make reader feel strange.
##########
cpp/velox/udf/examples/MyUDF.cc:
##########
@@ -94,41 +105,180 @@ std::shared_ptr<facebook::velox::exec::VectorFunction>
makeMyUdf1(
}
}
-static std::vector<std::shared_ptr<exec::FunctionSignature>>
integerSignatures() {
- // integer -> integer, bigint ->bigint
- return {
-
exec::FunctionSignatureBuilder().returnType("integer").argumentType("integer").build(),
-
exec::FunctionSignatureBuilder().returnType("bigint").argumentType("bigint").build()};
-}
+// name: myudf1
+// signatures:
+// bigint -> bigint
+// type: VectorFunction
+class MyUdf1Registerer final : public UdfRegisterer {
+ public:
+ int getNumUdf() override {
+ return 1;
+ }
-static std::vector<std::shared_ptr<exec::FunctionSignature>>
bigintSignatures() {
- // bigint -> bigint
- return
{exec::FunctionSignatureBuilder().returnType("bigint").argumentType("bigint").build()};
-}
+ void populateUdfEntries(int& index, gluten::UdfEntry* udfEntries) override {
+ udfEntries[index++] = {name_.c_str(), kBigInt, 1, bigintArg_};
+ }
+
+ void registerSignatures() override {
+ facebook::velox::exec::registerVectorFunction(
+ name_, bigintSignatures(),
std::make_unique<PlusConstantFunction<facebook::velox::TypeKind::BIGINT>>(5));
+ }
+
+ private:
+ std::vector<std::shared_ptr<exec::FunctionSignature>> bigintSignatures() {
+ return
{exec::FunctionSignatureBuilder().returnType("bigint").argumentType("bigint").build()};
+ }
+
+ const std::string name_ = "myudf1";
+ const char* bigintArg_[1] = {kBigInt};
+};
+
+// name: myudf2
+// signatures:
+// integer -> integer
+// bigint -> bigint
+// type: StatefulVectorFunction
+class MyUdf2Registerer final : public UdfRegisterer {
+ public:
+ int getNumUdf() override {
+ return 2;
+ }
+
+ void populateUdfEntries(int& index, gluten::UdfEntry* udfEntries) override {
+ udfEntries[index++] = {name_.c_str(), kInteger, 1, integerArg_};
+ udfEntries[index++] = {name_.c_str(), kBigInt, 1, bigintArg_};
+ }
+
+ void registerSignatures() override {
+ facebook::velox::exec::registerStatefulVectorFunction(name_,
integerAndBigintSignatures(), makePlusConstant);
+ }
+
+ private:
+ std::vector<std::shared_ptr<exec::FunctionSignature>>
integerAndBigintSignatures() {
+ return {
+
exec::FunctionSignatureBuilder().returnType("integer").argumentType("integer").build(),
+
exec::FunctionSignatureBuilder().returnType("bigint").argumentType("bigint").build()};
+ }
+
+ const std::string name_ = "myudf2";
+ const char* integerArg_[1] = {kInteger};
+ const char* bigintArg_[1] = {kBigInt};
+};
+
+// name: myudf3
+// signatures:
+// [integer,] ... -> integer
+// bigint, [bigint,] ... -> bigint
+// type: StatefulVectorFunction with variable arity
+class MyUdf3Registerer final : public UdfRegisterer {
+ public:
+ int getNumUdf() override {
+ return 2;
+ }
+
+ void populateUdfEntries(int& index, gluten::UdfEntry* udfEntries) override {
+ udfEntries[index++] = {name_.c_str(), kInteger, 1, integerArg_, true};
+ udfEntries[index++] = {name_.c_str(), kBigInt, 2, bigintArgs_, true};
+ }
+
+ void registerSignatures() override {
+ facebook::velox::exec::registerStatefulVectorFunction(
+ name_, integerAndBigintSignaturesWithVariableArity(),
makePlusConstant);
+ }
+
+ private:
+ std::vector<std::shared_ptr<exec::FunctionSignature>>
integerAndBigintSignaturesWithVariableArity() {
+ return {
+
exec::FunctionSignatureBuilder().returnType("integer").argumentType("integer").variableArity().build(),
+ exec::FunctionSignatureBuilder()
+ .returnType("bigint")
+ .argumentType("bigint")
+ .argumentType("bigint")
+ .variableArity()
+ .build()};
+ }
-} // namespace
+ const std::string name_ = "myudf3";
+ const char* integerArg_[1] = {kInteger};
+ const char* bigintArgs_[2] = {kBigInt, kBigInt};
+};
+} // namespace myudf
-const int kNumMyUdf = 4;
+namespace mydate {
+template <typename T>
+struct MyDateSimpleFunction {
+ VELOX_DEFINE_FUNCTION_TYPES(T);
+
+ FOLLY_ALWAYS_INLINE void call(int32_t& result, const arg_type<Date>& date,
const arg_type<int32_t> addition) {
+ result = date + addition;
+ }
+};
+
+// name: mydate
+// signatures:
+// date, integer -> bigint
+// type: SimpleFunction
+class MyDateRegisterer final : public UdfRegisterer {
+ public:
+ int getNumUdf() override {
+ return 1;
+ }
+
+ void populateUdfEntries(int& index, gluten::UdfEntry* udfEntries) override {
+ udfEntries[index++] = {name_.c_str(), kDate, 2, myDateArg_};
+ }
+
+ void registerSignatures() override {
+ facebook::velox::registerFunction<mydate::MyDateSimpleFunction, Date,
Date, int32_t>({name_});
+ }
+
+ private:
+ const std::string name_ = "mydate";
+ const char* myDateArg_[2] = {kDate, kInteger};
+};
+} // namespace mydate
+
+std::vector<std::shared_ptr<UdfRegisterer>>& globalRegisters() {
+ static std::vector<std::shared_ptr<UdfRegisterer>> registerers;
+ return registerers;
+}
+
+void setupRegisterers() {
+ static bool inited = false;
+ if (inited) {
+ return;
+ }
+ auto& registerers = globalRegisters();
+ registerers.push_back(std::make_shared<myudf::MyUdf1Registerer>());
+ registerers.push_back(std::make_shared<myudf::MyUdf2Registerer>());
+ registerers.push_back(std::make_shared<myudf::MyUdf3Registerer>());
+ registerers.push_back(std::make_shared<mydate::MyDateRegisterer>());
+ inited = true;
+}
DEFINE_GET_NUM_UDF {
- return kNumMyUdf;
+ setupRegisterers();
+
+ int numUdf = 0;
+ for (const auto& registerer : globalRegisters()) {
+ numUdf += registerer->getNumUdf();
+ }
+ return numUdf;
}
-const char* myUdf1Arg1[] = {kInteger};
-const char* myUdf1Arg2[] = {kBigInt};
-const char* myUdf2Arg1[] = {kBigInt};
-const char* myDateArg[] = {kDate, kInteger};
DEFINE_GET_UDF_ENTRIES {
+ setupRegisterers();
+
int index = 0;
- udfEntries[index++] = {"myudf1", kInteger, 1, myUdf1Arg1};
- udfEntries[index++] = {"myudf1", kBigInt, 1, myUdf1Arg2};
- udfEntries[index++] = {"myudf2", kBigInt, 1, myUdf2Arg1};
- udfEntries[index++] = {"mydate", kDate, 2, myDateArg};
+ for (const auto& registerer : globalRegisters()) {
Review Comment:
-> register? Ditto for other places.
##########
backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala:
##########
@@ -319,30 +355,78 @@ object UDFResolver extends Logging {
}
private def getUdfExpression(name: String)(children: Seq[Expression]) = {
- val expressionType =
- UDFMap.getOrElse(
- (name, children.map(_.dataType)),
- throw new UnsupportedOperationException(
- s"UDF $name -> ${children.map(_.dataType.simpleString).mkString(",
")} " +
- s"is not registered.")
- )
- UDFExpression(name, expressionType.dataType, expressionType.nullable,
children)
+ def errorMessage: String =
+ s"UDF $name -> ${children.map(_.dataType.simpleString).mkString(", ")}
is not registered."
+
+ val signatures =
+ UDFMap.getOrElse(name, throw new
UnsupportedOperationException(errorMessage));
+
+ signatures.find(sig => tryBind(sig, children.map(_.dataType))) match {
+ case Some(sig) =>
+ UDFExpression(name, sig.expressionType.dataType,
sig.expressionType.nullable, children)
+ case None =>
+ throw new UnsupportedOperationException(errorMessage)
+ }
}
private def getUdafExpression(name: String)(children: Seq[Expression]) = {
- val (expressionType, aggBufferAttributes) =
+ def errorMessage: String =
+ s"UDAF $name -> ${children.map(_.dataType.simpleString).mkString(", ")}
is not registered."
+
+ val signatures =
UDAFMap.getOrElse(
- (name, children.map(_.dataType)),
- throw new UnsupportedOperationException(
- s"UDAF $name -> ${children.map(_.dataType.simpleString).mkString(",
")} " +
- s"is not registered.")
+ name,
+ throw new UnsupportedOperationException(errorMessage)
)
- UserDefinedAggregateFunction(
- name,
- expressionType.dataType,
- expressionType.nullable,
- children,
- aggBufferAttributes)
+ signatures.find(sig => tryBind(sig, children.map(_.dataType))) match {
+ case Some(sig) =>
+ UserDefinedAggregateFunction(
+ name,
+ sig.expressionType.dataType,
+ sig.expressionType.nullable,
+ children,
+ sig.intermediateAttrs)
+ case None =>
+ throw new UnsupportedOperationException(errorMessage)
+ }
+ }
+
+ private def tryBind(sig: UDFSignatureBase, requiredDataTypes:
Seq[DataType]): Boolean = {
Review Comment:
Could you leave some comment to clarify this key method?
##########
cpp/velox/udf/examples/MyUDF.cc:
##########
@@ -21,15 +21,29 @@
#include <iostream>
#include "udf/Udf.h"
-namespace {
-
using namespace facebook::velox;
using namespace facebook::velox::exec;
static const char* kInteger = "int";
static const char* kBigInt = "bigint";
static const char* kDate = "date";
+class UdfRegisterer {
Review Comment:
-> `UdfRegister`?
##########
backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala:
##########
@@ -319,30 +355,78 @@ object UDFResolver extends Logging {
}
private def getUdfExpression(name: String)(children: Seq[Expression]) = {
- val expressionType =
- UDFMap.getOrElse(
- (name, children.map(_.dataType)),
- throw new UnsupportedOperationException(
- s"UDF $name -> ${children.map(_.dataType.simpleString).mkString(",
")} " +
- s"is not registered.")
- )
- UDFExpression(name, expressionType.dataType, expressionType.nullable,
children)
+ def errorMessage: String =
+ s"UDF $name -> ${children.map(_.dataType.simpleString).mkString(", ")}
is not registered."
+
+ val signatures =
+ UDFMap.getOrElse(name, throw new
UnsupportedOperationException(errorMessage));
+
+ signatures.find(sig => tryBind(sig, children.map(_.dataType))) match {
+ case Some(sig) =>
+ UDFExpression(name, sig.expressionType.dataType,
sig.expressionType.nullable, children)
+ case None =>
+ throw new UnsupportedOperationException(errorMessage)
+ }
}
private def getUdafExpression(name: String)(children: Seq[Expression]) = {
- val (expressionType, aggBufferAttributes) =
+ def errorMessage: String =
+ s"UDAF $name -> ${children.map(_.dataType.simpleString).mkString(", ")}
is not registered."
+
+ val signatures =
UDAFMap.getOrElse(
- (name, children.map(_.dataType)),
- throw new UnsupportedOperationException(
- s"UDAF $name -> ${children.map(_.dataType.simpleString).mkString(",
")} " +
- s"is not registered.")
+ name,
+ throw new UnsupportedOperationException(errorMessage)
)
- UserDefinedAggregateFunction(
- name,
- expressionType.dataType,
- expressionType.nullable,
- children,
- aggBufferAttributes)
+ signatures.find(sig => tryBind(sig, children.map(_.dataType))) match {
+ case Some(sig) =>
+ UserDefinedAggregateFunction(
+ name,
+ sig.expressionType.dataType,
+ sig.expressionType.nullable,
+ children,
+ sig.intermediateAttrs)
+ case None =>
+ throw new UnsupportedOperationException(errorMessage)
+ }
+ }
+
+ private def tryBind(sig: UDFSignatureBase, requiredDataTypes:
Seq[DataType]): Boolean = {
+ if (!sig.variableArity) {
+ sig.children.size == requiredDataTypes.size &&
+ sig.children
+ .zip(requiredDataTypes)
+ .forall { case (candidate, required) =>
DataTypeUtils.sameType(candidate, required) }
+ } else {
+ // If variableArity is true, there must be at least one argument in .the
signature.
Review Comment:
"*** in .the ***", drop `.`
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]