This is an automated email from the ASF dual-hosted git repository.
maruilei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/auron.git
The following commit(s) were added to refs/heads/master by this push:
new c7240b2e [AURON #1613] Introduce AuronExtFunctions (#1614)
c7240b2e is described below
commit c7240b2ee14ea8bda22ce81723a99b0e200d7af6
Author: zhangmang <[email protected]>
AuthorDate: Mon Nov 10 15:01:04 2025 +0800
[AURON #1613] Introduce AuronExtFunctions (#1614)
* [AURON #1613] Introduce AuronExtFunctions
* fix conflicts
---
native-engine/auron-serde/proto/auron.proto | 2 +-
native-engine/auron-serde/src/from_proto.rs | 6 +-
native-engine/datafusion-ext-functions/src/lib.rs | 73 +++++++++++----------
.../org/apache/spark/sql/auron/ShimsImpl.scala | 4 +-
.../apache/spark/sql/auron/NativeConverters.scala | 74 ++++++++++++----------
.../sql/execution/auron/plan/NativeAggBase.scala | 2 +-
6 files changed, 85 insertions(+), 76 deletions(-)
diff --git a/native-engine/auron-serde/proto/auron.proto
b/native-engine/auron-serde/proto/auron.proto
index 3fd05252..5a4076f2 100644
--- a/native-engine/auron-serde/proto/auron.proto
+++ b/native-engine/auron-serde/proto/auron.proto
@@ -274,7 +274,7 @@ enum ScalarFunction {
Nvl2=83;
Least=84;
Greatest=85;
- SparkExtFunctions=10000;
+ AuronExtFunctions=10000;
}
message PhysicalScalarFunctionNode {
diff --git a/native-engine/auron-serde/src/from_proto.rs
b/native-engine/auron-serde/src/from_proto.rs
index ef1cff44..45071360 100644
--- a/native-engine/auron-serde/src/from_proto.rs
+++ b/native-engine/auron-serde/src/from_proto.rs
@@ -839,7 +839,7 @@ impl From<protobuf::ScalarFunction> for Arc<ScalarUDF> {
ScalarFunction::Power => f::math::power(),
ScalarFunction::IsNaN => f::math::isnan(),
- ScalarFunction::SparkExtFunctions => {
+ ScalarFunction::AuronExtFunctions => {
unreachable!()
}
}
@@ -945,9 +945,9 @@ fn try_parse_physical_expr(
.map(|x| try_parse_physical_expr(x, input_schema))
.collect::<Result<Vec<_>, _>>()?;
- let scalar_udf = if scalar_function ==
protobuf::ScalarFunction::SparkExtFunctions {
+ let scalar_udf = if scalar_function ==
protobuf::ScalarFunction::AuronExtFunctions {
let fun_name = &e.name;
- let fun =
datafusion_ext_functions::create_spark_ext_function(fun_name)?;
+ let fun =
datafusion_ext_functions::create_auron_ext_function(fun_name)?;
Arc::new(create_udf(
&format!("spark_ext_function_{}", fun_name),
args.iter()
diff --git a/native-engine/datafusion-ext-functions/src/lib.rs
b/native-engine/datafusion-ext-functions/src/lib.rs
index 96118e93..c48c2871 100644
--- a/native-engine/datafusion-ext-functions/src/lib.rs
+++ b/native-engine/datafusion-ext-functions/src/lib.rs
@@ -32,43 +32,48 @@ mod spark_sha2;
mod spark_strings;
mod spark_unscaled_value;
-pub fn create_spark_ext_function(name: &str) ->
Result<ScalarFunctionImplementation> {
+pub fn create_auron_ext_function(name: &str) ->
Result<ScalarFunctionImplementation> {
+ // auron ext functions, if used for spark should be start with 'Spark_',
+ // if used for flink should be start with 'Flink_',
+ // same to other engines.
Ok(match name {
"Placeholder" => Arc::new(|_| panic!("placeholder() should never be
called")),
- "NullIf" => Arc::new(spark_null_if::spark_null_if),
- "NullIfZero" => Arc::new(spark_null_if::spark_null_if_zero),
- "UnscaledValue" =>
Arc::new(spark_unscaled_value::spark_unscaled_value),
- "MakeDecimal" => Arc::new(spark_make_decimal::spark_make_decimal),
- "CheckOverflow" =>
Arc::new(spark_check_overflow::spark_check_overflow),
- "Murmur3Hash" => Arc::new(spark_hash::spark_murmur3_hash),
- "XxHash64" => Arc::new(spark_hash::spark_xxhash64),
- "Sha224" => Arc::new(spark_sha2::spark_sha224),
- "Sha256" => Arc::new(spark_sha2::spark_sha256),
- "Sha384" => Arc::new(spark_sha2::spark_sha384),
- "Sha512" => Arc::new(spark_sha2::spark_sha512),
- "GetJsonObject" =>
Arc::new(spark_get_json_object::spark_get_json_object),
- "GetParsedJsonObject" =>
Arc::new(spark_get_json_object::spark_get_parsed_json_object),
- "ParseJson" => Arc::new(spark_get_json_object::spark_parse_json),
- "MakeArray" => Arc::new(spark_make_array::array),
- "StringSpace" => Arc::new(spark_strings::string_space),
- "StringRepeat" => Arc::new(spark_strings::string_repeat),
- "StringSplit" => Arc::new(spark_strings::string_split),
- "StringConcat" => Arc::new(spark_strings::string_concat),
- "StringConcatWs" => Arc::new(spark_strings::string_concat_ws),
- "StringLower" => Arc::new(spark_strings::string_lower),
- "StringUpper" => Arc::new(spark_strings::string_upper),
- "Year" => Arc::new(spark_dates::spark_year),
- "Month" => Arc::new(spark_dates::spark_month),
- "Day" => Arc::new(spark_dates::spark_day),
- "Quarter" => Arc::new(spark_dates::spark_quarter),
- "Hour" => Arc::new(spark_dates::spark_hour),
- "Minute" => Arc::new(spark_dates::spark_minute),
- "Second" => Arc::new(spark_dates::spark_second),
- "BrickhouseArrayUnion" =>
Arc::new(brickhouse::array_union::array_union),
- "Round" => Arc::new(spark_round::spark_round),
- "NormalizeNanAndZero" => {
+ "Spark_NullIf" => Arc::new(spark_null_if::spark_null_if),
+ "Spark_NullIfZero" => Arc::new(spark_null_if::spark_null_if_zero),
+ "Spark_UnscaledValue" =>
Arc::new(spark_unscaled_value::spark_unscaled_value),
+ "Spark_MakeDecimal" =>
Arc::new(spark_make_decimal::spark_make_decimal),
+ "Spark_CheckOverflow" =>
Arc::new(spark_check_overflow::spark_check_overflow),
+ "Spark_Murmur3Hash" => Arc::new(spark_hash::spark_murmur3_hash),
+ "Spark_XxHash64" => Arc::new(spark_hash::spark_xxhash64),
+ "Spark_Sha224" => Arc::new(spark_sha2::spark_sha224),
+ "Spark_Sha256" => Arc::new(spark_sha2::spark_sha256),
+ "Spark_Sha384" => Arc::new(spark_sha2::spark_sha384),
+ "Spark_Sha512" => Arc::new(spark_sha2::spark_sha512),
+ "Spark_GetJsonObject" =>
Arc::new(spark_get_json_object::spark_get_json_object),
+ "Spark_GetParsedJsonObject" => {
+ Arc::new(spark_get_json_object::spark_get_parsed_json_object)
+ }
+ "Spark_ParseJson" => Arc::new(spark_get_json_object::spark_parse_json),
+ "Spark_MakeArray" => Arc::new(spark_make_array::array),
+ "Spark_StringSpace" => Arc::new(spark_strings::string_space),
+ "Spark_StringRepeat" => Arc::new(spark_strings::string_repeat),
+ "Spark_StringSplit" => Arc::new(spark_strings::string_split),
+ "Spark_StringConcat" => Arc::new(spark_strings::string_concat),
+ "Spark_StringConcatWs" => Arc::new(spark_strings::string_concat_ws),
+ "Spark_StringLower" => Arc::new(spark_strings::string_lower),
+ "Spark_StringUpper" => Arc::new(spark_strings::string_upper),
+ "Spark_Year" => Arc::new(spark_dates::spark_year),
+ "Spark_Month" => Arc::new(spark_dates::spark_month),
+ "Spark_Day" => Arc::new(spark_dates::spark_day),
+ "Spark_Quarter" => Arc::new(spark_dates::spark_quarter),
+ "Spark_Hour" => Arc::new(spark_dates::spark_hour),
+ "Spark_Minute" => Arc::new(spark_dates::spark_minute),
+ "Spark_Second" => Arc::new(spark_dates::spark_second),
+ "Spark_BrickhouseArrayUnion" =>
Arc::new(brickhouse::array_union::array_union),
+ "Spark_Round" => Arc::new(spark_round::spark_round),
+ "Spark_NormalizeNanAndZero" => {
Arc::new(spark_normalize_nan_and_zero::spark_normalize_nan_and_zero)
}
- _ => df_unimplemented_err!("spark ext function not implemented:
{name}")?,
+ _ => df_unimplemented_err!("auron ext function not implemented:
{name}")?,
})
}
diff --git
a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala
b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala
index 0ff2f5a5..0ab39213 100644
---
a/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala
+++
b/spark-extension-shims-spark/src/main/scala/org/apache/spark/sql/auron/ShimsImpl.scala
@@ -528,8 +528,8 @@ class ShimsImpl extends Shims with Logging {
.setScalarFunction(
pb.PhysicalScalarFunctionNode
.newBuilder()
- .setFun(pb.ScalarFunction.SparkExtFunctions)
- .setName("StringSplit")
+ .setFun(pb.ScalarFunction.AuronExtFunctions)
+ .setName("Spark_StringSplit")
.addArgs(NativeConverters.convertExprWithFallback(str,
isPruningExpr, fallback))
.addArgs(NativeConverters
.convertExprWithFallback(Literal(nativePat), isPruningExpr,
fallback))
diff --git
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
index 5ecb41bc..44643801 100644
---
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
+++
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
@@ -698,7 +698,7 @@ object NativeConverters extends Logging {
pb.PhysicalBinaryExprNode
.newBuilder()
.setL(convertExprWithFallback(Cast(lhs, resultType),
isPruningExpr, fallback))
- .setR(buildExtScalarFunction("NullIfZero", rhs :: Nil,
rhs.dataType))
+ .setR(buildExtScalarFunction("Spark_NullIfZero", rhs ::
Nil, rhs.dataType))
.setOp("Divide"))
}))
}
@@ -711,7 +711,7 @@ object NativeConverters extends Logging {
pb.PhysicalBinaryExprNode
.newBuilder()
.setL(convertExprWithFallback(lhsCasted, isPruningExpr,
fallback))
- .setR(buildExtScalarFunction("NullIfZero", rhsCasted :: Nil,
rhs.dataType))
+ .setR(buildExtScalarFunction("Spark_NullIfZero", rhsCasted ::
Nil, rhs.dataType))
.setOp("Divide"))
}
}
@@ -733,7 +733,8 @@ object NativeConverters extends Logging {
pb.PhysicalBinaryExprNode
.newBuilder()
.setL(convertExprWithFallback(lhsCasted, isPruningExpr,
fallback))
- .setR(buildExtScalarFunction("NullIfZero", rhsCasted :: Nil,
rhs.dataType))
+ .setR(
+ buildExtScalarFunction("Spark_NullIfZero", rhsCasted ::
Nil, rhs.dataType))
.setOp("Modulo"))
}
}
@@ -832,9 +833,9 @@ object NativeConverters extends Logging {
case e: Round =>
e.scale match {
case Literal(n: Int, _) =>
- buildExtScalarFunction("Round", Seq(e.child, Literal(n.toLong)),
e.dataType)
+ buildExtScalarFunction("Spark_Round", Seq(e.child,
Literal(n.toLong)), e.dataType)
case _ =>
- buildExtScalarFunction("Round", Seq(e.child, Literal(0L)),
e.dataType)
+ buildExtScalarFunction("Spark_Round", Seq(e.child, Literal(0L)),
e.dataType)
}
case e: Signum => buildScalarFunction(pb.ScalarFunction.Signum,
e.children, e.dataType)
@@ -849,10 +850,10 @@ object NativeConverters extends Logging {
case e: Lower
if
sparkAuronConfig.getBoolean(SparkAuronConfiguration.CASE_CONVERT_FUNCTIONS_ENABLE)
=>
- buildExtScalarFunction("StringLower", e.children, e.dataType)
+ buildExtScalarFunction("Spark_StringLower", e.children, e.dataType)
case e: Upper
if
sparkAuronConfig.getBoolean(SparkAuronConfiguration.CASE_CONVERT_FUNCTIONS_ENABLE)
=>
- buildExtScalarFunction("StringUpper", e.children, e.dataType)
+ buildExtScalarFunction("Spark_StringLower", e.children, e.dataType)
case e: StringTrim =>
buildScalarFunction(pb.ScalarFunction.Trim, e.srcStr +:
e.trimStr.toSeq, e.dataType)
@@ -861,7 +862,7 @@ object NativeConverters extends Logging {
case e: StringTrimRight =>
buildScalarFunction(pb.ScalarFunction.Rtrim, e.srcStr +:
e.trimStr.toSeq, e.dataType)
case e @ NullIf(left, right, _) =>
- buildExtScalarFunction("NullIf", left :: right :: Nil, e.dataType)
+ buildExtScalarFunction("Spark_NullIf", left :: right :: Nil,
e.dataType)
case Md5(_1) =>
buildScalarFunction(pb.ScalarFunction.MD5,
Seq(unpackBinaryTypeCast(_1)), StringType)
case Reverse(_1) =>
@@ -869,19 +870,19 @@ object NativeConverters extends Logging {
case InitCap(_1) =>
buildScalarFunction(pb.ScalarFunction.InitCap,
Seq(unpackBinaryTypeCast(_1)), StringType)
case Sha2(_1, Literal(224, _)) =>
- buildExtScalarFunction("Sha224", Seq(unpackBinaryTypeCast(_1)),
StringType)
+ buildExtScalarFunction("Spark_Sha224", Seq(unpackBinaryTypeCast(_1)),
StringType)
case Sha2(_1, Literal(0, _)) =>
- buildExtScalarFunction("Sha256", Seq(unpackBinaryTypeCast(_1)),
StringType)
+ buildExtScalarFunction("Spark_Sha256", Seq(unpackBinaryTypeCast(_1)),
StringType)
case Sha2(_1, Literal(256, _)) =>
- buildExtScalarFunction("Sha256", Seq(unpackBinaryTypeCast(_1)),
StringType)
+ buildExtScalarFunction("Spark_Sha256", Seq(unpackBinaryTypeCast(_1)),
StringType)
case Sha2(_1, Literal(384, _)) =>
- buildExtScalarFunction("Sha384", Seq(unpackBinaryTypeCast(_1)),
StringType)
+ buildExtScalarFunction("Spark_Sha384", Seq(unpackBinaryTypeCast(_1)),
StringType)
case Sha2(_1, Literal(512, _)) =>
- buildExtScalarFunction("Sha512", Seq(unpackBinaryTypeCast(_1)),
StringType)
+ buildExtScalarFunction("Spark_Sha512", Seq(unpackBinaryTypeCast(_1)),
StringType)
case Murmur3Hash(children, 42) =>
- buildExtScalarFunction("Murmur3Hash", children, IntegerType)
+ buildExtScalarFunction("Spark_Murmur3Hash", children, IntegerType)
case XxHash64(children, 42L) =>
- buildExtScalarFunction("XxHash64", children, LongType)
+ buildExtScalarFunction("Spark_XxHash64", children, LongType)
case e: Greatest =>
buildScalarFunction(pb.ScalarFunction.Greatest, e.children, e.dataType)
case e: Pow =>
@@ -889,20 +890,20 @@ object NativeConverters extends Logging {
case e: Nvl =>
buildScalarFunction(pb.ScalarFunction.Nvl, e.children, e.dataType)
- case Year(child) => buildExtScalarFunction("Year", child :: Nil,
IntegerType)
- case Month(child) => buildExtScalarFunction("Month", child :: Nil,
IntegerType)
- case DayOfMonth(child) => buildExtScalarFunction("Day", child :: Nil,
IntegerType)
- case Quarter(child) => buildExtScalarFunction("Quarter", child :: Nil,
IntegerType)
+ case Year(child) => buildExtScalarFunction("Spark_Year", child :: Nil,
IntegerType)
+ case Month(child) => buildExtScalarFunction("Spark_Month", child :: Nil,
IntegerType)
+ case DayOfMonth(child) => buildExtScalarFunction("Spark_Day", child ::
Nil, IntegerType)
+ case Quarter(child) => buildExtScalarFunction("Spark_Quarter", child ::
Nil, IntegerType)
case e: Levenshtein =>
buildScalarFunction(pb.ScalarFunction.Levenshtein, e.children,
e.dataType)
case e: Hour if datetimeExtractEnabled =>
- buildTimePartExt("Hour", e.children.head, isPruningExpr, fallback)
+ buildTimePartExt("Spark_Hour", e.children.head, isPruningExpr,
fallback)
case e: Minute if datetimeExtractEnabled =>
- buildTimePartExt("Minute", e.children.head, isPruningExpr, fallback)
+ buildTimePartExt("Spark_Minute", e.children.head, isPruningExpr,
fallback)
case e: Second if datetimeExtractEnabled =>
- buildTimePartExt("Second", e.children.head, isPruningExpr, fallback)
+ buildTimePartExt("Spark_Second", e.children.head, isPruningExpr,
fallback)
// startswith is converted to scalar function in pruning-expr mode
case StartsWith(expr, Literal(prefix, StringType)) if isPruningExpr =>
@@ -949,20 +950,20 @@ object NativeConverters extends Logging {
StringType)
case StringSpace(n) =>
- buildExtScalarFunction("StringSpace", n :: Nil, StringType)
+ buildExtScalarFunction("Spark_StringSpace", n :: Nil, StringType)
case StringRepeat(str, n @ Literal(_, IntegerType)) =>
- buildExtScalarFunction("StringRepeat", str :: n :: Nil, StringType)
+ buildExtScalarFunction("Spark_StringRepeat", str :: n :: Nil,
StringType)
case e: Concat if e.children.forall(_.dataType == StringType) =>
- buildExtScalarFunction("StringConcat", e.children, e.dataType)
+ buildExtScalarFunction("Spark_StringConcat", e.children, e.dataType)
case e: ConcatWs
if e.children.nonEmpty
&& e.children.head.isInstanceOf[Literal]
&& e.children.forall(c =>
c.dataType == StringType || c.dataType == ArrayType(StringType))
=>
- buildExtScalarFunction("StringConcatWs", e.children, e.dataType)
+ buildExtScalarFunction("Spark_StringConcatWs", e.children, e.dataType)
case e: Coalesce =>
val children = e.children.map(Cast(_, e.dataType))
@@ -1011,7 +1012,7 @@ object NativeConverters extends Logging {
// expressions for DecimalPrecision rule
case UnscaledValue(_1) if decimalArithOpEnabled =>
val args = _1 :: Nil
- buildExtScalarFunction("UnscaledValue", args, LongType)
+ buildExtScalarFunction("Spark_UnscaledValue", args, LongType)
case e: MakeDecimal if decimalArithOpEnabled =>
val precision = e.precision
@@ -1019,7 +1020,7 @@ object NativeConverters extends Logging {
val args =
e.child :: Literal
.apply(precision, IntegerType) :: Literal.apply(scale,
IntegerType) :: Nil
- buildExtScalarFunction("MakeDecimal", args, DecimalType(precision,
scale))
+ buildExtScalarFunction("Spark_MakeDecimal", args,
DecimalType(precision, scale))
case e: CheckOverflow if decimalArithOpEnabled =>
val precision = e.dataType.precision
@@ -1027,13 +1028,13 @@ object NativeConverters extends Logging {
val args =
e.child :: Literal
.apply(precision, IntegerType) :: Literal.apply(scale,
IntegerType) :: Nil
- buildExtScalarFunction("CheckOverflow", args, DecimalType(precision,
scale))
+ buildExtScalarFunction("Spark_CheckOverflow", args,
DecimalType(precision, scale))
case e: NormalizeNaNAndZero
if e.dataType.isInstanceOf[FloatType] ||
e.dataType.isInstanceOf[DoubleType] =>
- buildExtScalarFunction("NormalizeNanAndZero", e.children, e.dataType)
+ buildExtScalarFunction("Spark_NormalizeNanAndZero", e.children,
e.dataType)
- case e: CreateArray => buildExtScalarFunction("MakeArray", e.children,
e.dataType)
+ case e: CreateArray => buildExtScalarFunction("Spark_MakeArray",
e.children, e.dataType)
case e: CreateNamedStruct =>
buildExprNode {
@@ -1100,16 +1101,19 @@ object NativeConverters extends Logging {
// The benefit of this approach is that if there are multiple calls,
// the JSON object can be reused, which can significantly improve
performance.
val parsed = Shims.get.createNativeExprWrapper(
- buildExtScalarFunction("ParseJson", e.children(0) :: Nil,
BinaryType),
+ buildExtScalarFunction("Spark_ParseJson", e.children(0) :: Nil,
BinaryType),
BinaryType,
nullable = false)
- buildExtScalarFunction("GetParsedJsonObject", parsed :: e.children(1)
:: Nil, StringType)
+ buildExtScalarFunction(
+ "Spark_GetParsedJsonObject",
+ parsed :: e.children(1) :: Nil,
+ StringType)
// hive UDF brickhouse.array_union
case e
if
getFunctionClassName(e).contains("brickhouse.udf.collect.ArrayUnionUDF")
&& udfBrickHouseEnabled =>
- buildExtScalarFunction("BrickhouseArrayUnion", e.children, e.dataType)
+ buildExtScalarFunction("Spark_BrickhouseArrayUnion", e.children,
e.dataType)
case e =>
Shims.get.convertMoreExprWithFallback(e, isPruningExpr, fallback)
match {
@@ -1304,7 +1308,7 @@ object NativeConverters extends Logging {
pb.PhysicalScalarFunctionNode
.newBuilder()
.setName(name)
- .setFun(pb.ScalarFunction.SparkExtFunctions)
+ .setFun(pb.ScalarFunction.AuronExtFunctions)
.addAllArgs(
args.map(expr => convertExprWithFallback(expr, isPruningExpr,
fallback)).asJava)
.setReturnType(convertDataType(dataType)))
diff --git
a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeAggBase.scala
b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeAggBase.scala
index 7f75e4d6..54e2cddc 100644
---
a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeAggBase.scala
+++
b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeAggBase.scala
@@ -275,7 +275,7 @@ object NativeAggBase extends Logging {
.setScalarFunction(
pb.PhysicalScalarFunctionNode
.newBuilder()
- .setFun(pb.ScalarFunction.SparkExtFunctions)
+ .setFun(pb.ScalarFunction.AuronExtFunctions)
.setName("Placeholder")
.setReturnType(NativeConverters.convertDataType(e.dataType)))
.build()