This is an automated email from the ASF dual-hosted git repository.

richox pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/auron.git


The following commit(s) were added to refs/heads/master by this push:
     new 630ce1ce [AURON #1588]Implement native function of isnan(87),nv2(83) 
,nvl(82), greatest(85), least(84), find_in_set(81) #1588 #1585 (#1585)
630ce1ce is described below

commit 630ce1ce9b19f66329c3a8e6963e412fc1180e74
Author: guixiaowen <[email protected]>
AuthorDate: Wed Nov 5 10:51:51 2025 +0800

    [AURON #1588]Implement native function of isnan(87),nv2(83) ,nvl(82), 
greatest(85), least(84), find_in_set(81) #1588 #1585 (#1585)
    
    Co-authored-by: guihuawen <[email protected]>
---
 native-engine/auron-serde/proto/auron.proto        |   6 +
 native-engine/auron-serde/src/from_proto.rs        |   6 +
 .../spark/sql/auron/AuronFunctionSuite.scala       | 217 +++++++++++++++++++++
 .../apache/spark/sql/auron/NativeConverters.scala  |  13 +-
 4 files changed, 241 insertions(+), 1 deletion(-)

diff --git a/native-engine/auron-serde/proto/auron.proto 
b/native-engine/auron-serde/proto/auron.proto
index 2f567ec8..3fd05252 100644
--- a/native-engine/auron-serde/proto/auron.proto
+++ b/native-engine/auron-serde/proto/auron.proto
@@ -267,7 +267,13 @@ enum ScalarFunction {
   Factorial=65;
   Hex=66;
   Power=67;
+  IsNaN=69;
   Levenshtein=80;
+  FindInSet=81;
+  Nvl=82;
+  Nvl2=83;
+  Least=84;
+  Greatest=85;
   SparkExtFunctions=10000;
 }
 
diff --git a/native-engine/auron-serde/src/from_proto.rs 
b/native-engine/auron-serde/src/from_proto.rs
index 89fb1579..ef1cff44 100644
--- a/native-engine/auron-serde/src/from_proto.rs
+++ b/native-engine/auron-serde/src/from_proto.rs
@@ -784,6 +784,8 @@ impl From<protobuf::ScalarFunction> for Arc<ScalarUDF> {
             ScalarFunction::Rtrim => f::string::rtrim(),
             ScalarFunction::ToTimestamp => f::datetime::to_timestamp(),
             ScalarFunction::NullIf => f::core::nullif(),
+            ScalarFunction::Nvl2 => f::core::nvl2(),
+            ScalarFunction::Nvl => f::core::nvl(),
             ScalarFunction::DatePart => f::datetime::date_part(),
             ScalarFunction::DateTrunc => f::datetime::date_trunc(),
             ScalarFunction::Md5 => f::crypto::md5(),
@@ -815,6 +817,7 @@ impl From<protobuf::ScalarFunction> for Arc<ScalarUDF> {
             ScalarFunction::StartsWith => f::string::starts_with(),
             ScalarFunction::Levenshtein => f::string::levenshtein(),
 
+            ScalarFunction::FindInSet => f::unicode::find_in_set(),
             ScalarFunction::Strpos => f::unicode::strpos(),
             ScalarFunction::Substr => f::unicode::substr(),
             // ScalarFunction::ToHex => f::string::to_hex(),
@@ -823,7 +826,9 @@ impl From<protobuf::ScalarFunction> for Arc<ScalarUDF> {
             ScalarFunction::Now => f::datetime::now(),
             ScalarFunction::Translate => f::unicode::translate(),
             ScalarFunction::RegexpMatch => f::regex::regexp_match(),
+            ScalarFunction::Greatest => f::core::greatest(),
             ScalarFunction::Coalesce => f::core::coalesce(),
+            ScalarFunction::Least => f::core::least(),
 
             // -- datafusion-spark functions
             // math functions
@@ -832,6 +837,7 @@ impl From<protobuf::ScalarFunction> for Arc<ScalarUDF> {
             ScalarFunction::Hex => spark_fun::math::hex(),
 
             ScalarFunction::Power => f::math::power(),
+            ScalarFunction::IsNaN => f::math::isnan(),
 
             ScalarFunction::SparkExtFunctions => {
                 unreachable!()
diff --git 
a/spark-extension-shims-spark/src/test/scala/org/apache/spark/sql/auron/AuronFunctionSuite.scala
 
b/spark-extension-shims-spark/src/test/scala/org/apache/spark/sql/auron/AuronFunctionSuite.scala
index 8ef13584..99348978 100644
--- 
a/spark-extension-shims-spark/src/test/scala/org/apache/spark/sql/auron/AuronFunctionSuite.scala
+++ 
b/spark-extension-shims-spark/src/test/scala/org/apache/spark/sql/auron/AuronFunctionSuite.scala
@@ -16,6 +16,8 @@
  */
 package org.apache.spark.sql.auron
 
+import java.text.SimpleDateFormat
+
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
 
@@ -320,6 +322,221 @@ class AuronFunctionSuite
     assert(row.isNullAt(0) && row.isNullAt(1) && row.isNullAt(2))
   }
 
+  test("test function least") {
+    withTable("t1") {
+      sql(
+        "create table test_least using parquet as select 1 as c1, 2 as c2, 'a' 
as c3, 'b' as c4, 'c' as c5")
+
+      val maxValue = Long.MaxValue
+      val minValue = Long.MinValue
+
+      val dateStringMin = "2015-01-01 08:00:00"
+      val dateStringMax = "2015-01-01 11:00:00"
+      var format = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
+      val dateTimeStampMin = format.parse(dateStringMin).getTime
+      val dateTimeStampMax = format.parse(dateStringMax).getTime
+      format = new SimpleDateFormat("yyyy-MM-dd")
+      val dateString = "2015-01-01"
+      val date = format.parse(dateString)
+
+      val functions =
+        s"""
+          |select
+          |    least(c4, c3, c5),
+          |    least(c1, c2, 1),
+          |    least(c1, c2, -1),
+          |    least(c4, c5, c3, c3, 'a'),
+          |    least(null, null),
+          |    least(c4, c3, c5, null),
+          |    least(-1.0, 2.5),
+          |    least(-1.0, 2),
+          |    least(-1.0f, 2.5f),
+          |    least(cast(1 as byte), cast(2 as byte)),
+          |    least('abc', 'aaaa'),
+          |    least(true, false),
+          |    least(cast("2015-01-01" as date), cast("2015-07-01" as date)),
+          |    least(${dateTimeStampMin}, ${dateTimeStampMax}),
+          |    least(${minValue}, ${maxValue})
+          |from
+          |    test_least
+        """.stripMargin
+
+      val df = sql(functions)
+
+      checkAnswer(
+        df,
+        Seq(
+          Row(
+            "a",
+            1,
+            -1,
+            "a",
+            null,
+            "a",
+            -1.0,
+            -1.0,
+            -1.0f,
+            1,
+            "aaaa",
+            false,
+            date,
+            dateTimeStampMin,
+            minValue)))
+    }
+  }
+
+  test("test function greatest") {
+    withTable("t1") {
+      sql(
+        "create table t1 using parquet as select 1 as c1, 2 as c2, 'a' as c3, 
'b' as c4, 'c' as c5")
+
+      val longMax = Long.MaxValue
+      val longMin = Long.MinValue
+      val dateStringMin = "2015-01-01 08:00:00"
+      val dateStringMax = "2015-01-01 11:00:00"
+      var format = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
+      val dateTimeStampMin = format.parse(dateStringMin).getTime
+      val dateTimeStampMax = format.parse(dateStringMax).getTime
+      format = new SimpleDateFormat("yyyy-MM-dd")
+      val dateString = "2015-07-01"
+      val date = format.parse(dateString)
+
+      val functions =
+        s"""
+          |select
+          |    greatest(c3, c4, c5),
+          |    greatest(c2, c1),
+          |    greatest(c1, c2, 2),
+          |    greatest(c4, c5, c3, 'ccc'),
+          |    greatest(null, null),
+          |    greatest(c3, c4, c5, null),
+          |    greatest(-1.0, 2.5),
+          |    greatest(-1, 2),
+          |    greatest(-1.0f, 2.5f),
+          |    greatest(${longMax}, ${longMin}),
+          |    greatest(cast(1 as byte), cast(2 as byte)),
+          |    greatest(cast(1 as short), cast(2 as short)),
+          |    greatest("abc", "aaaa"),
+          |    greatest(true, false),
+          |    greatest(
+          |        cast("2015-01-01" as date),
+          |        cast("2015-07-01" as date)
+          |    ),
+          |    greatest(
+          |        ${dateTimeStampMin},
+          |        ${dateTimeStampMax}
+          |    )
+          |from
+          |    t1
+        """.stripMargin
+
+      val df = sql(functions)
+      checkAnswer(
+        df,
+        Seq(
+          Row(
+            "c",
+            2,
+            2,
+            "ccc",
+            null,
+            "c",
+            2.5,
+            2,
+            2.5f,
+            longMax,
+            2,
+            2,
+            "abc",
+            true,
+            date,
+            dateTimeStampMax)))
+
+    }
+  }
+
+  test("test function FindInSet") {
+    withTable("t1") {
+      sql(
+        "create table t1_find_in_set using parquet as select 'ab' as a, 'b' as 
b, '' as c, 'def' as d")
+
+      val functions =
+        """
+          |select
+          |   find_in_set(a, 'ab'),
+          |   find_in_set(b, 'a,b'),
+          |   find_in_set(a, 'abc,b,ab,c,def'),
+          |   find_in_set(a, 'ab,abc,b,ab,c,def'),
+          |   find_in_set(a, ',,,ab,abc,b,ab,c,def'),
+          |   find_in_set(c, ',ab,abc,b,ab,c,def'),
+          |   find_in_set(a, '数据砖头,abc,b,ab,c,def'),
+          |   find_in_set(d, '数据砖头,abc,b,ab,c,def'),
+          |   find_in_set(d, null)
+          |from t1_find_in_set
+        """.stripMargin
+
+      val df = sql(functions)
+      df.show()
+      checkAnswer(df, Seq(Row(1, 2, 3, 1, 4, 1, 4, 6, null)))
+    }
+  }
+
+  test("test function IsNaN") {
+    withTable("t1") {
+      sql(
+        "create table test_is_nan using parquet as select cast('NaN' as 
double) as c1, cast('NaN' as float) as c2, log(-3) as c3, cast(null as double) 
as c4, 5.5f as c5")
+      val functions =
+        """
+          |select
+          |    isnan(c1),
+          |    isnan(c2),
+          |    isnan(c3),
+          |    isnan(c4),
+          |    isnan(c5)
+          |from
+          |    test_is_nan
+        """.stripMargin
+
+      val df = sql(functions)
+      df.show()
+      checkAnswer(df, Seq(Row(true, true, false, false, false)))
+    }
+  }
+
+  test("test function nvl2") {
+    withTable("t1") {
+      sql(
+        "create table t1 using parquet as select 'base'" +
+          " as base, 3 as exponent")
+      val functions =
+        """
+          |select
+          |  nvl2(null, base, exponent), nvl2(4, base, exponent)
+          |from t1
+                      """.stripMargin
+
+      val df = sql(functions)
+      checkAnswer(df, Seq(Row("base", 3)))
+    }
+  }
+
+  test("test function nvl") {
+    withTable("t1") {
+      sql(
+        "create table t1 using parquet as select 'base'" +
+          " as base, 3 as exponent")
+      val functions =
+        """
+          |select
+          |  nvl(null, base), base, nvl(4, exponent)
+          |from t1
+                    """.stripMargin
+
+      val df = sql(functions)
+      checkAnswer(df, Seq(Row("base", "base", 4)))
+    }
+  }
+
   test("test function Levenshtein") {
     withTable("t1") {
       sql(
diff --git 
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
 
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
index e8447434..5ecb41bc 100644
--- 
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
+++ 
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
@@ -794,6 +794,8 @@ object NativeConverters extends Logging {
         buildScalarFunction(pb.ScalarFunction.Log2, 
e.children.map(nullIfNegative), e.dataType)
       case e: Log10 =>
         buildScalarFunction(pb.ScalarFunction.Log10, 
e.children.map(nullIfNegative), e.dataType)
+      case e: Nvl2 =>
+        buildScalarFunction(pb.ScalarFunction.Nvl2, e.children, e.dataType)
       case e: Floor if !e.dataType.isInstanceOf[DecimalType] =>
         if (e.child.dataType.isInstanceOf[LongType]) {
           convertExprWithFallback(e.child, isPruningExpr, fallback)
@@ -821,10 +823,12 @@ object NativeConverters extends Logging {
           }
         }
       case e: Expm1 => buildScalarFunction(pb.ScalarFunction.Expm1, 
e.children, e.dataType)
+      case e: Least => buildScalarFunction(pb.ScalarFunction.Least, 
e.children, e.dataType)
       case e: Factorial =>
         buildScalarFunction(pb.ScalarFunction.Factorial, e.children, 
e.dataType)
       case e: Hex => buildScalarFunction(pb.ScalarFunction.Hex, e.children, 
e.dataType)
-
+      case e: IsNaN =>
+        buildScalarFunction(pb.ScalarFunction.IsNaN, e.children, e.dataType)
       case e: Round =>
         e.scale match {
           case Literal(n: Int, _) =>
@@ -834,6 +838,8 @@ object NativeConverters extends Logging {
         }
 
       case e: Signum => buildScalarFunction(pb.ScalarFunction.Signum, 
e.children, e.dataType)
+      case e: FindInSet =>
+        buildScalarFunction(pb.ScalarFunction.FindInSet, e.children, 
e.dataType)
       case e: Abs if e.dataType.isInstanceOf[FloatType] || 
e.dataType.isInstanceOf[DoubleType] =>
         buildScalarFunction(pb.ScalarFunction.Abs, e.children, e.dataType)
       case e: OctetLength =>
@@ -876,8 +882,13 @@ object NativeConverters extends Logging {
         buildExtScalarFunction("Murmur3Hash", children, IntegerType)
       case XxHash64(children, 42L) =>
         buildExtScalarFunction("XxHash64", children, LongType)
+      case e: Greatest =>
+        buildScalarFunction(pb.ScalarFunction.Greatest, e.children, e.dataType)
       case e: Pow =>
         buildScalarFunction(pb.ScalarFunction.Power, e.children, e.dataType)
+      case e: Nvl =>
+        buildScalarFunction(pb.ScalarFunction.Nvl, e.children, e.dataType)
+
       case Year(child) => buildExtScalarFunction("Year", child :: Nil, 
IntegerType)
       case Month(child) => buildExtScalarFunction("Month", child :: Nil, 
IntegerType)
       case DayOfMonth(child) => buildExtScalarFunction("Day", child :: Nil, 
IntegerType)

Reply via email to