This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new edd63efb feat: Implement ANSI support for UnaryMinus (#471)
edd63efb is described below
commit edd63efb6965537dfc564d2b0eda5aa6c23398e7
Author: Vipul Vaibhaw <[email protected]>
AuthorDate: Tue Jun 4 02:37:51 2024 +0530
feat: Implement ANSI support for UnaryMinus (#471)
* checking for invalid inputs for unary minus
* adding eval mode to expressions and proto message
* extending evaluate function for negative expression
* remove print statements
* fix format errors
* removing units
* fix clippy errors
* expect instead of unwrap, map_err instead of match and removing Float16
* adding test case for unary negative integer overflow
* added a function to make the code more readable
* adding comet sql ansi config
* using withTempDir and checkSparkAnswerAndOperator
* adding macros to improve code readability
* using withParquetTable
* adding scalar tests
* adding more test cases and bug fix
* using failonerror and removing eval_mode
* bug fix
* removing checks for float64 and monthdaynano
* removing checks of float and monthday nano
* adding checks while evalute bounds
* IntervalDayTime splitting i64 and then checking
* Adding interval test
* fix ci errors
---
core/src/errors.rs | 3 +
core/src/execution/datafusion/expressions/mod.rs | 1 +
.../execution/datafusion/expressions/negative.rs | 270 +++++++++++++++++++++
core/src/execution/datafusion/planner.rs | 9 +-
core/src/execution/proto/expr.proto | 1 +
.../org/apache/comet/serde/QueryPlanSerde.scala | 3 +-
.../org/apache/comet/CometExpressionSuite.scala | 98 ++++++++
7 files changed, 381 insertions(+), 4 deletions(-)
diff --git a/core/src/errors.rs b/core/src/errors.rs
index 04a1629d..af4fd269 100644
--- a/core/src/errors.rs
+++ b/core/src/errors.rs
@@ -88,6 +88,9 @@ pub enum CometError {
to_type: String,
},
+ #[error("[ARITHMETIC_OVERFLOW] {from_type} overflow. If necessary set
\"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
+ ArithmeticOverflow { from_type: String },
+
#[error(transparent)]
Arrow {
#[from]
diff --git a/core/src/execution/datafusion/expressions/mod.rs
b/core/src/execution/datafusion/expressions/mod.rs
index 9db4b65b..084fef2d 100644
--- a/core/src/execution/datafusion/expressions/mod.rs
+++ b/core/src/execution/datafusion/expressions/mod.rs
@@ -29,6 +29,7 @@ pub mod avg_decimal;
pub mod bloom_filter_might_contain;
pub mod correlation;
pub mod covariance;
+pub mod negative;
pub mod stats;
pub mod stddev;
pub mod strings;
diff --git a/core/src/execution/datafusion/expressions/negative.rs
b/core/src/execution/datafusion/expressions/negative.rs
new file mode 100644
index 00000000..e7aa2ac6
--- /dev/null
+++ b/core/src/execution/datafusion/expressions/negative.rs
@@ -0,0 +1,270 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::errors::CometError;
+use arrow::{compute::kernels::numeric::neg_wrapping,
datatypes::IntervalDayTimeType};
+use arrow_array::RecordBatch;
+use arrow_schema::{DataType, Schema};
+use datafusion::{
+ logical_expr::{interval_arithmetic::Interval, ColumnarValue},
+ physical_expr::PhysicalExpr,
+};
+use datafusion_common::{Result, ScalarValue};
+use datafusion_physical_expr::{
+ aggregate::utils::down_cast_any_ref, sort_properties::SortProperties,
+};
+use std::{
+ any::Any,
+ hash::{Hash, Hasher},
+ sync::Arc,
+};
+
+pub fn create_negate_expr(
+ expr: Arc<dyn PhysicalExpr>,
+ fail_on_error: bool,
+) -> Result<Arc<dyn PhysicalExpr>, CometError> {
+ Ok(Arc::new(NegativeExpr::new(expr, fail_on_error)))
+}
+
+/// Negative expression
+#[derive(Debug, Hash)]
+pub struct NegativeExpr {
+ /// Input expression
+ arg: Arc<dyn PhysicalExpr>,
+ fail_on_error: bool,
+}
+
+fn arithmetic_overflow_error(from_type: &str) -> CometError {
+ CometError::ArithmeticOverflow {
+ from_type: from_type.to_string(),
+ }
+}
+
+macro_rules! check_overflow {
+ ($array:expr, $array_type:ty, $min_val:expr, $type_name:expr) => {{
+ let typed_array = $array
+ .as_any()
+ .downcast_ref::<$array_type>()
+ .expect(concat!(stringify!($array_type), " expected"));
+ for i in 0..typed_array.len() {
+ if typed_array.value(i) == $min_val {
+ if $type_name == "byte" || $type_name == "short" {
+ let value = typed_array.value(i).to_string() + " caused";
+ return
Err(arithmetic_overflow_error(value.as_str()).into());
+ }
+ return Err(arithmetic_overflow_error($type_name).into());
+ }
+ }
+ }};
+}
+
+impl NegativeExpr {
+ /// Create new not expression
+ pub fn new(arg: Arc<dyn PhysicalExpr>, fail_on_error: bool) -> Self {
+ Self { arg, fail_on_error }
+ }
+
+ /// Get the input expression
+ pub fn arg(&self) -> &Arc<dyn PhysicalExpr> {
+ &self.arg
+ }
+}
+
+impl std::fmt::Display for NegativeExpr {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ write!(f, "(- {})", self.arg)
+ }
+}
+
+impl PhysicalExpr for NegativeExpr {
+ /// Return a reference to Any that can be used for downcasting
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
+ self.arg.data_type(input_schema)
+ }
+
+ fn nullable(&self, input_schema: &Schema) -> Result<bool> {
+ self.arg.nullable(input_schema)
+ }
+
+ fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
+ let arg = self.arg.evaluate(batch)?;
+
+ // overflow checks only apply in ANSI mode
+ // datatypes supported are byte, short, integer, long, float, interval
+ match arg {
+ ColumnarValue::Array(array) => {
+ if self.fail_on_error {
+ match array.data_type() {
+ DataType::Int8 => {
+ check_overflow!(array, arrow::array::Int8Array,
i8::MIN, "byte")
+ }
+ DataType::Int16 => {
+ check_overflow!(array, arrow::array::Int16Array,
i16::MIN, "short")
+ }
+ DataType::Int32 => {
+ check_overflow!(array, arrow::array::Int32Array,
i32::MIN, "integer")
+ }
+ DataType::Int64 => {
+ check_overflow!(array, arrow::array::Int64Array,
i64::MIN, "long")
+ }
+ DataType::Interval(value) => match value {
+ arrow::datatypes::IntervalUnit::YearMonth =>
check_overflow!(
+ array,
+ arrow::array::IntervalYearMonthArray,
+ i32::MIN,
+ "interval"
+ ),
+ arrow::datatypes::IntervalUnit::DayTime =>
check_overflow!(
+ array,
+ arrow::array::IntervalDayTimeArray,
+ i64::MIN,
+ "interval"
+ ),
+ arrow::datatypes::IntervalUnit::MonthDayNano => {
+ // Overflow checks are not supported
+ }
+ },
+ _ => {
+ // Overflow checks are not supported for other
datatypes
+ }
+ }
+ }
+ let result = neg_wrapping(array.as_ref())?;
+ Ok(ColumnarValue::Array(result))
+ }
+ ColumnarValue::Scalar(scalar) => {
+ if self.fail_on_error {
+ match scalar {
+ ScalarValue::Int8(value) => {
+ if value == Some(i8::MIN) {
+ return Err(arithmetic_overflow_error("
caused").into());
+ }
+ }
+ ScalarValue::Int16(value) => {
+ if value == Some(i16::MIN) {
+ return Err(arithmetic_overflow_error("
caused").into());
+ }
+ }
+ ScalarValue::Int32(value) => {
+ if value == Some(i32::MIN) {
+ return
Err(arithmetic_overflow_error("integer").into());
+ }
+ }
+ ScalarValue::Int64(value) => {
+ if value == Some(i64::MIN) {
+ return
Err(arithmetic_overflow_error("long").into());
+ }
+ }
+ ScalarValue::IntervalDayTime(value) => {
+ let (days, ms) =
+
IntervalDayTimeType::to_parts(value.unwrap_or_default());
+ if days == i32::MIN || ms == i32::MIN {
+ return
Err(arithmetic_overflow_error("interval").into());
+ }
+ }
+ ScalarValue::IntervalYearMonth(value) => {
+ if value == Some(i32::MIN) {
+ return
Err(arithmetic_overflow_error("interval").into());
+ }
+ }
+ _ => {
+ // Overflow checks are not supported for other
datatypes
+ }
+ }
+ }
+ Ok(ColumnarValue::Scalar((scalar.arithmetic_negate())?))
+ }
+ }
+ }
+
+ fn children(&self) -> Vec<Arc<dyn PhysicalExpr>> {
+ vec![self.arg.clone()]
+ }
+
+ fn with_new_children(
+ self: Arc<Self>,
+ children: Vec<Arc<dyn PhysicalExpr>>,
+ ) -> Result<Arc<dyn PhysicalExpr>> {
+ Ok(Arc::new(NegativeExpr::new(
+ children[0].clone(),
+ self.fail_on_error,
+ )))
+ }
+
+ fn dyn_hash(&self, state: &mut dyn Hasher) {
+ let mut s = state;
+ self.hash(&mut s);
+ }
+
+ /// Given the child interval of a NegativeExpr, it calculates the
NegativeExpr's interval.
+ /// It replaces the upper and lower bounds after multiplying them with -1.
+ /// Ex: `(a, b]` => `[-b, -a)`
+ fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
+ Interval::try_new(
+ children[0].upper().arithmetic_negate()?,
+ children[0].lower().arithmetic_negate()?,
+ )
+ }
+
+ /// Returns a new [`Interval`] of a NegativeExpr that has the existing
`interval` given that
+ /// given the input interval is known to be `children`.
+ fn propagate_constraints(
+ &self,
+ interval: &Interval,
+ children: &[&Interval],
+ ) -> Result<Option<Vec<Interval>>> {
+ let child_interval = children[0];
+
+ if child_interval.lower() == &ScalarValue::Int32(Some(i32::MIN))
+ || child_interval.upper() == &ScalarValue::Int32(Some(i32::MIN))
+ || child_interval.lower() == &ScalarValue::Int64(Some(i64::MIN))
+ || child_interval.upper() == &ScalarValue::Int64(Some(i64::MIN))
+ {
+ return Err(CometError::ArithmeticOverflow {
+ from_type: "long".to_string(),
+ }
+ .into());
+ }
+
+ let negated_interval = Interval::try_new(
+ interval.upper().arithmetic_negate()?,
+ interval.lower().arithmetic_negate()?,
+ )?;
+
+ Ok(child_interval
+ .intersect(negated_interval)?
+ .map(|result| vec![result]))
+ }
+
+ /// The ordering of a [`NegativeExpr`] is simply the reverse of its child.
+ fn get_ordering(&self, children: &[SortProperties]) -> SortProperties {
+ -children[0]
+ }
+}
+
+impl PartialEq<dyn Any> for NegativeExpr {
+ fn eq(&self, other: &dyn Any) -> bool {
+ down_cast_any_ref(other)
+ .downcast_ref::<Self>()
+ .map(|x| self.arg.eq(&x.arg))
+ .unwrap_or(false)
+ }
+}
diff --git a/core/src/execution/datafusion/planner.rs
b/core/src/execution/datafusion/planner.rs
index 20119b13..3a8548f7 100644
--- a/core/src/execution/datafusion/planner.rs
+++ b/core/src/execution/datafusion/planner.rs
@@ -33,7 +33,7 @@ use datafusion::{
expressions::{
in_list, BinaryExpr, BitAnd, BitOr, BitXor, CaseExpr, CastExpr,
Column, Count,
FirstValue, InListExpr, IsNotNullExpr, IsNullExpr, LastValue,
- Literal as DataFusionLiteral, Max, Min, NegativeExpr, NotExpr,
Sum, UnKnownColumn,
+ Literal as DataFusionLiteral, Max, Min, NotExpr, Sum,
UnKnownColumn,
},
AggregateExpr, PhysicalExpr, PhysicalSortExpr, ScalarFunctionExpr,
},
@@ -70,6 +70,7 @@ use crate::{
correlation::Correlation,
covariance::Covariance,
if_expr::IfExpr,
+ negative,
scalar_funcs::create_comet_physical_fun,
stats::StatsType,
stddev::Stddev,
@@ -563,8 +564,10 @@ impl PhysicalPlanner {
Ok(Arc::new(NotExpr::new(child)))
}
ExprStruct::Negative(expr) => {
- let child = self.create_expr(expr.child.as_ref().unwrap(),
input_schema)?;
- Ok(Arc::new(NegativeExpr::new(child)))
+ let child: Arc<dyn PhysicalExpr> =
+ self.create_expr(expr.child.as_ref().unwrap(),
input_schema.clone())?;
+ let result = negative::create_negate_expr(child,
expr.fail_on_error);
+ result.map_err(|e| ExecutionError::GeneralError(e.to_string()))
}
ExprStruct::NormalizeNanAndZero(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(),
input_schema)?;
diff --git a/core/src/execution/proto/expr.proto
b/core/src/execution/proto/expr.proto
index bcd98387..9c604901 100644
--- a/core/src/execution/proto/expr.proto
+++ b/core/src/execution/proto/expr.proto
@@ -454,6 +454,7 @@ message Not {
message Negative {
Expr child = 1;
+ bool fail_on_error = 2;
}
message IfExpr {
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 79a75102..5fe290cf 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -1984,11 +1984,12 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
None
}
- case UnaryMinus(child, _) =>
+ case UnaryMinus(child, failOnError) =>
val childExpr = exprToProtoInternal(child, inputs)
if (childExpr.isDefined) {
val builder = ExprOuterClass.Negative.newBuilder()
builder.setChild(childExpr.get)
+ builder.setFailOnError(failOnError)
Some(
ExprOuterClass.Expr
.newBuilder()
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index e15b09ca..e69054db 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -1548,5 +1548,103 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
}
+ test("unary negative integer overflow test") {
+ def withAnsiMode(enabled: Boolean)(f: => Unit): Unit = {
+ withSQLConf(
+ SQLConf.ANSI_ENABLED.key -> enabled.toString,
+ CometConf.COMET_ANSI_MODE_ENABLED.key -> enabled.toString,
+ CometConf.COMET_ENABLED.key -> "true",
+ CometConf.COMET_EXEC_ENABLED.key -> "true")(f)
+ }
+
+ def checkOverflow(query: String, dtype: String): Unit = {
+ checkSparkMaybeThrows(sql(query)) match {
+ case (Some(sparkException), Some(cometException)) =>
+ assert(sparkException.getMessage.contains(dtype + " overflow"))
+ assert(cometException.getMessage.contains(dtype + " overflow"))
+ case (None, None) => assert(true) // got same outputs
+ case (None, Some(ex)) =>
+ fail("Comet threw an exception but Spark did not " + ex.getMessage)
+ case (Some(_), None) =>
+ fail("Spark threw an exception but Comet did not")
+ }
+ }
+
+ def runArrayTest(query: String, dtype: String, path: String): Unit = {
+ withParquetTable(path, "t") {
+ withAnsiMode(enabled = false) {
+ checkSparkAnswerAndOperator(sql(query))
+ }
+ withAnsiMode(enabled = true) {
+ checkOverflow(query, dtype)
+ }
+ }
+ }
+ withTempDir { dir =>
+ // Array values test
+ val arrayPath = new Path(dir.toURI.toString,
"array_test.parquet").toString
+ Seq(Int.MaxValue,
Int.MinValue).toDF("a").write.mode("overwrite").parquet(arrayPath)
+ val arrayQuery = "select a, -a from t"
+ runArrayTest(arrayQuery, "integer", arrayPath)
+
+ // long values test
+ val longArrayPath = new Path(dir.toURI.toString,
"long_array_test.parquet").toString
+ Seq(Long.MaxValue, Long.MinValue)
+ .toDF("a")
+ .write
+ .mode("overwrite")
+ .parquet(longArrayPath)
+ val longArrayQuery = "select a, -a from t"
+ runArrayTest(longArrayQuery, "long", longArrayPath)
+
+ // short values test
+ val shortArrayPath = new Path(dir.toURI.toString,
"short_array_test.parquet").toString
+ Seq(Short.MaxValue, Short.MinValue)
+ .toDF("a")
+ .write
+ .mode("overwrite")
+ .parquet(shortArrayPath)
+ val shortArrayQuery = "select a, -a from t"
+ runArrayTest(shortArrayQuery, " caused", shortArrayPath)
+
+ // byte values test
+ val byteArrayPath = new Path(dir.toURI.toString,
"byte_array_test.parquet").toString
+ Seq(Byte.MaxValue, Byte.MinValue)
+ .toDF("a")
+ .write
+ .mode("overwrite")
+ .parquet(byteArrayPath)
+ val byteArrayQuery = "select a, -a from t"
+ runArrayTest(byteArrayQuery, " caused", byteArrayPath)
+
+ // interval values test
+ withTable("t_interval") {
+ spark.sql("CREATE TABLE t_interval(a STRING) USING PARQUET")
+ spark.sql("INSERT INTO t_interval VALUES ('INTERVAL 10000000000
YEAR')")
+ withAnsiMode(enabled = true) {
+ spark
+ .sql("SELECT CAST(a AS INTERVAL) AS a FROM t_interval")
+ .createOrReplaceTempView("t_interval_casted")
+ checkOverflow("SELECT a, -a FROM t_interval_casted", "interval")
+ }
+ }
+
+ withTable("t") {
+ sql("create table t(a int) using parquet")
+ sql("insert into t values (-2147483648)")
+ withAnsiMode(enabled = true) {
+ checkOverflow("select a, -a from t", "integer")
+ }
+ }
+
+ withTable("t_float") {
+ sql("create table t_float(a float) using parquet")
+ sql("insert into t_float values (3.4128235E38)")
+ withAnsiMode(enabled = true) {
+ checkOverflow("select a, -a from t_float", "float")
+ }
+ }
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]