This is an automated email from the ASF dual-hosted git repository.
richox pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/auron.git
The following commit(s) were added to refs/heads/master by this push:
new 630ce1ce [AURON #1588]Implement native function of isnan(87),nv2(83)
,nvl(82), greatest(85), least(84), find_in_set(81) #1588 #1585 (#1585)
630ce1ce is described below
commit 630ce1ce9b19f66329c3a8e6963e412fc1180e74
Author: guixiaowen <[email protected]>
AuthorDate: Wed Nov 5 10:51:51 2025 +0800
[AURON #1588]Implement native function of isnan(87),nv2(83) ,nvl(82),
greatest(85), least(84), find_in_set(81) #1588 #1585 (#1585)
Co-authored-by: guihuawen <[email protected]>
---
native-engine/auron-serde/proto/auron.proto | 6 +
native-engine/auron-serde/src/from_proto.rs | 6 +
.../spark/sql/auron/AuronFunctionSuite.scala | 217 +++++++++++++++++++++
.../apache/spark/sql/auron/NativeConverters.scala | 13 +-
4 files changed, 241 insertions(+), 1 deletion(-)
diff --git a/native-engine/auron-serde/proto/auron.proto
b/native-engine/auron-serde/proto/auron.proto
index 2f567ec8..3fd05252 100644
--- a/native-engine/auron-serde/proto/auron.proto
+++ b/native-engine/auron-serde/proto/auron.proto
@@ -267,7 +267,13 @@ enum ScalarFunction {
Factorial=65;
Hex=66;
Power=67;
+ IsNaN=69;
Levenshtein=80;
+ FindInSet=81;
+ Nvl=82;
+ Nvl2=83;
+ Least=84;
+ Greatest=85;
SparkExtFunctions=10000;
}
diff --git a/native-engine/auron-serde/src/from_proto.rs
b/native-engine/auron-serde/src/from_proto.rs
index 89fb1579..ef1cff44 100644
--- a/native-engine/auron-serde/src/from_proto.rs
+++ b/native-engine/auron-serde/src/from_proto.rs
@@ -784,6 +784,8 @@ impl From<protobuf::ScalarFunction> for Arc<ScalarUDF> {
ScalarFunction::Rtrim => f::string::rtrim(),
ScalarFunction::ToTimestamp => f::datetime::to_timestamp(),
ScalarFunction::NullIf => f::core::nullif(),
+ ScalarFunction::Nvl2 => f::core::nvl2(),
+ ScalarFunction::Nvl => f::core::nvl(),
ScalarFunction::DatePart => f::datetime::date_part(),
ScalarFunction::DateTrunc => f::datetime::date_trunc(),
ScalarFunction::Md5 => f::crypto::md5(),
@@ -815,6 +817,7 @@ impl From<protobuf::ScalarFunction> for Arc<ScalarUDF> {
ScalarFunction::StartsWith => f::string::starts_with(),
ScalarFunction::Levenshtein => f::string::levenshtein(),
+ ScalarFunction::FindInSet => f::unicode::find_in_set(),
ScalarFunction::Strpos => f::unicode::strpos(),
ScalarFunction::Substr => f::unicode::substr(),
// ScalarFunction::ToHex => f::string::to_hex(),
@@ -823,7 +826,9 @@ impl From<protobuf::ScalarFunction> for Arc<ScalarUDF> {
ScalarFunction::Now => f::datetime::now(),
ScalarFunction::Translate => f::unicode::translate(),
ScalarFunction::RegexpMatch => f::regex::regexp_match(),
+ ScalarFunction::Greatest => f::core::greatest(),
ScalarFunction::Coalesce => f::core::coalesce(),
+ ScalarFunction::Least => f::core::least(),
// -- datafusion-spark functions
// math functions
@@ -832,6 +837,7 @@ impl From<protobuf::ScalarFunction> for Arc<ScalarUDF> {
ScalarFunction::Hex => spark_fun::math::hex(),
ScalarFunction::Power => f::math::power(),
+ ScalarFunction::IsNaN => f::math::isnan(),
ScalarFunction::SparkExtFunctions => {
unreachable!()
diff --git
a/spark-extension-shims-spark/src/test/scala/org/apache/spark/sql/auron/AuronFunctionSuite.scala
b/spark-extension-shims-spark/src/test/scala/org/apache/spark/sql/auron/AuronFunctionSuite.scala
index 8ef13584..99348978 100644
---
a/spark-extension-shims-spark/src/test/scala/org/apache/spark/sql/auron/AuronFunctionSuite.scala
+++
b/spark-extension-shims-spark/src/test/scala/org/apache/spark/sql/auron/AuronFunctionSuite.scala
@@ -16,6 +16,8 @@
*/
package org.apache.spark.sql.auron
+import java.text.SimpleDateFormat
+
import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
@@ -320,6 +322,221 @@ class AuronFunctionSuite
assert(row.isNullAt(0) && row.isNullAt(1) && row.isNullAt(2))
}
+ test("test function least") {
+ withTable("t1") {
+ sql(
+ "create table test_least using parquet as select 1 as c1, 2 as c2, 'a'
as c3, 'b' as c4, 'c' as c5")
+
+ val maxValue = Long.MaxValue
+ val minValue = Long.MinValue
+
+ val dateStringMin = "2015-01-01 08:00:00"
+ val dateStringMax = "2015-01-01 11:00:00"
+ var format = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
+ val dateTimeStampMin = format.parse(dateStringMin).getTime
+ val dateTimeStampMax = format.parse(dateStringMax).getTime
+ format = new SimpleDateFormat("yyyy-MM-dd")
+ val dateString = "2015-01-01"
+ val date = format.parse(dateString)
+
+ val functions =
+ s"""
+ |select
+ | least(c4, c3, c5),
+ | least(c1, c2, 1),
+ | least(c1, c2, -1),
+ | least(c4, c5, c3, c3, 'a'),
+ | least(null, null),
+ | least(c4, c3, c5, null),
+ | least(-1.0, 2.5),
+ | least(-1.0, 2),
+ | least(-1.0f, 2.5f),
+ | least(cast(1 as byte), cast(2 as byte)),
+ | least('abc', 'aaaa'),
+ | least(true, false),
+ | least(cast("2015-01-01" as date), cast("2015-07-01" as date)),
+ | least(${dateTimeStampMin}, ${dateTimeStampMax}),
+ | least(${minValue}, ${maxValue})
+ |from
+ | test_least
+ """.stripMargin
+
+ val df = sql(functions)
+
+ checkAnswer(
+ df,
+ Seq(
+ Row(
+ "a",
+ 1,
+ -1,
+ "a",
+ null,
+ "a",
+ -1.0,
+ -1.0,
+ -1.0f,
+ 1,
+ "aaaa",
+ false,
+ date,
+ dateTimeStampMin,
+ minValue)))
+ }
+ }
+
+ test("test function greatest") {
+ withTable("t1") {
+ sql(
+ "create table t1 using parquet as select 1 as c1, 2 as c2, 'a' as c3,
'b' as c4, 'c' as c5")
+
+ val longMax = Long.MaxValue
+ val longMin = Long.MinValue
+ val dateStringMin = "2015-01-01 08:00:00"
+ val dateStringMax = "2015-01-01 11:00:00"
+ var format = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
+ val dateTimeStampMin = format.parse(dateStringMin).getTime
+ val dateTimeStampMax = format.parse(dateStringMax).getTime
+ format = new SimpleDateFormat("yyyy-MM-dd")
+ val dateString = "2015-07-01"
+ val date = format.parse(dateString)
+
+ val functions =
+ s"""
+ |select
+ | greatest(c3, c4, c5),
+ | greatest(c2, c1),
+ | greatest(c1, c2, 2),
+ | greatest(c4, c5, c3, 'ccc'),
+ | greatest(null, null),
+ | greatest(c3, c4, c5, null),
+ | greatest(-1.0, 2.5),
+ | greatest(-1, 2),
+ | greatest(-1.0f, 2.5f),
+ | greatest(${longMax}, ${longMin}),
+ | greatest(cast(1 as byte), cast(2 as byte)),
+ | greatest(cast(1 as short), cast(2 as short)),
+ | greatest("abc", "aaaa"),
+ | greatest(true, false),
+ | greatest(
+ | cast("2015-01-01" as date),
+ | cast("2015-07-01" as date)
+ | ),
+ | greatest(
+ | ${dateTimeStampMin},
+ | ${dateTimeStampMax}
+ | )
+ |from
+ | t1
+ """.stripMargin
+
+ val df = sql(functions)
+ checkAnswer(
+ df,
+ Seq(
+ Row(
+ "c",
+ 2,
+ 2,
+ "ccc",
+ null,
+ "c",
+ 2.5,
+ 2,
+ 2.5f,
+ longMax,
+ 2,
+ 2,
+ "abc",
+ true,
+ date,
+ dateTimeStampMax)))
+
+ }
+ }
+
+ test("test function FindInSet") {
+ withTable("t1") {
+ sql(
+ "create table t1_find_in_set using parquet as select 'ab' as a, 'b' as
b, '' as c, 'def' as d")
+
+ val functions =
+ """
+ |select
+ | find_in_set(a, 'ab'),
+ | find_in_set(b, 'a,b'),
+ | find_in_set(a, 'abc,b,ab,c,def'),
+ | find_in_set(a, 'ab,abc,b,ab,c,def'),
+ | find_in_set(a, ',,,ab,abc,b,ab,c,def'),
+ | find_in_set(c, ',ab,abc,b,ab,c,def'),
+ | find_in_set(a, '数据砖头,abc,b,ab,c,def'),
+ | find_in_set(d, '数据砖头,abc,b,ab,c,def'),
+ | find_in_set(d, null)
+ |from t1_find_in_set
+ """.stripMargin
+
+ val df = sql(functions)
+ df.show()
+ checkAnswer(df, Seq(Row(1, 2, 3, 1, 4, 1, 4, 6, null)))
+ }
+ }
+
+ test("test function IsNaN") {
+ withTable("t1") {
+ sql(
+ "create table test_is_nan using parquet as select cast('NaN' as
double) as c1, cast('NaN' as float) as c2, log(-3) as c3, cast(null as double)
as c4, 5.5f as c5")
+ val functions =
+ """
+ |select
+ | isnan(c1),
+ | isnan(c2),
+ | isnan(c3),
+ | isnan(c4),
+ | isnan(c5)
+ |from
+ | test_is_nan
+ """.stripMargin
+
+ val df = sql(functions)
+ df.show()
+ checkAnswer(df, Seq(Row(true, true, false, false, false)))
+ }
+ }
+
+ test("test function nvl2") {
+ withTable("t1") {
+ sql(
+ "create table t1 using parquet as select 'base'" +
+ " as base, 3 as exponent")
+ val functions =
+ """
+ |select
+ | nvl2(null, base, exponent), nvl2(4, base, exponent)
+ |from t1
+ """.stripMargin
+
+ val df = sql(functions)
+ checkAnswer(df, Seq(Row("base", 3)))
+ }
+ }
+
+ test("test function nvl") {
+ withTable("t1") {
+ sql(
+ "create table t1 using parquet as select 'base'" +
+ " as base, 3 as exponent")
+ val functions =
+ """
+ |select
+ | nvl(null, base), base, nvl(4, exponent)
+ |from t1
+ """.stripMargin
+
+ val df = sql(functions)
+ checkAnswer(df, Seq(Row("base", "base", 4)))
+ }
+ }
+
test("test function Levenshtein") {
withTable("t1") {
sql(
diff --git
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
index e8447434..5ecb41bc 100644
---
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
+++
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
@@ -794,6 +794,8 @@ object NativeConverters extends Logging {
buildScalarFunction(pb.ScalarFunction.Log2,
e.children.map(nullIfNegative), e.dataType)
case e: Log10 =>
buildScalarFunction(pb.ScalarFunction.Log10,
e.children.map(nullIfNegative), e.dataType)
+ case e: Nvl2 =>
+ buildScalarFunction(pb.ScalarFunction.Nvl2, e.children, e.dataType)
case e: Floor if !e.dataType.isInstanceOf[DecimalType] =>
if (e.child.dataType.isInstanceOf[LongType]) {
convertExprWithFallback(e.child, isPruningExpr, fallback)
@@ -821,10 +823,12 @@ object NativeConverters extends Logging {
}
}
case e: Expm1 => buildScalarFunction(pb.ScalarFunction.Expm1,
e.children, e.dataType)
+ case e: Least => buildScalarFunction(pb.ScalarFunction.Least,
e.children, e.dataType)
case e: Factorial =>
buildScalarFunction(pb.ScalarFunction.Factorial, e.children,
e.dataType)
case e: Hex => buildScalarFunction(pb.ScalarFunction.Hex, e.children,
e.dataType)
-
+ case e: IsNaN =>
+ buildScalarFunction(pb.ScalarFunction.IsNaN, e.children, e.dataType)
case e: Round =>
e.scale match {
case Literal(n: Int, _) =>
@@ -834,6 +838,8 @@ object NativeConverters extends Logging {
}
case e: Signum => buildScalarFunction(pb.ScalarFunction.Signum,
e.children, e.dataType)
+ case e: FindInSet =>
+ buildScalarFunction(pb.ScalarFunction.FindInSet, e.children,
e.dataType)
case e: Abs if e.dataType.isInstanceOf[FloatType] ||
e.dataType.isInstanceOf[DoubleType] =>
buildScalarFunction(pb.ScalarFunction.Abs, e.children, e.dataType)
case e: OctetLength =>
@@ -876,8 +882,13 @@ object NativeConverters extends Logging {
buildExtScalarFunction("Murmur3Hash", children, IntegerType)
case XxHash64(children, 42L) =>
buildExtScalarFunction("XxHash64", children, LongType)
+ case e: Greatest =>
+ buildScalarFunction(pb.ScalarFunction.Greatest, e.children, e.dataType)
case e: Pow =>
buildScalarFunction(pb.ScalarFunction.Power, e.children, e.dataType)
+ case e: Nvl =>
+ buildScalarFunction(pb.ScalarFunction.Nvl, e.children, e.dataType)
+
case Year(child) => buildExtScalarFunction("Year", child :: Nil,
IntegerType)
case Month(child) => buildExtScalarFunction("Month", child :: Nil,
IntegerType)
case DayOfMonth(child) => buildExtScalarFunction("Day", child :: Nil,
IntegerType)