This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/auron.git


The following commit(s) were added to refs/heads/master by this push:
     new 4c52bf44 [AURON#1327] Implement native function of `round` (#1426)
4c52bf44 is described below

commit 4c52bf4423ff5a2e9024a9c3bc2a80e039ff5079
Author: slfan1989 <[email protected]>
AuthorDate: Sat Oct 11 12:01:15 2025 +0800

    [AURON#1327] Implement native function of `round` (#1426)
    
    ### Which issue does this PR close?
    
    Closes #1327.
    
    ### Rationale for this change
    
    `spark_round` is a Rust implementation of an Apache Spark-style round 
function for the DataFusion query engine. Its primary purpose is to perform 
rounding operations on numerical values, adhering to Spark's HALF_UP rounding 
mode (i.e., `0.5` rounds to `1`, `-0.5` rounds to `-1`). It supports multiple 
data types (`Float64`, `Float32`, `Int16`, `Int32`, `Int64`, `Decimal128`) and 
can handle negative precision and null values.
    
    
    ### What changes are included in this PR?
    
    - We implemented the Round function following Spark’s HALF_UP rounding 
semantics,
    ensuring full behavioral alignment with Spark SQL.
    
    - For validation, we directly reused the unit tests from 
`MathExpressionsSuite#round/bround`,
    comparing our implementation against Spark’s native results using:
    
    ```
    checkEvaluation(Round(doublePi, scale), doubleResults(i), EmptyRow)
    checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow)
    checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow)
    checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow)
    checkEvaluation(Round(floatPi, scale), floatResults(i), EmptyRow)
    ```
    
    - We also added additional boundary test cases to ensure that spark_round
    behaves correctly under edge conditions such as large numbers, small 
numbers, and negative scales.
    
    ### Are there any user-facing changes?
    No.
    
    ### How was this patch tested?
    
    Unit Test.
---
 native-engine/datafusion-ext-functions/Cargo.toml  |   2 +-
 native-engine/datafusion-ext-functions/src/lib.rs  |   2 +
 .../datafusion-ext-functions/src/spark_round.rs    | 447 +++++++++++++++++++++
 .../spark/sql/auron/AuronFunctionSuite.scala       | 154 +++++++
 .../apache/spark/sql/auron/NativeConverters.scala  |  10 +-
 5 files changed, 611 insertions(+), 4 deletions(-)

diff --git a/native-engine/datafusion-ext-functions/Cargo.toml 
b/native-engine/datafusion-ext-functions/Cargo.toml
index b172b09c..495e4c7a 100644
--- a/native-engine/datafusion-ext-functions/Cargo.toml
+++ b/native-engine/datafusion-ext-functions/Cargo.toml
@@ -33,4 +33,4 @@ log = { workspace = true }
 num = { workspace = true }
 paste = { workspace = true }
 serde_json = { workspace = true }
-sonic-rs = { workspace = true }
+sonic-rs = { workspace = true }
\ No newline at end of file
diff --git a/native-engine/datafusion-ext-functions/src/lib.rs 
b/native-engine/datafusion-ext-functions/src/lib.rs
index 5f311823..f2311f37 100644
--- a/native-engine/datafusion-ext-functions/src/lib.rs
+++ b/native-engine/datafusion-ext-functions/src/lib.rs
@@ -27,6 +27,7 @@ mod spark_make_array;
 mod spark_make_decimal;
 mod spark_normalize_nan_and_zero;
 mod spark_null_if;
+mod spark_round;
 mod spark_sha2;
 mod spark_strings;
 mod spark_unscaled_value;
@@ -60,6 +61,7 @@ pub fn create_spark_ext_function(name: &str) -> 
Result<ScalarFunctionImplementat
         "Month" => Arc::new(spark_dates::spark_month),
         "Day" => Arc::new(spark_dates::spark_day),
         "BrickhouseArrayUnion" => 
Arc::new(brickhouse::array_union::array_union),
+        "Round" => Arc::new(spark_round::spark_round),
         "NormalizeNanAndZero" => {
             
Arc::new(spark_normalize_nan_and_zero::spark_normalize_nan_and_zero)
         }
diff --git a/native-engine/datafusion-ext-functions/src/spark_round.rs 
b/native-engine/datafusion-ext-functions/src/spark_round.rs
new file mode 100644
index 00000000..e8de0e57
--- /dev/null
+++ b/native-engine/datafusion-ext-functions/src/spark_round.rs
@@ -0,0 +1,447 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+use std::sync::Arc;
+
+use arrow::{
+    array::{Decimal128Array, Float32Array, Float64Array, Int16Array, 
Int32Array, Int64Array},
+    datatypes::DataType,
+};
+use datafusion::{
+    common::{
+        DataFusionError, Result, ScalarValue,
+        cast::{
+            as_decimal128_array, as_float32_array, as_float64_array, 
as_int16_array,
+            as_int32_array, as_int64_array,
+        },
+    },
+    physical_plan::ColumnarValue,
+};
+
+/// Spark-style `round(expr, scale)` implementation.
+/// - Uses HALF_UP rounding mode (`0.5 → 1`, `-0.5 → -1`)
+/// - Supports negative scales (e.g., `round(123.4, -1) = 120`)
+/// - Handles Float, Decimal, Int16/32/64
+/// - Null-safe
+pub fn spark_round(args: &[ColumnarValue]) -> Result<ColumnarValue> {
+    if args.len() != 2 {
+        return Err(DataFusionError::Execution(
+            "spark_round() requires two arguments".to_string(),
+        ));
+    }
+
+    let value = &args[0];
+    let scale_val = &args[1];
+
+    // Parse scale (must be a literal integer)
+    let scale = match scale_val {
+        ColumnarValue::Scalar(ScalarValue::Int32(Some(n))) => *n,
+        ColumnarValue::Scalar(ScalarValue::Int64(Some(n))) => *n as i32,
+        _ => {
+            return Err(DataFusionError::Execution(
+                "spark_round() scale must be a literal integer".to_string(),
+            ));
+        }
+    };
+
+    match value {
+        // ---------- Array input ----------
+        ColumnarValue::Array(arr) => match arr.data_type() {
+            DataType::Decimal128(..) => {
+                let dec_arr = as_decimal128_array(arr)?;
+                let precision = dec_arr.precision();
+                let in_scale = dec_arr.scale();
+
+                let result = 
Decimal128Array::from_iter(dec_arr.iter().map(|opt| {
+                    opt.map(|v| {
+                        let diff = in_scale as i32 - scale;
+                        if diff >= 0 {
+                            round_i128_half_up(v, -diff)
+                        } else {
+                            v * 10_i128.pow((-diff) as u32)
+                        }
+                    })
+                }))
+                .with_precision_and_scale(precision, in_scale)
+                .map_err(|e| DataFusionError::Execution(e.to_string()))?;
+
+                Ok(ColumnarValue::Array(Arc::new(result)))
+            }
+
+            DataType::Int64 => 
Ok(ColumnarValue::Array(Arc::new(Int64Array::from_iter(
+                as_int64_array(arr)?
+                    .iter()
+                    .map(|opt| opt.map(|v| round_i128_half_up(v as i128, 
scale) as i64)),
+            )))),
+
+            DataType::Int32 => 
Ok(ColumnarValue::Array(Arc::new(Int32Array::from_iter(
+                as_int32_array(arr)?
+                    .iter()
+                    .map(|opt| opt.map(|v| round_i128_half_up(v as i128, 
scale) as i32)),
+            )))),
+
+            DataType::Int16 => 
Ok(ColumnarValue::Array(Arc::new(Int16Array::from_iter(
+                as_int16_array(arr)?
+                    .iter()
+                    .map(|opt| opt.map(|v| round_i128_half_up(v as i128, 
scale) as i16)),
+            )))),
+
+            DataType::Float32 => {
+                // Handle Float32 Array case
+                let arr = as_float32_array(arr)?;
+                let factor = 10_f32.powi(scale);
+                let result = Float32Array::from_iter(arr.iter().map(|opt| {
+                    opt.map(|v| {
+                        if v.is_nan() || v.is_infinite() {
+                            v
+                        } else {
+                            round_half_up_f32(v * factor) / factor
+                        }
+                    })
+                }));
+
+                Ok(ColumnarValue::Array(Arc::new(result)))
+            }
+
+            // Float64 fallback
+            _ => {
+                let arr = as_float64_array(arr)?;
+                let factor = 10_f64.powi(scale);
+                let result = Float64Array::from_iter(arr.iter().map(|opt| {
+                    opt.map(|v| {
+                        if v.is_nan() || v.is_infinite() {
+                            v
+                        } else {
+                            round_half_up_f64(v * factor) / factor
+                        }
+                    })
+                }));
+                Ok(ColumnarValue::Array(Arc::new(result)))
+            }
+        },
+
+        // ---------- Scalar input ----------
+        ColumnarValue::Scalar(sv) => {
+            if sv.is_null() {
+                return Ok(ColumnarValue::Scalar(sv.clone()));
+            }
+
+            Ok(match sv {
+                ScalarValue::Float64(Some(v)) => {
+                    let f = 10_f64.powi(scale);
+                    
ColumnarValue::Scalar(ScalarValue::Float64(Some(round_half_up_f64(v * f) / f)))
+                }
+                ScalarValue::Float32(Some(v)) => {
+                    let f = 10_f64.powi(scale);
+                    ColumnarValue::Scalar(ScalarValue::Float32(Some(
+                        (round_half_up_f64((*v as f64) * f) / f) as f32,
+                    )))
+                }
+                ScalarValue::Int64(Some(v)) => 
ColumnarValue::Scalar(ScalarValue::Int64(Some(
+                    round_i128_half_up(*v as i128, scale) as i64,
+                ))),
+                ScalarValue::Int32(Some(v)) => 
ColumnarValue::Scalar(ScalarValue::Int32(Some(
+                    round_i128_half_up(*v as i128, scale) as i32,
+                ))),
+                ScalarValue::Int16(Some(v)) => 
ColumnarValue::Scalar(ScalarValue::Int16(Some(
+                    round_i128_half_up(*v as i128, scale) as i16,
+                ))),
+                ScalarValue::Decimal128(Some(v), p, s) => 
ColumnarValue::Scalar(
+                    ScalarValue::Decimal128(Some(round_i128_half_up(*v, 
scale)), *p, *s),
+                ),
+                _ => {
+                    return Err(DataFusionError::Execution(
+                        "Unsupported type for spark_round()".to_string(),
+                    ));
+                }
+            })
+        }
+    }
+}
+
+/// Spark-style HALF_UP rounding (0.5 → 1, -0.5 → -1)
+fn round_half_up_f64(x: f64) -> f64 {
+    if x >= 0.0 {
+        (x + 0.5).floor()
+    } else {
+        (x - 0.5).ceil()
+    }
+}
+
+/// Spark-style HALF_UP rounding (0.5 → 1, -0.5 → -1) for Float32
+fn round_half_up_f32(x: f32) -> f32 {
+    if x >= 0.0 {
+        (x + 0.5).floor()
+    } else {
+        (x - 0.5).ceil()
+    }
+}
+
+/// Integer rounding using Spark's HALF_UP logic without float precision loss
+fn round_i128_half_up(value: i128, scale: i32) -> i128 {
+    if scale >= 0 {
+        return value;
+    }
+    let factor = 10_i128.pow((-scale) as u32);
+    let remainder = value % factor;
+    let base = value - remainder;
+
+    if value >= 0 {
+        if remainder * 2 >= factor {
+            base + factor
+        } else {
+            base
+        }
+    } else if remainder.abs() * 2 >= factor {
+        base - factor
+    } else {
+        base
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use datafusion::{
+        common::{Result, ScalarValue, cast::*},
+        physical_plan::ColumnarValue,
+    };
+
+    use super::*;
+
+    /// Unit test for `spark_round()` verifying correct rounding behavior on
+    /// Decimal128 inputs.
+    #[test]
+    fn test_round_decimal() -> Result<()> {
+        let arr = Arc::new(
+            Decimal128Array::from_iter_values([12345_i128, -67895_i128])
+                .with_precision_and_scale(10, 2)?,
+        );
+
+        let result = spark_round(&[
+            ColumnarValue::Array(arr.clone()),
+            ColumnarValue::Scalar(ScalarValue::Int32(Some(1))),
+        ])?;
+
+        assert!(matches!(result, ColumnarValue::Array(_)));
+
+        let out = result.into_array(2)?;
+        let arr = as_decimal128_array(&out)?;
+        let values: Vec<_> = arr.iter().collect();
+        assert_eq!(values, vec![Some(12350_i128), Some(-67900_i128)]);
+
+        Ok(())
+    }
+
+    /// Unit test for `spark_round()` verifying correct rounding behavior
+    /// when a **negative scale** is provided (i.e., rounding to tens, 
hundreds,
+    /// etc.).
+    #[test]
+    fn test_round_negative_scale() -> Result<()> {
+        let arr = Arc::new(Float64Array::from(vec![Some(123.45), 
Some(-678.9)]));
+        let result = spark_round(&[
+            ColumnarValue::Array(arr),
+            ColumnarValue::Scalar(ScalarValue::Int32(Some(-1))),
+        ])?;
+
+        let out = result.into_array(2)?;
+        let out = as_float64_array(&out)?;
+        let v: Vec<_> = out.iter().collect();
+
+        assert_eq!(v, vec![Some(120.0), Some(-680.0)]);
+        Ok(())
+    }
+
+    /// Unit test for `spark_round()` verifying rounding of Float64 values to a
+    /// positive decimal
+    #[test]
+    fn test_round_float() -> Result<()> {
+        let arr = Arc::new(Float64Array::from(vec![
+            Some(1.2345),
+            Some(-2.3456),
+            Some(0.5),
+            Some(-0.5),
+            None,
+        ]));
+
+        let result = spark_round(&[
+            ColumnarValue::Array(arr),
+            ColumnarValue::Scalar(ScalarValue::Int32(Some(2))),
+        ])?;
+
+        let out = result.into_array(5)?;
+        let out = as_float64_array(&out)?;
+        let v: Vec<_> = out.iter().collect();
+
+        assert_eq!(
+            v,
+            vec![Some(1.23), Some(-2.35), Some(0.5), Some(-0.5), None]
+        );
+        Ok(())
+    }
+
+    /// Unit test for `spark_round()` verifying Spark-style half-away-from-zero
+    /// rounding on scalar Float64.
+    #[test]
+    fn test_round_scalar() -> Result<()> {
+        let s = ColumnarValue::Scalar(ScalarValue::Float64(Some(-1.5)));
+        let result = spark_round(&[s, 
ColumnarValue::Scalar(ScalarValue::Int32(Some(0)))])?;
+        match result {
+            ColumnarValue::Scalar(ScalarValue::Float64(Some(v))) => 
assert_eq!(v, -2.0),
+            _ => panic!("wrong result"),
+        }
+        Ok(())
+    }
+
+    /// Tests Spark-compatible rounding for 16-bit integer (Short).
+    #[test]
+    fn test_spark_round_short_pi_scales() -> Result<()> {
+        let short_pi: i16 = 31415;
+        let expected: Vec<i16> = vec![
+            0, 0, 30000, 31000, 31400, 31420, 31415, 31415, 31415, 31415, 
31415, 31415, 31415,
+        ];
+
+        for (i, scale) in (-6..=6).enumerate() {
+            let result = spark_round(&[
+                ColumnarValue::Scalar(ScalarValue::Int16(Some(short_pi))),
+                ColumnarValue::Scalar(ScalarValue::Int32(Some(scale))),
+            ])?;
+
+            let arr = result.into_array(1)?;
+            let out = as_int16_array(&arr)?;
+            assert_eq!(out.value(0), expected[i]);
+        }
+        Ok(())
+    }
+
+    /// Tests Spark-compatible rounding for Float32.
+    #[test]
+    fn test_spark_round_float_pi_scales() -> Result<()> {
+        let float_pi = 3.1415_f32;
+        let expected = vec![
+            0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.141, 3.1415, 
3.1415, 3.1415,
+        ];
+
+        for (i, scale) in (-6..=6).enumerate() {
+            let result = spark_round(&[
+                ColumnarValue::Scalar(ScalarValue::Float32(Some(float_pi))),
+                ColumnarValue::Scalar(ScalarValue::Int32(Some(scale))),
+            ])?;
+
+            let arr = result.into_array(1)?;
+            let out = as_float32_array(&arr)?;
+            assert!(
+                (out.value(0) - expected[i]).abs() < 1e-6,
+                "Mismatch at scale {scale}: expected {}, got {}",
+                expected[i],
+                out.value(0)
+            );
+        }
+        Ok(())
+    }
+
+    /// Tests Spark-compatible rounding for Float64 (Double precision).
+    #[test]
+    fn test_spark_round_double_pi_scales() -> Result<()> {
+        let double_pi = std::f64::consts::PI;
+        let expected = vec![
+            0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142, 3.1416, 
3.14159, 3.141593,
+        ];
+
+        for (i, scale) in (-6..=6).enumerate() {
+            let result = spark_round(&[
+                ColumnarValue::Scalar(ScalarValue::Float64(Some(double_pi))),
+                ColumnarValue::Scalar(ScalarValue::Int32(Some(scale))),
+            ])?;
+
+            let arr = result.into_array(1)?;
+            let out = as_float64_array(&arr)?;
+            let actual = out.value(0);
+            assert!(
+                (actual - expected[i]).abs() < 1e-9,
+                "Mismatch at scale {scale}: expected {}, got {}",
+                expected[i],
+                actual
+            );
+        }
+        Ok(())
+    }
+
+    /// Tests Spark-compatible rounding for Int32.
+    #[test]
+    fn test_spark_round_int_pi_scales() -> Result<()> {
+        let int_pi = 314159265_i32;
+        let expected = vec![
+            314000000, 314200000, 314160000, 314159000, 314159300, 314159270, 
314159265, 314159265,
+            314159265, 314159265, 314159265, 314159265, 314159265,
+        ];
+
+        for (i, scale) in (-6..=6).enumerate() {
+            let result = spark_round(&[
+                ColumnarValue::Scalar(ScalarValue::Int32(Some(int_pi))),
+                ColumnarValue::Scalar(ScalarValue::Int32(Some(scale))),
+            ])?;
+
+            let arr = result.into_array(1)?;
+            let out = as_int32_array(&arr)?;
+            assert_eq!(
+                out.value(0),
+                expected[i],
+                "Mismatch at scale {scale}: expected {}, got {}",
+                expected[i],
+                out.value(0)
+            );
+        }
+        Ok(())
+    }
+
+    /// Tests Spark-compatible rounding for Decimal128 (Long in Spark).
+    #[test]
+    fn test_spark_round_long_pi_scales() -> Result<()> {
+        let long_pi = 31415926535897932_i128;
+        let expected = vec![
+            31415926536000000,
+            31415926535900000,
+            31415926535900000,
+            31415926535898000,
+            31415926535897900,
+            31415926535897930,
+            31415926535897932,
+            31415926535897932,
+            31415926535897932,
+            31415926535897932,
+            31415926535897932,
+            31415926535897932,
+            31415926535897932,
+        ];
+
+        for (i, scale) in (-6..=6).enumerate() {
+            let result = spark_round(&[
+                ColumnarValue::Scalar(ScalarValue::Decimal128(Some(long_pi), 
38, 0)),
+                ColumnarValue::Scalar(ScalarValue::Int32(Some(scale))),
+            ])?;
+
+            let arr = result.into_array(1)?;
+            let out = as_decimal128_array(&arr)?;
+            assert_eq!(
+                out.value(0),
+                expected[i],
+                "Mismatch at scale {scale}: expected {}, got {}",
+                expected[i],
+                out.value(0)
+            );
+        }
+        Ok(())
+    }
+}
diff --git 
a/spark-extension-shims-spark3/src/test/scala/org/apache/spark/sql/auron/AuronFunctionSuite.scala
 
b/spark-extension-shims-spark3/src/test/scala/org/apache/spark/sql/auron/AuronFunctionSuite.scala
index 2f7e5707..fb46cd08 100644
--- 
a/spark-extension-shims-spark3/src/test/scala/org/apache/spark/sql/auron/AuronFunctionSuite.scala
+++ 
b/spark-extension-shims-spark3/src/test/scala/org/apache/spark/sql/auron/AuronFunctionSuite.scala
@@ -120,4 +120,158 @@ class AuronFunctionSuite
       checkAnswer(df, Seq(Row("uron Spark SQ")))
     }
   }
+
+  test("round function with varying scales for intPi") {
+    withTable("t2") {
+      sql("CREATE TABLE t2 (c1 INT) USING parquet")
+
+      val intPi: Int = 314159265
+      sql(s"INSERT INTO t2 VALUES($intPi)")
+
+      val scales = -6 to 6
+      val expectedResults = Map(
+        -6 -> 314000000,
+        -5 -> 314200000,
+        -4 -> 314160000,
+        -3 -> 314159000,
+        -2 -> 314159300,
+        -1 -> 314159270,
+        0 -> 314159265,
+        1 -> 314159265,
+        2 -> 314159265,
+        3 -> 314159265,
+        4 -> 314159265,
+        5 -> 314159265,
+        6 -> 314159265)
+
+      scales.foreach { scale =>
+        val df = sql(s"SELECT round(c1, $scale) AS xx FROM t2")
+        val expected = expectedResults(scale)
+        checkAnswer(df, Seq(Row(expected)))
+      }
+    }
+  }
+
+  test("round function with varying scales for doublePi") {
+    withTable("t1") {
+      sql("create table t1(c1 double) using parquet")
+
+      val doublePi: Double = math.Pi
+      sql(s"insert into t1 values($doublePi)")
+      val scales = -6 to 6
+      val expectedResults = Map(
+        -6 -> 0.0,
+        -5 -> 0.0,
+        -4 -> 0.0,
+        -3 -> 0.0,
+        -2 -> 0.0,
+        -1 -> 0.0,
+        0 -> 3.0,
+        1 -> 3.1,
+        2 -> 3.14,
+        3 -> 3.142,
+        4 -> 3.1416,
+        5 -> 3.14159,
+        6 -> 3.141593)
+
+      scales.foreach { scale =>
+        val df = sql(s"select round(c1, $scale) from t1")
+        val expected = expectedResults(scale)
+        checkAnswer(df, Seq(Row(expected)))
+      }
+    }
+  }
+
+  test("round function with varying scales for floatPi") {
+    withTable("t1") {
+      sql("CREATE TABLE t1 (c1 FLOAT) USING parquet")
+
+      val floatPi: Float = 3.1415f
+      sql(s"INSERT INTO t1 VALUES($floatPi)")
+
+      val scales = -6 to 6
+      val expectedResults = Map(
+        -6 -> 0.0f,
+        -5 -> 0.0f,
+        -4 -> 0.0f,
+        -3 -> 0.0f,
+        -2 -> 0.0f,
+        -1 -> 0.0f,
+        0 -> 3.0f,
+        1 -> 3.1f,
+        2 -> 3.14f,
+        3 -> 3.142f,
+        4 -> 3.1415f,
+        5 -> 3.1415f,
+        6 -> 3.1415f)
+
+      scales.foreach { scale =>
+        val df = sql(s"select round(c1, $scale) from t1")
+        val expected = expectedResults(scale)
+        checkAnswer(df, Seq(Row(expected)))
+      }
+    }
+  }
+
+  test("round function with varying scales for shortPi") {
+    withTable("t1") {
+      sql("CREATE TABLE t1 (c1 SMALLINT) USING parquet")
+
+      val shortPi: Short = 31415
+      sql(s"INSERT INTO t1 VALUES($shortPi)")
+
+      val scales = -6 to 6
+      val expectedResults = Map(
+        -6 -> 0.toShort,
+        -5 -> 0.toShort,
+        -4 -> 30000.toShort,
+        -3 -> 31000.toShort,
+        -2 -> 31400.toShort,
+        -1 -> 31420.toShort,
+        0 -> 31415.toShort,
+        1 -> 31415.toShort,
+        2 -> 31415.toShort,
+        3 -> 31415.toShort,
+        4 -> 31415.toShort,
+        5 -> 31415.toShort,
+        6 -> 31415.toShort)
+
+      scales.foreach { scale =>
+        val df = sql(s"SELECT round(c1, $scale) FROM t1")
+        val expected = expectedResults(scale)
+        checkAnswer(df, Seq(Row(expected)))
+      }
+    }
+  }
+
+  test("round function with varying scales for longPi") {
+    withTable("t1") {
+      sql("CREATE TABLE t1 (c1 BIGINT) USING parquet")
+
+      val longPi: Long = 31415926535897932L
+      sql(s"INSERT INTO t1 VALUES($longPi)")
+
+      val scales = -6 to 6
+      val expectedResults = Map(
+        -6 -> 31415926536000000L,
+        -5 -> 31415926535900000L,
+        -4 -> 31415926535900000L,
+        -3 -> 31415926535898000L,
+        -2 -> 31415926535897900L,
+        -1 -> 31415926535897930L,
+        0 -> 31415926535897932L,
+        1 -> 31415926535897932L,
+        2 -> 31415926535897932L,
+        3 -> 31415926535897932L,
+        4 -> 31415926535897932L,
+        5 -> 31415926535897932L,
+        6 -> 31415926535897932L)
+
+      scales.foreach { scale =>
+        val df = sql(s"SELECT round(c1, $scale) FROM t1")
+        val expected = expectedResults(scale)
+        checkAnswer(df, Seq(Row(expected)))
+      }
+    }
+  }
 }
diff --git 
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
 
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
index 3ec49807..72f96bf0 100644
--- 
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
+++ 
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
@@ -815,9 +815,13 @@ object NativeConverters extends Logging {
         buildScalarFunction(pb.ScalarFunction.Factorial, e.children, 
e.dataType)
       case e: Hex => buildScalarFunction(pb.ScalarFunction.Hex, e.children, 
e.dataType)
 
-      // TODO: datafusion's round() has different behavior from spark
-      // case e @ Round(_1, Literal(n: Int, _)) if 
_1.dataType.isInstanceOf[FractionalType] =>
-      //   buildScalarFunction(pb.ScalarFunction.Round, Seq(_1, 
Literal(n.toLong)), e.dataType)
+      case e: Round =>
+        e.scale match {
+          case Literal(n: Int, _) =>
+            buildExtScalarFunction("Round", Seq(e.child, Literal(n.toLong)), 
e.dataType)
+          case _ =>
+            buildExtScalarFunction("Round", Seq(e.child, Literal(0L)), 
e.dataType)
+        }
 
       case e: Signum => buildScalarFunction(pb.ScalarFunction.Signum, 
e.children, e.dataType)
       case e: Abs if e.dataType.isInstanceOf[FloatType] || 
e.dataType.isInstanceOf[DoubleType] =>

Reply via email to