This is an automated email from the ASF dual-hosted git repository.
zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/auron.git
The following commit(s) were added to refs/heads/master by this push:
new 4c52bf44 [AURON#1327] Implement native function of `round` (#1426)
4c52bf44 is described below
commit 4c52bf4423ff5a2e9024a9c3bc2a80e039ff5079
Author: slfan1989 <[email protected]>
AuthorDate: Sat Oct 11 12:01:15 2025 +0800
[AURON#1327] Implement native function of `round` (#1426)
### Which issue does this PR close?
Closes #1327.
### Rationale for this change
`spark_round` is a Rust implementation of an Apache Spark-style round
function for the DataFusion query engine. Its primary purpose is to perform
rounding operations on numerical values, adhering to Spark's HALF_UP rounding
mode (i.e., `0.5` rounds to `1`, `-0.5` rounds to `-1`). It supports multiple
data types (`Float64`, `Float32`, `Int16`, `Int32`, `Int64`, `Decimal128`) and
can handle negative precision and null values.
### What changes are included in this PR?
- We implemented the Round function following Spark’s HALF_UP rounding
semantics,
ensuring full behavioral alignment with Spark SQL.
- For validation, we directly reused the unit tests from
`MathExpressionsSuite#round/bround`,
comparing our implementation against Spark’s native results using:
```
checkEvaluation(Round(doublePi, scale), doubleResults(i), EmptyRow)
checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow)
checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow)
checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow)
checkEvaluation(Round(floatPi, scale), floatResults(i), EmptyRow)
```
- We also added additional boundary test cases to ensure that spark_round
behaves correctly under edge conditions such as large numbers, small
numbers, and negative scales.
### Are there any user-facing changes?
No.
### How was this patch tested?
Unit Test.
---
native-engine/datafusion-ext-functions/Cargo.toml | 2 +-
native-engine/datafusion-ext-functions/src/lib.rs | 2 +
.../datafusion-ext-functions/src/spark_round.rs | 447 +++++++++++++++++++++
.../spark/sql/auron/AuronFunctionSuite.scala | 154 +++++++
.../apache/spark/sql/auron/NativeConverters.scala | 10 +-
5 files changed, 611 insertions(+), 4 deletions(-)
diff --git a/native-engine/datafusion-ext-functions/Cargo.toml
b/native-engine/datafusion-ext-functions/Cargo.toml
index b172b09c..495e4c7a 100644
--- a/native-engine/datafusion-ext-functions/Cargo.toml
+++ b/native-engine/datafusion-ext-functions/Cargo.toml
@@ -33,4 +33,4 @@ log = { workspace = true }
num = { workspace = true }
paste = { workspace = true }
serde_json = { workspace = true }
-sonic-rs = { workspace = true }
+sonic-rs = { workspace = true }
\ No newline at end of file
diff --git a/native-engine/datafusion-ext-functions/src/lib.rs
b/native-engine/datafusion-ext-functions/src/lib.rs
index 5f311823..f2311f37 100644
--- a/native-engine/datafusion-ext-functions/src/lib.rs
+++ b/native-engine/datafusion-ext-functions/src/lib.rs
@@ -27,6 +27,7 @@ mod spark_make_array;
mod spark_make_decimal;
mod spark_normalize_nan_and_zero;
mod spark_null_if;
+mod spark_round;
mod spark_sha2;
mod spark_strings;
mod spark_unscaled_value;
@@ -60,6 +61,7 @@ pub fn create_spark_ext_function(name: &str) ->
Result<ScalarFunctionImplementat
"Month" => Arc::new(spark_dates::spark_month),
"Day" => Arc::new(spark_dates::spark_day),
"BrickhouseArrayUnion" =>
Arc::new(brickhouse::array_union::array_union),
+ "Round" => Arc::new(spark_round::spark_round),
"NormalizeNanAndZero" => {
Arc::new(spark_normalize_nan_and_zero::spark_normalize_nan_and_zero)
}
diff --git a/native-engine/datafusion-ext-functions/src/spark_round.rs
b/native-engine/datafusion-ext-functions/src/spark_round.rs
new file mode 100644
index 00000000..e8de0e57
--- /dev/null
+++ b/native-engine/datafusion-ext-functions/src/spark_round.rs
@@ -0,0 +1,447 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+use std::sync::Arc;
+
+use arrow::{
+ array::{Decimal128Array, Float32Array, Float64Array, Int16Array,
Int32Array, Int64Array},
+ datatypes::DataType,
+};
+use datafusion::{
+ common::{
+ DataFusionError, Result, ScalarValue,
+ cast::{
+ as_decimal128_array, as_float32_array, as_float64_array,
as_int16_array,
+ as_int32_array, as_int64_array,
+ },
+ },
+ physical_plan::ColumnarValue,
+};
+
+/// Spark-style `round(expr, scale)` implementation.
+/// - Uses HALF_UP rounding mode (`0.5 → 1`, `-0.5 → -1`)
+/// - Supports negative scales (e.g., `round(123.4, -1) = 120`)
+/// - Handles Float, Decimal, Int16/32/64
+/// - Null-safe
+pub fn spark_round(args: &[ColumnarValue]) -> Result<ColumnarValue> {
+ if args.len() != 2 {
+ return Err(DataFusionError::Execution(
+ "spark_round() requires two arguments".to_string(),
+ ));
+ }
+
+ let value = &args[0];
+ let scale_val = &args[1];
+
+ // Parse scale (must be a literal integer)
+ let scale = match scale_val {
+ ColumnarValue::Scalar(ScalarValue::Int32(Some(n))) => *n,
+ ColumnarValue::Scalar(ScalarValue::Int64(Some(n))) => *n as i32,
+ _ => {
+ return Err(DataFusionError::Execution(
+ "spark_round() scale must be a literal integer".to_string(),
+ ));
+ }
+ };
+
+ match value {
+ // ---------- Array input ----------
+ ColumnarValue::Array(arr) => match arr.data_type() {
+ DataType::Decimal128(..) => {
+ let dec_arr = as_decimal128_array(arr)?;
+ let precision = dec_arr.precision();
+ let in_scale = dec_arr.scale();
+
+ let result =
Decimal128Array::from_iter(dec_arr.iter().map(|opt| {
+ opt.map(|v| {
+ let diff = in_scale as i32 - scale;
+ if diff >= 0 {
+ round_i128_half_up(v, -diff)
+ } else {
+ v * 10_i128.pow((-diff) as u32)
+ }
+ })
+ }))
+ .with_precision_and_scale(precision, in_scale)
+ .map_err(|e| DataFusionError::Execution(e.to_string()))?;
+
+ Ok(ColumnarValue::Array(Arc::new(result)))
+ }
+
+ DataType::Int64 =>
Ok(ColumnarValue::Array(Arc::new(Int64Array::from_iter(
+ as_int64_array(arr)?
+ .iter()
+ .map(|opt| opt.map(|v| round_i128_half_up(v as i128,
scale) as i64)),
+ )))),
+
+ DataType::Int32 =>
Ok(ColumnarValue::Array(Arc::new(Int32Array::from_iter(
+ as_int32_array(arr)?
+ .iter()
+ .map(|opt| opt.map(|v| round_i128_half_up(v as i128,
scale) as i32)),
+ )))),
+
+ DataType::Int16 =>
Ok(ColumnarValue::Array(Arc::new(Int16Array::from_iter(
+ as_int16_array(arr)?
+ .iter()
+ .map(|opt| opt.map(|v| round_i128_half_up(v as i128,
scale) as i16)),
+ )))),
+
+ DataType::Float32 => {
+ // Handle Float32 Array case
+ let arr = as_float32_array(arr)?;
+ let factor = 10_f32.powi(scale);
+ let result = Float32Array::from_iter(arr.iter().map(|opt| {
+ opt.map(|v| {
+ if v.is_nan() || v.is_infinite() {
+ v
+ } else {
+ round_half_up_f32(v * factor) / factor
+ }
+ })
+ }));
+
+ Ok(ColumnarValue::Array(Arc::new(result)))
+ }
+
+ // Float64 fallback
+ _ => {
+ let arr = as_float64_array(arr)?;
+ let factor = 10_f64.powi(scale);
+ let result = Float64Array::from_iter(arr.iter().map(|opt| {
+ opt.map(|v| {
+ if v.is_nan() || v.is_infinite() {
+ v
+ } else {
+ round_half_up_f64(v * factor) / factor
+ }
+ })
+ }));
+ Ok(ColumnarValue::Array(Arc::new(result)))
+ }
+ },
+
+ // ---------- Scalar input ----------
+ ColumnarValue::Scalar(sv) => {
+ if sv.is_null() {
+ return Ok(ColumnarValue::Scalar(sv.clone()));
+ }
+
+ Ok(match sv {
+ ScalarValue::Float64(Some(v)) => {
+ let f = 10_f64.powi(scale);
+
ColumnarValue::Scalar(ScalarValue::Float64(Some(round_half_up_f64(v * f) / f)))
+ }
+ ScalarValue::Float32(Some(v)) => {
+ let f = 10_f64.powi(scale);
+ ColumnarValue::Scalar(ScalarValue::Float32(Some(
+ (round_half_up_f64((*v as f64) * f) / f) as f32,
+ )))
+ }
+ ScalarValue::Int64(Some(v)) =>
ColumnarValue::Scalar(ScalarValue::Int64(Some(
+ round_i128_half_up(*v as i128, scale) as i64,
+ ))),
+ ScalarValue::Int32(Some(v)) =>
ColumnarValue::Scalar(ScalarValue::Int32(Some(
+ round_i128_half_up(*v as i128, scale) as i32,
+ ))),
+ ScalarValue::Int16(Some(v)) =>
ColumnarValue::Scalar(ScalarValue::Int16(Some(
+ round_i128_half_up(*v as i128, scale) as i16,
+ ))),
+ ScalarValue::Decimal128(Some(v), p, s) =>
ColumnarValue::Scalar(
+ ScalarValue::Decimal128(Some(round_i128_half_up(*v,
scale)), *p, *s),
+ ),
+ _ => {
+ return Err(DataFusionError::Execution(
+ "Unsupported type for spark_round()".to_string(),
+ ));
+ }
+ })
+ }
+ }
+}
+
+/// Spark-style HALF_UP rounding (0.5 → 1, -0.5 → -1)
+fn round_half_up_f64(x: f64) -> f64 {
+ if x >= 0.0 {
+ (x + 0.5).floor()
+ } else {
+ (x - 0.5).ceil()
+ }
+}
+
+/// Spark-style HALF_UP rounding (0.5 → 1, -0.5 → -1) for Float32
+fn round_half_up_f32(x: f32) -> f32 {
+ if x >= 0.0 {
+ (x + 0.5).floor()
+ } else {
+ (x - 0.5).ceil()
+ }
+}
+
+/// Integer rounding using Spark's HALF_UP logic without float precision loss
+fn round_i128_half_up(value: i128, scale: i32) -> i128 {
+ if scale >= 0 {
+ return value;
+ }
+ let factor = 10_i128.pow((-scale) as u32);
+ let remainder = value % factor;
+ let base = value - remainder;
+
+ if value >= 0 {
+ if remainder * 2 >= factor {
+ base + factor
+ } else {
+ base
+ }
+ } else if remainder.abs() * 2 >= factor {
+ base - factor
+ } else {
+ base
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use datafusion::{
+ common::{Result, ScalarValue, cast::*},
+ physical_plan::ColumnarValue,
+ };
+
+ use super::*;
+
+ /// Unit test for `spark_round()` verifying correct rounding behavior on
+ /// Decimal128 inputs.
+ #[test]
+ fn test_round_decimal() -> Result<()> {
+ let arr = Arc::new(
+ Decimal128Array::from_iter_values([12345_i128, -67895_i128])
+ .with_precision_and_scale(10, 2)?,
+ );
+
+ let result = spark_round(&[
+ ColumnarValue::Array(arr.clone()),
+ ColumnarValue::Scalar(ScalarValue::Int32(Some(1))),
+ ])?;
+
+ assert!(matches!(result, ColumnarValue::Array(_)));
+
+ let out = result.into_array(2)?;
+ let arr = as_decimal128_array(&out)?;
+ let values: Vec<_> = arr.iter().collect();
+ assert_eq!(values, vec![Some(12350_i128), Some(-67900_i128)]);
+
+ Ok(())
+ }
+
+ /// Unit test for `spark_round()` verifying correct rounding behavior
+ /// when a **negative scale** is provided (i.e., rounding to tens,
hundreds,
+ /// etc.).
+ #[test]
+ fn test_round_negative_scale() -> Result<()> {
+ let arr = Arc::new(Float64Array::from(vec![Some(123.45),
Some(-678.9)]));
+ let result = spark_round(&[
+ ColumnarValue::Array(arr),
+ ColumnarValue::Scalar(ScalarValue::Int32(Some(-1))),
+ ])?;
+
+ let out = result.into_array(2)?;
+ let out = as_float64_array(&out)?;
+ let v: Vec<_> = out.iter().collect();
+
+ assert_eq!(v, vec![Some(120.0), Some(-680.0)]);
+ Ok(())
+ }
+
+ /// Unit test for `spark_round()` verifying rounding of Float64 values to a
+ /// positive decimal
+ #[test]
+ fn test_round_float() -> Result<()> {
+ let arr = Arc::new(Float64Array::from(vec![
+ Some(1.2345),
+ Some(-2.3456),
+ Some(0.5),
+ Some(-0.5),
+ None,
+ ]));
+
+ let result = spark_round(&[
+ ColumnarValue::Array(arr),
+ ColumnarValue::Scalar(ScalarValue::Int32(Some(2))),
+ ])?;
+
+ let out = result.into_array(5)?;
+ let out = as_float64_array(&out)?;
+ let v: Vec<_> = out.iter().collect();
+
+ assert_eq!(
+ v,
+ vec![Some(1.23), Some(-2.35), Some(0.5), Some(-0.5), None]
+ );
+ Ok(())
+ }
+
+ /// Unit test for `spark_round()` verifying Spark-style half-away-from-zero
+ /// rounding on scalar Float64.
+ #[test]
+ fn test_round_scalar() -> Result<()> {
+ let s = ColumnarValue::Scalar(ScalarValue::Float64(Some(-1.5)));
+ let result = spark_round(&[s,
ColumnarValue::Scalar(ScalarValue::Int32(Some(0)))])?;
+ match result {
+ ColumnarValue::Scalar(ScalarValue::Float64(Some(v))) =>
assert_eq!(v, -2.0),
+ _ => panic!("wrong result"),
+ }
+ Ok(())
+ }
+
+ /// Tests Spark-compatible rounding for 16-bit integer (Short).
+ #[test]
+ fn test_spark_round_short_pi_scales() -> Result<()> {
+ let short_pi: i16 = 31415;
+ let expected: Vec<i16> = vec![
+ 0, 0, 30000, 31000, 31400, 31420, 31415, 31415, 31415, 31415,
31415, 31415, 31415,
+ ];
+
+ for (i, scale) in (-6..=6).enumerate() {
+ let result = spark_round(&[
+ ColumnarValue::Scalar(ScalarValue::Int16(Some(short_pi))),
+ ColumnarValue::Scalar(ScalarValue::Int32(Some(scale))),
+ ])?;
+
+ let arr = result.into_array(1)?;
+ let out = as_int16_array(&arr)?;
+ assert_eq!(out.value(0), expected[i]);
+ }
+ Ok(())
+ }
+
+ /// Tests Spark-compatible rounding for Float32.
+ #[test]
+ fn test_spark_round_float_pi_scales() -> Result<()> {
+ let float_pi = 3.1415_f32;
+ let expected = vec![
+ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.141, 3.1415,
3.1415, 3.1415,
+ ];
+
+ for (i, scale) in (-6..=6).enumerate() {
+ let result = spark_round(&[
+ ColumnarValue::Scalar(ScalarValue::Float32(Some(float_pi))),
+ ColumnarValue::Scalar(ScalarValue::Int32(Some(scale))),
+ ])?;
+
+ let arr = result.into_array(1)?;
+ let out = as_float32_array(&arr)?;
+ assert!(
+ (out.value(0) - expected[i]).abs() < 1e-6,
+ "Mismatch at scale {scale}: expected {}, got {}",
+ expected[i],
+ out.value(0)
+ );
+ }
+ Ok(())
+ }
+
+ /// Tests Spark-compatible rounding for Float64 (Double precision).
+ #[test]
+ fn test_spark_round_double_pi_scales() -> Result<()> {
+ let double_pi = std::f64::consts::PI;
+ let expected = vec![
+ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142, 3.1416,
3.14159, 3.141593,
+ ];
+
+ for (i, scale) in (-6..=6).enumerate() {
+ let result = spark_round(&[
+ ColumnarValue::Scalar(ScalarValue::Float64(Some(double_pi))),
+ ColumnarValue::Scalar(ScalarValue::Int32(Some(scale))),
+ ])?;
+
+ let arr = result.into_array(1)?;
+ let out = as_float64_array(&arr)?;
+ let actual = out.value(0);
+ assert!(
+ (actual - expected[i]).abs() < 1e-9,
+ "Mismatch at scale {scale}: expected {}, got {}",
+ expected[i],
+ actual
+ );
+ }
+ Ok(())
+ }
+
+ /// Tests Spark-compatible rounding for Int32.
+ #[test]
+ fn test_spark_round_int_pi_scales() -> Result<()> {
+ let int_pi = 314159265_i32;
+ let expected = vec![
+ 314000000, 314200000, 314160000, 314159000, 314159300, 314159270,
314159265, 314159265,
+ 314159265, 314159265, 314159265, 314159265, 314159265,
+ ];
+
+ for (i, scale) in (-6..=6).enumerate() {
+ let result = spark_round(&[
+ ColumnarValue::Scalar(ScalarValue::Int32(Some(int_pi))),
+ ColumnarValue::Scalar(ScalarValue::Int32(Some(scale))),
+ ])?;
+
+ let arr = result.into_array(1)?;
+ let out = as_int32_array(&arr)?;
+ assert_eq!(
+ out.value(0),
+ expected[i],
+ "Mismatch at scale {scale}: expected {}, got {}",
+ expected[i],
+ out.value(0)
+ );
+ }
+ Ok(())
+ }
+
+ /// Tests Spark-compatible rounding for Decimal128 (Long in Spark).
+ #[test]
+ fn test_spark_round_long_pi_scales() -> Result<()> {
+ let long_pi = 31415926535897932_i128;
+ let expected = vec![
+ 31415926536000000,
+ 31415926535900000,
+ 31415926535900000,
+ 31415926535898000,
+ 31415926535897900,
+ 31415926535897930,
+ 31415926535897932,
+ 31415926535897932,
+ 31415926535897932,
+ 31415926535897932,
+ 31415926535897932,
+ 31415926535897932,
+ 31415926535897932,
+ ];
+
+ for (i, scale) in (-6..=6).enumerate() {
+ let result = spark_round(&[
+ ColumnarValue::Scalar(ScalarValue::Decimal128(Some(long_pi),
38, 0)),
+ ColumnarValue::Scalar(ScalarValue::Int32(Some(scale))),
+ ])?;
+
+ let arr = result.into_array(1)?;
+ let out = as_decimal128_array(&arr)?;
+ assert_eq!(
+ out.value(0),
+ expected[i],
+ "Mismatch at scale {scale}: expected {}, got {}",
+ expected[i],
+ out.value(0)
+ );
+ }
+ Ok(())
+ }
+}
diff --git
a/spark-extension-shims-spark3/src/test/scala/org/apache/spark/sql/auron/AuronFunctionSuite.scala
b/spark-extension-shims-spark3/src/test/scala/org/apache/spark/sql/auron/AuronFunctionSuite.scala
index 2f7e5707..fb46cd08 100644
---
a/spark-extension-shims-spark3/src/test/scala/org/apache/spark/sql/auron/AuronFunctionSuite.scala
+++
b/spark-extension-shims-spark3/src/test/scala/org/apache/spark/sql/auron/AuronFunctionSuite.scala
@@ -120,4 +120,158 @@ class AuronFunctionSuite
checkAnswer(df, Seq(Row("uron Spark SQ")))
}
}
+
+ test("round function with varying scales for intPi") {
+ withTable("t2") {
+ sql("CREATE TABLE t2 (c1 INT) USING parquet")
+
+ val intPi: Int = 314159265
+ sql(s"INSERT INTO t2 VALUES($intPi)")
+
+ val scales = -6 to 6
+ val expectedResults = Map(
+ -6 -> 314000000,
+ -5 -> 314200000,
+ -4 -> 314160000,
+ -3 -> 314159000,
+ -2 -> 314159300,
+ -1 -> 314159270,
+ 0 -> 314159265,
+ 1 -> 314159265,
+ 2 -> 314159265,
+ 3 -> 314159265,
+ 4 -> 314159265,
+ 5 -> 314159265,
+ 6 -> 314159265)
+
+ scales.foreach { scale =>
+ val df = sql(s"SELECT round(c1, $scale) AS xx FROM t2")
+ val expected = expectedResults(scale)
+ checkAnswer(df, Seq(Row(expected)))
+ }
+ }
+ }
+
+ test("round function with varying scales for doublePi") {
+ withTable("t1") {
+ sql("create table t1(c1 double) using parquet")
+
+ val doublePi: Double = math.Pi
+ sql(s"insert into t1 values($doublePi)")
+ val scales = -6 to 6
+ val expectedResults = Map(
+ -6 -> 0.0,
+ -5 -> 0.0,
+ -4 -> 0.0,
+ -3 -> 0.0,
+ -2 -> 0.0,
+ -1 -> 0.0,
+ 0 -> 3.0,
+ 1 -> 3.1,
+ 2 -> 3.14,
+ 3 -> 3.142,
+ 4 -> 3.1416,
+ 5 -> 3.14159,
+ 6 -> 3.141593)
+
+ scales.foreach { scale =>
+ val df = sql(s"select round(c1, $scale) from t1")
+ val expected = expectedResults(scale)
+ checkAnswer(df, Seq(Row(expected)))
+ }
+ }
+ }
+
+ test("round function with varying scales for floatPi") {
+ withTable("t1") {
+ sql("CREATE TABLE t1 (c1 FLOAT) USING parquet")
+
+ val floatPi: Float = 3.1415f
+ sql(s"INSERT INTO t1 VALUES($floatPi)")
+
+ val scales = -6 to 6
+ val expectedResults = Map(
+ -6 -> 0.0f,
+ -5 -> 0.0f,
+ -4 -> 0.0f,
+ -3 -> 0.0f,
+ -2 -> 0.0f,
+ -1 -> 0.0f,
+ 0 -> 3.0f,
+ 1 -> 3.1f,
+ 2 -> 3.14f,
+ 3 -> 3.142f,
+ 4 -> 3.1415f,
+ 5 -> 3.1415f,
+ 6 -> 3.1415f)
+
+ scales.foreach { scale =>
+ val df = sql(s"select round(c1, $scale) from t1")
+ val expected = expectedResults(scale)
+ checkAnswer(df, Seq(Row(expected)))
+ }
+ }
+ }
+
+ test("round function with varying scales for shortPi") {
+ withTable("t1") {
+ sql("CREATE TABLE t1 (c1 SMALLINT) USING parquet")
+
+ val shortPi: Short = 31415
+ sql(s"INSERT INTO t1 VALUES($shortPi)")
+
+ val scales = -6 to 6
+ val expectedResults = Map(
+ -6 -> 0.toShort,
+ -5 -> 0.toShort,
+ -4 -> 30000.toShort,
+ -3 -> 31000.toShort,
+ -2 -> 31400.toShort,
+ -1 -> 31420.toShort,
+ 0 -> 31415.toShort,
+ 1 -> 31415.toShort,
+ 2 -> 31415.toShort,
+ 3 -> 31415.toShort,
+ 4 -> 31415.toShort,
+ 5 -> 31415.toShort,
+ 6 -> 31415.toShort)
+
+ scales.foreach { scale =>
+ val df = sql(s"SELECT round(c1, $scale) FROM t1")
+ val expected = expectedResults(scale)
+ checkAnswer(df, Seq(Row(expected)))
+ }
+ }
+ }
+
+ test("round function with varying scales for longPi") {
+ withTable("t1") {
+ sql("CREATE TABLE t1 (c1 BIGINT) USING parquet")
+
+ val longPi: Long = 31415926535897932L
+ sql(s"INSERT INTO t1 VALUES($longPi)")
+
+ val scales = -6 to 6
+ val expectedResults = Map(
+ -6 -> 31415926536000000L,
+ -5 -> 31415926535900000L,
+ -4 -> 31415926535900000L,
+ -3 -> 31415926535898000L,
+ -2 -> 31415926535897900L,
+ -1 -> 31415926535897930L,
+ 0 -> 31415926535897932L,
+ 1 -> 31415926535897932L,
+ 2 -> 31415926535897932L,
+ 3 -> 31415926535897932L,
+ 4 -> 31415926535897932L,
+ 5 -> 31415926535897932L,
+ 6 -> 31415926535897932L)
+
+ scales.foreach { scale =>
+ val df = sql(s"SELECT round(c1, $scale) FROM t1")
+ val expected = expectedResults(scale)
+ checkAnswer(df, Seq(Row(expected)))
+ }
+ }
+ }
}
diff --git
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
index 3ec49807..72f96bf0 100644
---
a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
+++
b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala
@@ -815,9 +815,13 @@ object NativeConverters extends Logging {
buildScalarFunction(pb.ScalarFunction.Factorial, e.children,
e.dataType)
case e: Hex => buildScalarFunction(pb.ScalarFunction.Hex, e.children,
e.dataType)
- // TODO: datafusion's round() has different behavior from spark
- // case e @ Round(_1, Literal(n: Int, _)) if
_1.dataType.isInstanceOf[FractionalType] =>
- // buildScalarFunction(pb.ScalarFunction.Round, Seq(_1,
Literal(n.toLong)), e.dataType)
+ case e: Round =>
+ e.scale match {
+ case Literal(n: Int, _) =>
+ buildExtScalarFunction("Round", Seq(e.child, Literal(n.toLong)),
e.dataType)
+ case _ =>
+ buildExtScalarFunction("Round", Seq(e.child, Literal(0L)),
e.dataType)
+ }
case e: Signum => buildScalarFunction(pb.ScalarFunction.Signum,
e.children, e.dataType)
case e: Abs if e.dataType.isInstanceOf[FloatType] ||
e.dataType.isInstanceOf[DoubleType] =>