This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new edd63efb feat: Implement ANSI support for UnaryMinus (#471)
edd63efb is described below

commit edd63efb6965537dfc564d2b0eda5aa6c23398e7
Author: Vipul Vaibhaw <[email protected]>
AuthorDate: Tue Jun 4 02:37:51 2024 +0530

    feat: Implement ANSI support for UnaryMinus (#471)
    
    * checking for invalid inputs for unary minus
    
    * adding eval mode to expressions and proto message
    
    * extending evaluate function for negative expression
    
    * remove print statements
    
    * fix format errors
    
    * removing units
    
    * fix clippy errors
    
    * expect instead of unwrap, map_err instead of match and removing Float16
    
    * adding test case for unary negative integer overflow
    
    * added a function to make the code more readable
    
    * adding comet sql ansi config
    
    * using withTempDir and checkSparkAnswerAndOperator
    
    * adding macros to improve code readability
    
    * using withParquetTable
    
    * adding scalar tests
    
    * adding more test cases and bug fix
    
    * using failonerror and removing eval_mode
    
    * bug fix
    
    * removing checks for float64 and monthdaynano
    
    * removing checks of float and monthday nano
    
    * adding checks while evalute bounds
    
    * IntervalDayTime splitting i64 and then checking
    
    * Adding interval test
    
    * fix ci errors
---
 core/src/errors.rs                                 |   3 +
 core/src/execution/datafusion/expressions/mod.rs   |   1 +
 .../execution/datafusion/expressions/negative.rs   | 270 +++++++++++++++++++++
 core/src/execution/datafusion/planner.rs           |   9 +-
 core/src/execution/proto/expr.proto                |   1 +
 .../org/apache/comet/serde/QueryPlanSerde.scala    |   3 +-
 .../org/apache/comet/CometExpressionSuite.scala    |  98 ++++++++
 7 files changed, 381 insertions(+), 4 deletions(-)

diff --git a/core/src/errors.rs b/core/src/errors.rs
index 04a1629d..af4fd269 100644
--- a/core/src/errors.rs
+++ b/core/src/errors.rs
@@ -88,6 +88,9 @@ pub enum CometError {
         to_type: String,
     },
 
+    #[error("[ARITHMETIC_OVERFLOW] {from_type} overflow. If necessary set 
\"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
+    ArithmeticOverflow { from_type: String },
+
     #[error(transparent)]
     Arrow {
         #[from]
diff --git a/core/src/execution/datafusion/expressions/mod.rs 
b/core/src/execution/datafusion/expressions/mod.rs
index 9db4b65b..084fef2d 100644
--- a/core/src/execution/datafusion/expressions/mod.rs
+++ b/core/src/execution/datafusion/expressions/mod.rs
@@ -29,6 +29,7 @@ pub mod avg_decimal;
 pub mod bloom_filter_might_contain;
 pub mod correlation;
 pub mod covariance;
+pub mod negative;
 pub mod stats;
 pub mod stddev;
 pub mod strings;
diff --git a/core/src/execution/datafusion/expressions/negative.rs 
b/core/src/execution/datafusion/expressions/negative.rs
new file mode 100644
index 00000000..e7aa2ac6
--- /dev/null
+++ b/core/src/execution/datafusion/expressions/negative.rs
@@ -0,0 +1,270 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::errors::CometError;
+use arrow::{compute::kernels::numeric::neg_wrapping, 
datatypes::IntervalDayTimeType};
+use arrow_array::RecordBatch;
+use arrow_schema::{DataType, Schema};
+use datafusion::{
+    logical_expr::{interval_arithmetic::Interval, ColumnarValue},
+    physical_expr::PhysicalExpr,
+};
+use datafusion_common::{Result, ScalarValue};
+use datafusion_physical_expr::{
+    aggregate::utils::down_cast_any_ref, sort_properties::SortProperties,
+};
+use std::{
+    any::Any,
+    hash::{Hash, Hasher},
+    sync::Arc,
+};
+
+pub fn create_negate_expr(
+    expr: Arc<dyn PhysicalExpr>,
+    fail_on_error: bool,
+) -> Result<Arc<dyn PhysicalExpr>, CometError> {
+    Ok(Arc::new(NegativeExpr::new(expr, fail_on_error)))
+}
+
+/// Negative expression
+#[derive(Debug, Hash)]
+pub struct NegativeExpr {
+    /// Input expression
+    arg: Arc<dyn PhysicalExpr>,
+    fail_on_error: bool,
+}
+
+fn arithmetic_overflow_error(from_type: &str) -> CometError {
+    CometError::ArithmeticOverflow {
+        from_type: from_type.to_string(),
+    }
+}
+
+macro_rules! check_overflow {
+    ($array:expr, $array_type:ty, $min_val:expr, $type_name:expr) => {{
+        let typed_array = $array
+            .as_any()
+            .downcast_ref::<$array_type>()
+            .expect(concat!(stringify!($array_type), " expected"));
+        for i in 0..typed_array.len() {
+            if typed_array.value(i) == $min_val {
+                if $type_name == "byte" || $type_name == "short" {
+                    let value = typed_array.value(i).to_string() + " caused";
+                    return 
Err(arithmetic_overflow_error(value.as_str()).into());
+                }
+                return Err(arithmetic_overflow_error($type_name).into());
+            }
+        }
+    }};
+}
+
+impl NegativeExpr {
+    /// Create new not expression
+    pub fn new(arg: Arc<dyn PhysicalExpr>, fail_on_error: bool) -> Self {
+        Self { arg, fail_on_error }
+    }
+
+    /// Get the input expression
+    pub fn arg(&self) -> &Arc<dyn PhysicalExpr> {
+        &self.arg
+    }
+}
+
+impl std::fmt::Display for NegativeExpr {
+    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+        write!(f, "(- {})", self.arg)
+    }
+}
+
+impl PhysicalExpr for NegativeExpr {
+    /// Return a reference to Any that can be used for downcasting
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
+        self.arg.data_type(input_schema)
+    }
+
+    fn nullable(&self, input_schema: &Schema) -> Result<bool> {
+        self.arg.nullable(input_schema)
+    }
+
+    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
+        let arg = self.arg.evaluate(batch)?;
+
+        // overflow checks only apply in ANSI mode
+        // datatypes supported are byte, short, integer, long, float, interval
+        match arg {
+            ColumnarValue::Array(array) => {
+                if self.fail_on_error {
+                    match array.data_type() {
+                        DataType::Int8 => {
+                            check_overflow!(array, arrow::array::Int8Array, 
i8::MIN, "byte")
+                        }
+                        DataType::Int16 => {
+                            check_overflow!(array, arrow::array::Int16Array, 
i16::MIN, "short")
+                        }
+                        DataType::Int32 => {
+                            check_overflow!(array, arrow::array::Int32Array, 
i32::MIN, "integer")
+                        }
+                        DataType::Int64 => {
+                            check_overflow!(array, arrow::array::Int64Array, 
i64::MIN, "long")
+                        }
+                        DataType::Interval(value) => match value {
+                            arrow::datatypes::IntervalUnit::YearMonth => 
check_overflow!(
+                                array,
+                                arrow::array::IntervalYearMonthArray,
+                                i32::MIN,
+                                "interval"
+                            ),
+                            arrow::datatypes::IntervalUnit::DayTime => 
check_overflow!(
+                                array,
+                                arrow::array::IntervalDayTimeArray,
+                                i64::MIN,
+                                "interval"
+                            ),
+                            arrow::datatypes::IntervalUnit::MonthDayNano => {
+                                // Overflow checks are not supported
+                            }
+                        },
+                        _ => {
+                            // Overflow checks are not supported for other 
datatypes
+                        }
+                    }
+                }
+                let result = neg_wrapping(array.as_ref())?;
+                Ok(ColumnarValue::Array(result))
+            }
+            ColumnarValue::Scalar(scalar) => {
+                if self.fail_on_error {
+                    match scalar {
+                        ScalarValue::Int8(value) => {
+                            if value == Some(i8::MIN) {
+                                return Err(arithmetic_overflow_error(" 
caused").into());
+                            }
+                        }
+                        ScalarValue::Int16(value) => {
+                            if value == Some(i16::MIN) {
+                                return Err(arithmetic_overflow_error(" 
caused").into());
+                            }
+                        }
+                        ScalarValue::Int32(value) => {
+                            if value == Some(i32::MIN) {
+                                return 
Err(arithmetic_overflow_error("integer").into());
+                            }
+                        }
+                        ScalarValue::Int64(value) => {
+                            if value == Some(i64::MIN) {
+                                return 
Err(arithmetic_overflow_error("long").into());
+                            }
+                        }
+                        ScalarValue::IntervalDayTime(value) => {
+                            let (days, ms) =
+                                
IntervalDayTimeType::to_parts(value.unwrap_or_default());
+                            if days == i32::MIN || ms == i32::MIN {
+                                return 
Err(arithmetic_overflow_error("interval").into());
+                            }
+                        }
+                        ScalarValue::IntervalYearMonth(value) => {
+                            if value == Some(i32::MIN) {
+                                return 
Err(arithmetic_overflow_error("interval").into());
+                            }
+                        }
+                        _ => {
+                            // Overflow checks are not supported for other 
datatypes
+                        }
+                    }
+                }
+                Ok(ColumnarValue::Scalar((scalar.arithmetic_negate())?))
+            }
+        }
+    }
+
+    fn children(&self) -> Vec<Arc<dyn PhysicalExpr>> {
+        vec![self.arg.clone()]
+    }
+
+    fn with_new_children(
+        self: Arc<Self>,
+        children: Vec<Arc<dyn PhysicalExpr>>,
+    ) -> Result<Arc<dyn PhysicalExpr>> {
+        Ok(Arc::new(NegativeExpr::new(
+            children[0].clone(),
+            self.fail_on_error,
+        )))
+    }
+
+    fn dyn_hash(&self, state: &mut dyn Hasher) {
+        let mut s = state;
+        self.hash(&mut s);
+    }
+
+    /// Given the child interval of a NegativeExpr, it calculates the 
NegativeExpr's interval.
+    /// It replaces the upper and lower bounds after multiplying them with -1.
+    /// Ex: `(a, b]` => `[-b, -a)`
+    fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
+        Interval::try_new(
+            children[0].upper().arithmetic_negate()?,
+            children[0].lower().arithmetic_negate()?,
+        )
+    }
+
+    /// Returns a new [`Interval`] of a NegativeExpr  that has the existing 
`interval` given that
+    /// given the input interval is known to be `children`.
+    fn propagate_constraints(
+        &self,
+        interval: &Interval,
+        children: &[&Interval],
+    ) -> Result<Option<Vec<Interval>>> {
+        let child_interval = children[0];
+
+        if child_interval.lower() == &ScalarValue::Int32(Some(i32::MIN))
+            || child_interval.upper() == &ScalarValue::Int32(Some(i32::MIN))
+            || child_interval.lower() == &ScalarValue::Int64(Some(i64::MIN))
+            || child_interval.upper() == &ScalarValue::Int64(Some(i64::MIN))
+        {
+            return Err(CometError::ArithmeticOverflow {
+                from_type: "long".to_string(),
+            }
+            .into());
+        }
+
+        let negated_interval = Interval::try_new(
+            interval.upper().arithmetic_negate()?,
+            interval.lower().arithmetic_negate()?,
+        )?;
+
+        Ok(child_interval
+            .intersect(negated_interval)?
+            .map(|result| vec![result]))
+    }
+
+    /// The ordering of a [`NegativeExpr`] is simply the reverse of its child.
+    fn get_ordering(&self, children: &[SortProperties]) -> SortProperties {
+        -children[0]
+    }
+}
+
+impl PartialEq<dyn Any> for NegativeExpr {
+    fn eq(&self, other: &dyn Any) -> bool {
+        down_cast_any_ref(other)
+            .downcast_ref::<Self>()
+            .map(|x| self.arg.eq(&x.arg))
+            .unwrap_or(false)
+    }
+}
diff --git a/core/src/execution/datafusion/planner.rs 
b/core/src/execution/datafusion/planner.rs
index 20119b13..3a8548f7 100644
--- a/core/src/execution/datafusion/planner.rs
+++ b/core/src/execution/datafusion/planner.rs
@@ -33,7 +33,7 @@ use datafusion::{
         expressions::{
             in_list, BinaryExpr, BitAnd, BitOr, BitXor, CaseExpr, CastExpr, 
Column, Count,
             FirstValue, InListExpr, IsNotNullExpr, IsNullExpr, LastValue,
-            Literal as DataFusionLiteral, Max, Min, NegativeExpr, NotExpr, 
Sum, UnKnownColumn,
+            Literal as DataFusionLiteral, Max, Min, NotExpr, Sum, 
UnKnownColumn,
         },
         AggregateExpr, PhysicalExpr, PhysicalSortExpr, ScalarFunctionExpr,
     },
@@ -70,6 +70,7 @@ use crate::{
                 correlation::Correlation,
                 covariance::Covariance,
                 if_expr::IfExpr,
+                negative,
                 scalar_funcs::create_comet_physical_fun,
                 stats::StatsType,
                 stddev::Stddev,
@@ -563,8 +564,10 @@ impl PhysicalPlanner {
                 Ok(Arc::new(NotExpr::new(child)))
             }
             ExprStruct::Negative(expr) => {
-                let child = self.create_expr(expr.child.as_ref().unwrap(), 
input_schema)?;
-                Ok(Arc::new(NegativeExpr::new(child)))
+                let child: Arc<dyn PhysicalExpr> =
+                    self.create_expr(expr.child.as_ref().unwrap(), 
input_schema.clone())?;
+                let result = negative::create_negate_expr(child, 
expr.fail_on_error);
+                result.map_err(|e| ExecutionError::GeneralError(e.to_string()))
             }
             ExprStruct::NormalizeNanAndZero(expr) => {
                 let child = self.create_expr(expr.child.as_ref().unwrap(), 
input_schema)?;
diff --git a/core/src/execution/proto/expr.proto 
b/core/src/execution/proto/expr.proto
index bcd98387..9c604901 100644
--- a/core/src/execution/proto/expr.proto
+++ b/core/src/execution/proto/expr.proto
@@ -454,6 +454,7 @@ message Not {
 
 message Negative {
   Expr child = 1;
+  bool fail_on_error = 2;
 }
 
 message IfExpr {
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 79a75102..5fe290cf 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -1984,11 +1984,12 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
             None
           }
 
-        case UnaryMinus(child, _) =>
+        case UnaryMinus(child, failOnError) =>
           val childExpr = exprToProtoInternal(child, inputs)
           if (childExpr.isDefined) {
             val builder = ExprOuterClass.Negative.newBuilder()
             builder.setChild(childExpr.get)
+            builder.setFailOnError(failOnError)
             Some(
               ExprOuterClass.Expr
                 .newBuilder()
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index e15b09ca..e69054db 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -1548,5 +1548,103 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
       }
     }
   }
+  test("unary negative integer overflow test") {
+    def withAnsiMode(enabled: Boolean)(f: => Unit): Unit = {
+      withSQLConf(
+        SQLConf.ANSI_ENABLED.key -> enabled.toString,
+        CometConf.COMET_ANSI_MODE_ENABLED.key -> enabled.toString,
+        CometConf.COMET_ENABLED.key -> "true",
+        CometConf.COMET_EXEC_ENABLED.key -> "true")(f)
+    }
+
+    def checkOverflow(query: String, dtype: String): Unit = {
+      checkSparkMaybeThrows(sql(query)) match {
+        case (Some(sparkException), Some(cometException)) =>
+          assert(sparkException.getMessage.contains(dtype + " overflow"))
+          assert(cometException.getMessage.contains(dtype + " overflow"))
+        case (None, None) => assert(true) // got same outputs
+        case (None, Some(ex)) =>
+          fail("Comet threw an exception but Spark did not " + ex.getMessage)
+        case (Some(_), None) =>
+          fail("Spark threw an exception but Comet did not")
+      }
+    }
+
+    def runArrayTest(query: String, dtype: String, path: String): Unit = {
+      withParquetTable(path, "t") {
+        withAnsiMode(enabled = false) {
+          checkSparkAnswerAndOperator(sql(query))
+        }
+        withAnsiMode(enabled = true) {
+          checkOverflow(query, dtype)
+        }
+      }
+    }
 
+    withTempDir { dir =>
+      // Array values test
+      val arrayPath = new Path(dir.toURI.toString, 
"array_test.parquet").toString
+      Seq(Int.MaxValue, 
Int.MinValue).toDF("a").write.mode("overwrite").parquet(arrayPath)
+      val arrayQuery = "select a, -a from t"
+      runArrayTest(arrayQuery, "integer", arrayPath)
+
+      // long values test
+      val longArrayPath = new Path(dir.toURI.toString, 
"long_array_test.parquet").toString
+      Seq(Long.MaxValue, Long.MinValue)
+        .toDF("a")
+        .write
+        .mode("overwrite")
+        .parquet(longArrayPath)
+      val longArrayQuery = "select a, -a from t"
+      runArrayTest(longArrayQuery, "long", longArrayPath)
+
+      // short values test
+      val shortArrayPath = new Path(dir.toURI.toString, 
"short_array_test.parquet").toString
+      Seq(Short.MaxValue, Short.MinValue)
+        .toDF("a")
+        .write
+        .mode("overwrite")
+        .parquet(shortArrayPath)
+      val shortArrayQuery = "select a, -a from t"
+      runArrayTest(shortArrayQuery, " caused", shortArrayPath)
+
+      // byte values test
+      val byteArrayPath = new Path(dir.toURI.toString, 
"byte_array_test.parquet").toString
+      Seq(Byte.MaxValue, Byte.MinValue)
+        .toDF("a")
+        .write
+        .mode("overwrite")
+        .parquet(byteArrayPath)
+      val byteArrayQuery = "select a, -a from t"
+      runArrayTest(byteArrayQuery, " caused", byteArrayPath)
+
+      // interval values test
+      withTable("t_interval") {
+        spark.sql("CREATE TABLE t_interval(a STRING) USING PARQUET")
+        spark.sql("INSERT INTO t_interval VALUES ('INTERVAL 10000000000 
YEAR')")
+        withAnsiMode(enabled = true) {
+          spark
+            .sql("SELECT CAST(a AS INTERVAL) AS a FROM t_interval")
+            .createOrReplaceTempView("t_interval_casted")
+          checkOverflow("SELECT a, -a FROM t_interval_casted", "interval")
+        }
+      }
+
+      withTable("t") {
+        sql("create table t(a int) using parquet")
+        sql("insert into t values (-2147483648)")
+        withAnsiMode(enabled = true) {
+          checkOverflow("select a, -a from t", "integer")
+        }
+      }
+
+      withTable("t_float") {
+        sql("create table t_float(a float) using parquet")
+        sql("insert into t_float values (3.4128235E38)")
+        withAnsiMode(enabled = true) {
+          checkOverflow("select a, -a from t_float", "float")
+        }
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to