This is an automated email from the ASF dual-hosted git repository.

richox pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/auron.git


The following commit(s) were added to refs/heads/master by this push:
     new 9d033236 [AURON #1502] Implement native function of bround. (#1706)
9d033236 is described below

commit 9d033236d0ca3565ca57a269b43fa76bfc1d6f3f
Author: slfan1989 <[email protected]>
AuthorDate: Sat Jan 17 14:32:19 2026 +0800

    [AURON #1502] Implement native function of bround. (#1706)
    
    <!--
    Thanks for sending a pull request! Please keep the following tips in
    mind:
    - Start the PR title with the related issue ID, e.g. '[AURON #XXXX]
    Short summary...'.
    - Make your PR title clear and descriptive, summarizing what this PR
    changes.
      - Provide a concise example to reproduce the issue, if possible.
      - Keep the PR description up to date with all changes.
    -->
    
    ### Which issue does this PR close?
    
    Closes #1502.
    
    ### Rationale for this change
    
    To achieve full compatibility with Spark’s numeric functions, we should
    implement bround() with the following characteristics:
    
    > Expected behavior
    
    Function name: `bround(expr, scale)`
    Rounding mode: `HALF_EVEN (bankers’ rounding)`
    Example:
    - bround(2.5) → 2.0
    - bround(3.5) → 4.0
    - bround(-2.5) → -2.0
    
    Supports: `FLOAT`, `DOUBLE`, `DECIMAL`, `INT16/32/64`
    
    - Handles negative scales: e.g., `bround(123.4, -1)` = 120
    - Null-safe: should return NULL if input is NULL
    - Array and scalar inputs: consistent with current round()
    implementation
    
    
    ### What changes are included in this PR?
    
    This PR adds full support for the bround() function, which performs
    bankers’ rounding (HALF_EVEN).
    The following changes are included:
    
    - Added native implementation of spark_bround() in the expression layer.
    - Added BRound expression support in NativeConverters for proper Spark →
    native translation.
    - Added comprehensive unit tests to verify correctness for:
       - Positive and negative numbers
       - Different scale values (including negative scales)
       -  Various numeric types (FLOAT, DOUBLE, DECIMAL, INT, LONG)
       - Null input handling
    
    ### Are there any user-facing changes?
    
    No.
    
    ### How was this patch tested?
    
    CI.
    
    ---------
    
    Signed-off-by: slfan1989 <[email protected]>
---
 native-engine/datafusion-ext-functions/src/lib.rs  |   2 +
 .../datafusion-ext-functions/src/spark_bround.rs   | 515 +++++++++++++++++++++
 .../org/apache/auron/AuronFunctionSuite.scala      | 163 +++++++
 .../scala/org/apache/auron/AuronQuerySuite.scala   |   2 +-
 .../apache/spark/sql/auron/NativeConverters.scala  |   8 +
 5 files changed, 689 insertions(+), 1 deletion(-)

diff --git a/native-engine/datafusion-ext-functions/src/lib.rs 
b/native-engine/datafusion-ext-functions/src/lib.rs
index 0117359e..db297f29 100644
--- a/native-engine/datafusion-ext-functions/src/lib.rs
+++ b/native-engine/datafusion-ext-functions/src/lib.rs
@@ -19,6 +19,7 @@ use datafusion::{common::Result, 
logical_expr::ScalarFunctionImplementation};
 use datafusion_ext_commons::df_unimplemented_err;
 
 mod brickhouse;
+mod spark_bround;
 mod spark_check_overflow;
 mod spark_crypto;
 mod spark_dates;
@@ -78,6 +79,7 @@ pub fn create_auron_ext_function(
         "Spark_Second" => Arc::new(spark_dates::spark_second),
         "Spark_BrickhouseArrayUnion" => 
Arc::new(brickhouse::array_union::array_union),
         "Spark_Round" => Arc::new(spark_round::spark_round),
+        "Spark_BRound" => Arc::new(spark_bround::spark_bround),
         "Spark_NormalizeNanAndZero" => {
             
Arc::new(spark_normalize_nan_and_zero::spark_normalize_nan_and_zero)
         }
diff --git a/native-engine/datafusion-ext-functions/src/spark_bround.rs 
b/native-engine/datafusion-ext-functions/src/spark_bround.rs
new file mode 100644
index 00000000..4cb0bd99
--- /dev/null
+++ b/native-engine/datafusion-ext-functions/src/spark_bround.rs
@@ -0,0 +1,515 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+use std::sync::Arc;
+
+use arrow::{
+    array::{Decimal128Array, Float32Array, Float64Array, Int16Array, 
Int32Array, Int64Array},
+    datatypes::DataType,
+};
+use datafusion::{
+    common::{
+        DataFusionError, Result, ScalarValue,
+        cast::{
+            as_decimal128_array, as_float32_array, as_float64_array, 
as_int16_array,
+            as_int32_array, as_int64_array,
+        },
+    },
+    physical_plan::ColumnarValue,
+};
+
+/// Spark-style `bround(expr, scale)` implementation (HALF_EVEN).
+/// - HALF_EVEN (banker's rounding): ties go to the nearest even
+/// - Supports negative scales (e.g., bround(123.4, -1) = 120)
+/// - Handles Float, Decimal, Int16/32/64
+/// - Null-safe
+pub fn spark_bround(args: &[ColumnarValue]) -> Result<ColumnarValue> {
+    if args.len() != 2 {
+        return Err(DataFusionError::Execution(
+            "spark_bround() requires two arguments".to_string(),
+        ));
+    }
+
+    let value = &args[0];
+    let scale_val = &args[1];
+
+    // Parse scale (must be a literal integer)
+    let scale = match scale_val {
+        ColumnarValue::Scalar(ScalarValue::Int32(Some(n))) => *n,
+        ColumnarValue::Scalar(ScalarValue::Int64(Some(n))) => *n as i32,
+        _ => {
+            return Err(DataFusionError::Execution(
+                "spark_bround() scale must be a literal integer".to_string(),
+            ));
+        }
+    };
+
+    match value {
+        // ---------- Array input ----------
+        ColumnarValue::Array(arr) => match arr.data_type() {
+            DataType::Decimal128(..) => {
+                let dec_arr = as_decimal128_array(arr)?;
+                let precision = dec_arr.precision();
+                let in_scale = dec_arr.scale();
+
+                let result = 
Decimal128Array::from_iter(dec_arr.iter().map(|opt| {
+                    opt.map(|v| {
+                        let diff = in_scale as i32 - scale;
+                        if diff >= 0 {
+                            // reduce fractional digits by diff using HALF_EVEN
+                            round_i128_half_even(v, -diff)
+                        } else {
+                            // increasing scale (more fractional digits): 
multiply
+                            v * 10_i128.pow((-diff) as u32)
+                        }
+                    })
+                }))
+                .with_precision_and_scale(precision, in_scale)
+                .map_err(|e| DataFusionError::Execution(e.to_string()))?;
+
+                Ok(ColumnarValue::Array(Arc::new(result)))
+            }
+
+            DataType::Int64 => 
Ok(ColumnarValue::Array(Arc::new(Int64Array::from_iter(
+                as_int64_array(arr)?
+                    .iter()
+                    .map(|opt| opt.map(|v| round_i128_half_even(v as i128, 
scale) as i64)),
+            )))),
+
+            DataType::Int32 => 
Ok(ColumnarValue::Array(Arc::new(Int32Array::from_iter(
+                as_int32_array(arr)?
+                    .iter()
+                    .map(|opt| opt.map(|v| round_i128_half_even(v as i128, 
scale) as i32)),
+            )))),
+
+            DataType::Int16 => 
Ok(ColumnarValue::Array(Arc::new(Int16Array::from_iter(
+                as_int16_array(arr)?
+                    .iter()
+                    .map(|opt| opt.map(|v| round_i128_half_even(v as i128, 
scale) as i16)),
+            )))),
+
+            DataType::Float32 => {
+                let arr = as_float32_array(arr)?;
+                let factor = 10_f32.powi(scale);
+                let result = Float32Array::from_iter(arr.iter().map(|opt| {
+                    opt.map(|v| {
+                        if v.is_nan() || v.is_infinite() {
+                            v
+                        } else {
+                            round_half_even_f32(v * factor) / factor
+                        }
+                    })
+                }));
+                Ok(ColumnarValue::Array(Arc::new(result)))
+            }
+
+            // Float64 fallback
+            _ => {
+                let arr = as_float64_array(arr)?;
+                let factor = 10_f64.powi(scale);
+                let result = Float64Array::from_iter(arr.iter().map(|opt| {
+                    opt.map(|v| {
+                        if v.is_nan() || v.is_infinite() {
+                            v
+                        } else {
+                            round_half_even_f64(v * factor) / factor
+                        }
+                    })
+                }));
+                Ok(ColumnarValue::Array(Arc::new(result)))
+            }
+        },
+
+        // ---------- Scalar input ----------
+        ColumnarValue::Scalar(sv) => {
+            if sv.is_null() {
+                return Ok(ColumnarValue::Scalar(sv.clone()));
+            }
+
+            Ok(match sv {
+                ScalarValue::Float64(Some(v)) => {
+                    let f = 10_f64.powi(scale);
+                    ColumnarValue::Scalar(ScalarValue::Float64(Some(
+                        round_half_even_f64(v * f) / f,
+                    )))
+                }
+                ScalarValue::Float32(Some(v)) => {
+                    let f = 10_f64.powi(scale);
+                    ColumnarValue::Scalar(ScalarValue::Float32(Some(
+                        (round_half_even_f64((*v as f64) * f) / f) as f32,
+                    )))
+                }
+                ScalarValue::Int64(Some(v)) => 
ColumnarValue::Scalar(ScalarValue::Int64(Some(
+                    round_i128_half_even(*v as i128, scale) as i64,
+                ))),
+                ScalarValue::Int32(Some(v)) => 
ColumnarValue::Scalar(ScalarValue::Int32(Some(
+                    round_i128_half_even(*v as i128, scale) as i32,
+                ))),
+                ScalarValue::Int16(Some(v)) => 
ColumnarValue::Scalar(ScalarValue::Int16(Some(
+                    round_i128_half_even(*v as i128, scale) as i16,
+                ))),
+                ScalarValue::Decimal128(Some(v), p, s) => 
ColumnarValue::Scalar(
+                    ScalarValue::Decimal128(Some(round_i128_half_even(*v, 
scale)), *p, *s),
+                ),
+                _ => {
+                    return Err(DataFusionError::Execution(
+                        "Unsupported type for spark_bround()".to_string(),
+                    ));
+                }
+            })
+        }
+    }
+}
+
+/// HALF_EVEN for f64: ties go to nearest even integer
+fn round_half_even_f64(x: f64) -> f64 {
+    if x.is_nan() || x.is_infinite() {
+        return x;
+    }
+    let sign = x.signum();
+    let ax = x.abs();
+    let f = ax.floor();
+    let diff = ax - f;
+
+    let rounded = if diff > 0.5 {
+        f + 1.0
+    } else if diff < 0.5 {
+        f
+    } else {
+        // tie: choose the even integer
+        if ((f as i64) & 1) == 0 { f } else { f + 1.0 }
+    };
+
+    rounded.copysign(sign)
+}
+
+/// HALF_EVEN for f32
+fn round_half_even_f32(x: f32) -> f32 {
+    if x.is_nan() || x.is_infinite() {
+        return x;
+    }
+    let sign = x.signum();
+    let ax = x.abs();
+    let f = ax.floor();
+    let diff = ax - f;
+
+    let rounded = if diff > 0.5 {
+        f + 1.0
+    } else if diff < 0.5 {
+        f
+    } else {
+        if ((f as i32) & 1) == 0 { f } else { f + 1.0 }
+    };
+
+    rounded.copysign(sign)
+}
+
+/// Integer rounding using Spark's HALF_EVEN logic without float precision 
loss.
+/// `scale < 0` means rounding to tens/hundreds/... (10^(-scale)).
+fn round_i128_half_even(value: i128, scale: i32) -> i128 {
+    if scale >= 0 {
+        return value;
+    }
+    let factor = 10_i128.pow((-scale) as u32);
+    let remainder = value % factor;
+    let base = value - remainder;
+    let twice = remainder.abs() * 2;
+
+    if twice > factor {
+        if value >= 0 {
+            base + factor
+        } else {
+            base - factor
+        }
+    } else if twice < factor {
+        base
+    } else {
+        // tie: choose the even multiple of `factor`
+        let q = base / factor; // exact integer
+        if q % 2 == 0 {
+            base
+        } else if value >= 0 {
+            base + factor
+        } else {
+            base - factor
+        }
+    }
+}
+
+#[cfg(test)]
+mod bround_tests {
+    use datafusion::{common::Result, physical_plan::ColumnarValue};
+
+    use super::*;
+
+    // Test: float64 data type, check "HALF_EVEN" rounding rule (banker's 
rounding)
+    #[test]
+    fn test_bround_float64_ties() -> Result<()> {
+        let arr = Arc::new(Float64Array::from(vec![
+            Some(1.5),
+            Some(2.5),
+            Some(-0.5),
+            Some(-1.5),
+            Some(0.5000000000),
+        ]));
+        let out = spark_bround(&[
+            ColumnarValue::Array(arr),
+            ColumnarValue::Scalar(ScalarValue::Int32(Some(0))),
+        ])?
+        .into_array(5)?;
+        let a = as_float64_array(&out)?;
+        let v: Vec<_> = a.iter().collect();
+        // HALF_EVEN: 1.5→2, 2.5→2, -0.5→0, -1.5→-2, 0.5→0
+        assert_eq!(
+            v,
+            vec![Some(2.0), Some(2.0), Some(0.0), Some(-2.0), Some(0.0)]
+        );
+        Ok(())
+    }
+
+    // Test: float64 data type, handle negative scale values
+    #[test]
+    fn test_bround_negative_scale_float() -> Result<()> {
+        let arr = Arc::new(Float64Array::from(vec![
+            Some(125.0),
+            Some(135.0),
+            Some(145.0),
+            Some(155.0),
+        ]));
+        let out = spark_bround(&[
+            ColumnarValue::Array(arr),
+            ColumnarValue::Scalar(ScalarValue::Int32(Some(-1))),
+        ])?
+        .into_array(4)?;
+        let a = as_float64_array(&out)?;
+        let v: Vec<_> = a.iter().collect();
+        assert_eq!(v, vec![Some(120.0), Some(140.0), Some(140.0), 
Some(160.0)]);
+        Ok(())
+    }
+
+    // Test: bround on Decimal array
+    #[test]
+    fn test_bround_decimal_array() -> Result<()> {
+        let arr = Arc::new(
+            Decimal128Array::from_iter_values([12345_i128, 67895_i128])
+                .with_precision_and_scale(10, 2)?,
+        );
+        let out = spark_bround(&[
+            ColumnarValue::Array(arr.clone()),
+            ColumnarValue::Scalar(ScalarValue::Int32(Some(1))),
+        ])?
+        .into_array(2)?;
+        let dec = as_decimal128_array(&out)?;
+        let vals: Vec<_> = dec.iter().collect();
+        assert_eq!(vals, vec![Some(12340_i128), Some(67900_i128)]);
+        Ok(())
+    }
+
+    /// scales = -6..=6
+    fn scales_range() -> impl Iterator<Item = i32> {
+        -6..=6
+    }
+
+    // Test: double data type π value across different scales
+    #[test]
+    fn test_bround_double_pi_scales() -> Result<()> {
+        let double_pi = std::f64::consts::PI;
+        let expected = vec![
+            0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142, 3.1416, 
3.14159, 3.141593,
+        ];
+
+        for (i, scale) in scales_range().enumerate() {
+            let out = spark_bround(&[
+                ColumnarValue::Scalar(ScalarValue::Float64(Some(double_pi))),
+                ColumnarValue::Scalar(ScalarValue::Int32(Some(scale))),
+            ])?
+            .into_array(1)?;
+            let arr = as_float64_array(&out)?;
+            let actual = arr.value(0);
+            assert!(
+                (actual - expected[i]).abs() < 1e-9,
+                "scale={scale}: expected {}, got {}",
+                expected[i],
+                actual
+            );
+        }
+        Ok(())
+    }
+
+    // Test: float data type π value across different scales
+    #[test]
+    fn test_bround_float_pi_scales() -> Result<()> {
+        let float_pi = 3.1415_f32;
+        let expected = vec![
+            0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.141, 3.1415, 
3.1415, 3.1415,
+        ];
+
+        for (i, scale) in scales_range().enumerate() {
+            let out = spark_bround(&[
+                ColumnarValue::Scalar(ScalarValue::Float32(Some(float_pi))),
+                ColumnarValue::Scalar(ScalarValue::Int32(Some(scale))),
+            ])?
+            .into_array(1)?;
+            let arr = as_float32_array(&out)?;
+            let actual = arr.value(0);
+            assert!(
+                (actual - expected[i]).abs() < 1e-6,
+                "scale={scale}: expected {}, got {}",
+                expected[i],
+                actual
+            );
+        }
+        Ok(())
+    }
+
+    // Test: short data type π value across different scales
+    #[test]
+    fn test_bround_short_pi_scales() -> Result<()> {
+        let short_pi: i16 = 31415;
+        let expected: Vec<i16> = vec![0, 0, 30000, 31000, 31400, 31420]
+            .into_iter()
+            .chain(std::iter::repeat(31415).take(7))
+            .collect();
+
+        for (i, scale) in scales_range().enumerate() {
+            let out = spark_bround(&[
+                ColumnarValue::Scalar(ScalarValue::Int16(Some(short_pi))),
+                ColumnarValue::Scalar(ScalarValue::Int32(Some(scale))),
+            ])?
+            .into_array(1)?;
+            let arr = as_int16_array(&out)?;
+            assert_eq!(
+                arr.value(0),
+                expected[i],
+                "scale={scale}: expected {}, got {}",
+                expected[i],
+                arr.value(0)
+            );
+        }
+        Ok(())
+    }
+
+    // Test: int data type π value across different scales
+    #[test]
+    fn test_bround_int_pi_scales() -> Result<()> {
+        let int_pi: i32 = 314_159_265;
+        let expected: Vec<i32> = vec![
+            314000000, 314200000, 314160000, 314159000, 314159300, 314159260,
+        ]
+        .into_iter()
+        .chain(std::iter::repeat(314_159_265).take(7))
+        .collect();
+
+        for (i, scale) in scales_range().enumerate() {
+            let out = spark_bround(&[
+                ColumnarValue::Scalar(ScalarValue::Int32(Some(int_pi))),
+                ColumnarValue::Scalar(ScalarValue::Int32(Some(scale))),
+            ])?
+            .into_array(1)?;
+            let arr = as_int32_array(&out)?;
+            assert_eq!(
+                arr.value(0),
+                expected[i],
+                "scale={scale}: expected {}, got {}",
+                expected[i],
+                arr.value(0)
+            );
+        }
+        Ok(())
+    }
+
+    // Test: long data type π value across different scales
+    #[test]
+    fn test_bround_long_pi_scales() -> Result<()> {
+        let long_pi: i128 = 31_415_926_535_897_932_i128;
+        let expected: Vec<i128> = vec![
+            31_415_926_536_000_000,
+            31_415_926_535_900_000,
+            31_415_926_535_900_000,
+            31_415_926_535_898_000,
+            31_415_926_535_897_900,
+            31_415_926_535_897_930,
+        ]
+        .into_iter()
+        .chain(std::iter::repeat(31_415_926_535_897_932_i128).take(7))
+        .collect();
+
+        for (i, scale) in scales_range().enumerate() {
+            let out = spark_bround(&[
+                ColumnarValue::Scalar(ScalarValue::Decimal128(Some(long_pi), 
38, 0)),
+                ColumnarValue::Scalar(ScalarValue::Int32(Some(scale))),
+            ])?
+            .into_array(1)?;
+            let arr = as_decimal128_array(&out)?;
+            assert_eq!(
+                arr.value(0),
+                expected[i],
+                "scale={scale}: expected {}, got {}",
+                expected[i],
+                arr.value(0)
+            );
+        }
+        Ok(())
+    }
+
+    // Test: bround on Decimal array when scale is less than or equal to the
+    // original scale
+    #[test]
+    fn test_bround_decimal_array_scale_le_in_scale() -> Result<()> {
+        let arr = Arc::new(
+            Decimal128Array::from_iter_values([12345_i128, 67895_i128])
+                .with_precision_and_scale(10, 2)?,
+        );
+        let out = spark_bround(&[
+            ColumnarValue::Array(arr),
+            ColumnarValue::Scalar(ScalarValue::Int32(Some(1))),
+        ])?
+        .into_array(2)?;
+        let a = as_decimal128_array(&out)?;
+        let vals: Vec<_> = a.iter().collect();
+        assert_eq!(vals, vec![Some(12340_i128), Some(67900_i128)]);
+        Ok(())
+    }
+
+    // Test: bround with "HALF_EVEN" rounding (banker's rounding) for both 
ties and
+    // signs
+    #[test]
+    fn test_bround_half_even_ties_and_signs() -> Result<()> {
+        let cases = vec![
+            (ScalarValue::Float64(Some(2.5)), 0, 2.0),
+            (ScalarValue::Float64(Some(3.5)), 0, 4.0),
+            (ScalarValue::Float64(Some(-2.5)), 0, -2.0),
+            (ScalarValue::Float64(Some(-3.5)), 0, -4.0),
+            (ScalarValue::Float64(Some(-0.35)), 1, -0.4),
+            (ScalarValue::Float64(Some(-35.0)), -1, -40.0),
+        ];
+
+        for (sv, scale, expected) in cases {
+            let out = spark_bround(&[
+                ColumnarValue::Scalar(sv),
+                ColumnarValue::Scalar(ScalarValue::Int32(Some(scale))),
+            ])?
+            .into_array(1)?;
+            let a = as_float64_array(&out)?;
+            assert!(
+                (a.value(0) - expected).abs() < 1e-12,
+                "scale={scale}: expected {}, got {}",
+                expected,
+                a.value(0)
+            );
+        }
+        Ok(())
+    }
+}
diff --git 
a/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala
 
b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala
index 18a798bc..c7e2f6a9 100644
--- 
a/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala
+++ 
b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala
@@ -489,4 +489,167 @@ class AuronFunctionSuite extends AuronQueryTest with 
BaseAuronSQLSuite {
       }
     }
   }
+
+  test("bround function with varying scales for doublePi") {
+    withTable("t1") {
+      val doublePi: Double = math.Pi
+      sql(s"CREATE TABLE t1(c1 DOUBLE) USING parquet")
+      sql(s"INSERT INTO t1 VALUES($doublePi)")
+
+      val scales = -6 to 6
+      val expectedResults = Map(
+        -6 -> 0.0,
+        -5 -> 0.0,
+        -4 -> 0.0,
+        -3 -> 0.0,
+        -2 -> 0.0,
+        -1 -> 0.0,
+        0 -> 3.0,
+        1 -> 3.1,
+        2 -> 3.14,
+        3 -> 3.142,
+        4 -> 3.1416,
+        5 -> 3.14159,
+        6 -> 3.141593)
+
+      scales.foreach { scale =>
+        val df = sql(s"SELECT bround(c1, $scale) FROM t1")
+        val expected = expectedResults(scale)
+        checkAnswer(df, Seq(Row(expected)))
+      }
+    }
+  }
+
+  test("bround function with varying scales for floatPi") {
+    withTable("t1") {
+      val floatPi: Float = 3.1415f
+      sql(s"CREATE TABLE t1(c1 FLOAT) USING parquet")
+      sql(s"INSERT INTO t1 VALUES($floatPi)")
+
+      val scales = -6 to 6
+      val expectedResults = Map(
+        -6 -> 0.0f,
+        -5 -> 0.0f,
+        -4 -> 0.0f,
+        -3 -> 0.0f,
+        -2 -> 0.0f,
+        -1 -> 0.0f,
+        0 -> 3.0f,
+        1 -> 3.1f,
+        2 -> 3.14f,
+        3 -> 3.142f,
+        4 -> 3.1415f,
+        5 -> 3.1415f,
+        6 -> 3.1415f)
+
+      scales.foreach { scale =>
+        val df = sql(s"SELECT bround(c1, $scale) FROM t1")
+        val expected = expectedResults(scale)
+        checkAnswer(df, Seq(Row(expected)))
+      }
+    }
+  }
+
+  test("bround function with varying scales for shortPi") {
+    withTable("t1") {
+      val shortPi: Short = 31415
+      sql(s"CREATE TABLE t1(c1 SMALLINT) USING parquet")
+      sql(s"INSERT INTO t1 VALUES($shortPi)")
+
+      val scales = -6 to 6
+      val expectedResults = Map(
+        -6 -> 0.toShort,
+        -5 -> 0.toShort,
+        -4 -> 30000.toShort,
+        -3 -> 31000.toShort,
+        -2 -> 31400.toShort,
+        -1 -> 31420.toShort,
+        0 -> 31415.toShort,
+        1 -> 31415.toShort,
+        2 -> 31415.toShort,
+        3 -> 31415.toShort,
+        4 -> 31415.toShort,
+        5 -> 31415.toShort,
+        6 -> 31415.toShort)
+
+      scales.foreach { scale =>
+        val df = sql(s"SELECT bround(c1, $scale) FROM t1")
+        val expected = expectedResults(scale)
+        checkAnswer(df, Seq(Row(expected)))
+      }
+    }
+  }
+
+  test("bround function with varying scales for intPi") {
+    withTable("t1") {
+      val intPi: Int = 314159265
+      sql(s"CREATE TABLE t1(c1 INT) USING parquet")
+      sql(s"INSERT INTO t1 VALUES($intPi)")
+
+      val scales = -6 to 6
+      val expectedResults = Map(
+        -6 -> 314000000,
+        -5 -> 314200000,
+        -4 -> 314160000,
+        -3 -> 314159000,
+        -2 -> 314159300,
+        -1 -> 314159260,
+        0 -> 314159265,
+        1 -> 314159265,
+        2 -> 314159265,
+        3 -> 314159265,
+        4 -> 314159265,
+        5 -> 314159265,
+        6 -> 314159265)
+
+      scales.foreach { scale =>
+        val df = sql(s"SELECT bround(c1, $scale) FROM t1")
+        val expected = expectedResults(scale)
+        checkAnswer(df, Seq(Row(expected)))
+      }
+    }
+  }
+
+  test("bround function with varying scales for longPi") {
+    withTable("t1") {
+      val longPi: Long = 31415926535897932L
+      sql(s"CREATE TABLE t1(c1 BIGINT) USING parquet")
+      sql(s"INSERT INTO t1 VALUES($longPi)")
+
+      val scales = -6 to 6
+      val expectedResults = Map(
+        -6 -> 31415926536000000L,
+        -5 -> 31415926535900000L,
+        -4 -> 31415926535900000L,
+        -3 -> 31415926535898000L,
+        -2 -> 31415926535897900L,
+        -1 -> 31415926535897930L,
+        0 -> 31415926535897932L,
+        1 -> 31415926535897932L,
+        2 -> 31415926535897932L,
+        3 -> 31415926535897932L,
+        4 -> 31415926535897932L,
+        5 -> 31415926535897932L,
+        6 -> 31415926535897932L)
+
+      scales.foreach { scale =>
+        val df = sql(s"SELECT bround(c1, $scale) FROM t1")
+        val expected = expectedResults(scale)
+        checkAnswer(df, Seq(Row(expected)))
+      }
+    }
+  }
+
+  test("bround function for null values") {
+    withTable("t1") {
+      sql("CREATE TABLE t1(c1 DOUBLE) USING parquet")
+      sql("INSERT INTO t1 VALUES(NULL)")
+
+      val scales = -6 to 6
+      scales.foreach { scale =>
+        val df = sql(s"SELECT bround(c1, $scale) FROM t1")
+        checkAnswer(df, Seq(Row(null)))
+      }
+    }
+  }
 }
diff --git 
a/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronQuerySuite.scala
 
b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronQuerySuite.scala
index da335fc9..349b489a 100644
--- 
a/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronQuerySuite.scala
+++ 
b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronQuerySuite.scala
@@ -247,7 +247,7 @@ class AuronQuerySuite extends AuronQueryTest with 
BaseAuronSQLSuite with AuronSQ
                      |LOCATION '$path'
                      |""".stripMargin)
               sql("MSCK REPAIR TABLE t")
-              val expected = if (forcePositionalEvolution) {
+              if (forcePositionalEvolution) {
                 correctAnswer
               } else {
                 Seq(Row(null, 2, 1), Row(null, 4, 2), Row(null, 6, 3), 
Row(null, null, 4))
diff --git 
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
 
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
index 79803555..13a627f2 100644
--- 
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
+++ 
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
@@ -862,6 +862,14 @@ object NativeConverters extends Logging {
             buildExtScalarFunction("Spark_Round", Seq(e.child, Literal(0L)), 
e.dataType)
         }
 
+      case e: BRound =>
+        e.scale match {
+          case Literal(n: Int, _) =>
+            buildExtScalarFunction("Spark_BRound", Seq(e.child, 
Literal(n.toLong)), e.dataType)
+          case _ =>
+            buildExtScalarFunction("Spark_BRound", Seq(e.child, Literal(0L)), 
e.dataType)
+        }
+
       case e: Signum => buildScalarFunction(pb.ScalarFunction.Signum, 
e.children, e.dataType)
       case e: FindInSet =>
         buildScalarFunction(pb.ScalarFunction.FindInSet, e.children, 
e.dataType)

Reply via email to