This is an automated email from the ASF dual-hosted git repository.

avantgardner pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 49237a21a Bug fix: Window frame range value outside the type range 
(#5384)
49237a21a is described below

commit 49237a21a707944aa0218c346f1c03e8d2cf478e
Author: Mustafa Akur <[email protected]>
AuthorDate: Fri Feb 24 22:11:37 2023 +0300

    Bug fix: Window frame range value outside the type range (#5384)
    
    * Initial Implementation
    
    * Change error type
    
    * just use largest type
    
    * Refactors, simplifications, comment improvements
    
    * add new test
    
    * Change error type
    
    ---------
    
    Co-authored-by: Mehmet Ozan Kabak <[email protected]>
---
 datafusion/core/tests/sql/window.rs                |   2 +-
 .../core/tests/sqllogictests/test_files/window.slt |  28 +++-
 datafusion/expr/src/type_coercion.rs               |  13 +-
 datafusion/optimizer/src/type_coercion.rs          | 145 +++++++++++++--------
 4 files changed, 123 insertions(+), 65 deletions(-)

diff --git a/datafusion/core/tests/sql/window.rs 
b/datafusion/core/tests/sql/window.rs
index 3de7482d7..8f5d1584e 100644
--- a/datafusion/core/tests/sql/window.rs
+++ b/datafusion/core/tests/sql/window.rs
@@ -671,7 +671,7 @@ async fn window_frame_creation_type_checking() -> 
Result<()> {
     // Error is returned from the logical plan.
     check_query(
         false,
-        "Internal error: Optimizer rule 'type_coercion' failed due to 
unexpected error: Arrow error: Cast error: Cannot cast string '1 DAY' to value 
of UInt32 type"
+        "Internal error: Optimizer rule 'type_coercion' failed due to 
unexpected error: Execution error: Cannot cast Utf8(\"1 DAY\") to UInt32."
     ).await
 }
 
diff --git a/datafusion/core/tests/sqllogictests/test_files/window.slt 
b/datafusion/core/tests/sqllogictests/test_files/window.slt
index cbbc82c91..64920bb3d 100644
--- a/datafusion/core/tests/sqllogictests/test_files/window.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/window.slt
@@ -527,8 +527,7 @@ LIMIT 5
 #// }
 
 # async fn window_frame_ranges_preceding_following_desc
-# This query should pass. Tracked in  
https://github.com/apache/arrow-datafusion/issues/5346
-query error DataFusion error: Internal error: Operator \+ is not implemented
+query III
 SELECT
 SUM(c4) OVER(ORDER BY c2 DESC RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING),
 SUM(c3) OVER(ORDER BY c2 DESC RANGE BETWEEN 10000 PRECEDING AND 10000 
FOLLOWING),
@@ -536,6 +535,31 @@ COUNT(*) OVER(ORDER BY c2 DESC RANGE BETWEEN 1 PRECEDING 
AND 1 FOLLOWING)
 FROM aggregate_test_100
 ORDER BY c9
 LIMIT 5
+----
+52276 781 56
+260620 781 63
+-28623 781 37
+260620 781 63
+260620 781 63
+
+# async fn window_frame_large_range
+# Range offset 10000 is too big for Int8 (i.e. the type of c3).
+# In this case, we should be able to still produce correct results.
+# See the issue: https://github.com/apache/arrow-datafusion/issues/5346
+# below over clause is equivalent to OVER(ORDER BY c3 DESC RANGE BETWEEN 
UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)
+# in terms of behaviour.
+query I
+SELECT
+SUM(c3) OVER(ORDER BY c3 DESC RANGE BETWEEN 10000 PRECEDING AND 10000 
FOLLOWING) as summation1
+FROM aggregate_test_100
+ORDER BY c9
+LIMIT 5
+----
+781
+781
+781
+781
+781
 
 # async fn window_frame_order_by_asc_desc_large
 query I
diff --git a/datafusion/expr/src/type_coercion.rs 
b/datafusion/expr/src/type_coercion.rs
index 502925e9b..ca624a877 100644
--- a/datafusion/expr/src/type_coercion.rs
+++ b/datafusion/expr/src/type_coercion.rs
@@ -33,7 +33,7 @@
 
 use arrow::datatypes::DataType;
 
-/// Determine if a DataType is signed numeric or not
+/// Determine whether the given data type `dt` represents unsigned numeric 
values.
 pub fn is_signed_numeric(dt: &DataType) -> bool {
     matches!(
         dt,
@@ -48,12 +48,12 @@ pub fn is_signed_numeric(dt: &DataType) -> bool {
     )
 }
 
-// Determine if a DataType is Null or not
+/// Determine whether the given data type `dt` is `Null`.
 pub fn is_null(dt: &DataType) -> bool {
     *dt == DataType::Null
 }
 
-/// Determine if a DataType is numeric or not
+/// Determine whether the given data type `dt` represents numeric values.
 pub fn is_numeric(dt: &DataType) -> bool {
     is_signed_numeric(dt)
         || matches!(
@@ -62,17 +62,18 @@ pub fn is_numeric(dt: &DataType) -> bool {
         )
 }
 
-/// Determine if a DataType is Timestamp or not
+/// Determine whether the given data type `dt` is a `Timestamp`.
 pub fn is_timestamp(dt: &DataType) -> bool {
     matches!(dt, DataType::Timestamp(_, _))
 }
 
-/// Determine if a DataType is Date or not
+/// Determine whether the given data type `dt` is a `Date`.
 pub fn is_date(dt: &DataType) -> bool {
     matches!(dt, DataType::Date32 | DataType::Date64)
 }
 
-pub fn is_uft8(dt: &DataType) -> bool {
+/// Determine whether the given data type `dt` is a `Utf8`.
+pub fn is_utf8(dt: &DataType) -> bool {
     matches!(dt, DataType::Utf8)
 }
 
diff --git a/datafusion/optimizer/src/type_coercion.rs 
b/datafusion/optimizer/src/type_coercion.rs
index 6b6cea82f..b1ee55f93 100644
--- a/datafusion/optimizer/src/type_coercion.rs
+++ b/datafusion/optimizer/src/type_coercion.rs
@@ -34,7 +34,7 @@ use datafusion_expr::type_coercion::functions::data_types;
 use datafusion_expr::type_coercion::other::{
     get_coerce_type_for_case_when, get_coerce_type_for_list,
 };
-use datafusion_expr::type_coercion::{is_date, is_numeric, is_timestamp, 
is_uft8};
+use datafusion_expr::type_coercion::{is_date, is_numeric, is_timestamp, 
is_utf8};
 use datafusion_expr::utils::from_plan;
 use datafusion_expr::{
     aggregate_function, function, is_false, is_not_false, is_not_true, 
is_not_unknown,
@@ -411,7 +411,7 @@ impl ExprRewriter for TypeCoercionRewriter {
                 window_frame,
             }) => {
                 let window_frame =
-                    get_coerced_window_frame(window_frame, &self.schema, 
&order_by)?;
+                    coerce_window_frame(window_frame, &self.schema, 
&order_by)?;
                 let expr = Expr::WindowFunction(WindowFunction::new(
                     fun,
                     args,
@@ -426,95 +426,128 @@ impl ExprRewriter for TypeCoercionRewriter {
     }
 }
 
-/// Casts the ScalarValue `value` to coerced type.
-// When coerced type is `Interval` we use `parse_interval` since 
`try_from_string` not
-// supports conversion from string to Interval
-fn convert_to_coerced_type(
-    coerced_type: &DataType,
-    value: &ScalarValue,
-) -> Result<ScalarValue> {
+/// Casts the given `value` to `target_type`. Note that this function
+/// only considers `Null` or `Utf8` values.
+fn coerce_scalar(target_type: &DataType, value: &ScalarValue) -> 
Result<ScalarValue> {
     match value {
-        // In here we do casting either for NULL types or
-        // ScalarValue::Utf8(Some(val)). The other types are already casted.
-        // The reason is that we convert the sqlparser result
-        // to the Utf8 for all possible cases. Hence the types other than Utf8
-        // are already casted to appropriate type. Therefore they can be 
returned directly.
+        // Coerce Utf8 values:
         ScalarValue::Utf8(Some(val)) => {
-            // we need special handling for Interval types
-            if let DataType::Interval(..) = coerced_type {
+            // When `target_type` is `Interval`, we use `parse_interval` since
+            // `try_from_string` does not support `String` to `Interval` 
coercions.
+            if let DataType::Interval(..) = target_type {
                 parse_interval("millisecond", val)
             } else {
-                ScalarValue::try_from_string(val.clone(), coerced_type)
+                ScalarValue::try_from_string(val.clone(), target_type)
             }
         }
         s => {
             if s.is_null() {
-                ScalarValue::try_from(coerced_type)
+                // Coerce `Null` values:
+                ScalarValue::try_from(target_type)
             } else {
+                // Values except `Utf8`/`Null` variants already have the right 
type
+                // (casted before) since we convert `sqlparser` outputs to 
`Utf8`
+                // for all possible cases. Therefore, we return a clone here.
                 Ok(s.clone())
             }
         }
     }
 }
 
+/// This function coerces `value` to `target_type` in a range-aware fashion.
+/// If the coercion is successful, we return an `Ok` value with the result.
+/// If the coercion fails because `target_type` is not wide enough (i.e. we
+/// can not coerce to `target_type`, but we can to a wider type in the same
+/// family), we return a `Null` value of this type to signal this situation.
+/// Downstream code uses this signal to treat these values as *unbounded*.
+fn coerce_scalar_range_aware(
+    target_type: &DataType,
+    value: &ScalarValue,
+) -> Result<ScalarValue> {
+    coerce_scalar(target_type, value).or_else(|err| {
+        // If type coercion fails, check if the largest type in family works:
+        if let Some(largest_type) = get_widest_type_in_family(target_type) {
+            coerce_scalar(largest_type, value).map_or_else(
+                |_| {
+                    Err(DataFusionError::Execution(format!(
+                        "Cannot cast {:?} to {:?}",
+                        value, target_type
+                    )))
+                },
+                |_| ScalarValue::try_from(target_type),
+            )
+        } else {
+            Err(err)
+        }
+    })
+}
+
+/// This function returns the widest type in the family of `given_type`.
+/// If the given type is already the widest type, it returns `None`.
+/// For example, if `given_type` is `Int8`, it returns `Int64`.
+fn get_widest_type_in_family(given_type: &DataType) -> Option<&DataType> {
+    match given_type {
+        DataType::UInt8 | DataType::UInt16 | DataType::UInt32 => 
Some(&DataType::UInt64),
+        DataType::Int8 | DataType::Int16 | DataType::Int32 => 
Some(&DataType::Int64),
+        DataType::Float16 | DataType::Float32 => Some(&DataType::Float64),
+        _ => None,
+    }
+}
+
+/// Coerces the given (window frame) `bound` to `target_type`.
 fn coerce_frame_bound(
-    coerced_type: &DataType,
+    target_type: &DataType,
     bound: &WindowFrameBound,
 ) -> Result<WindowFrameBound> {
-    Ok(match bound {
-        WindowFrameBound::Preceding(val) => {
-            WindowFrameBound::Preceding(convert_to_coerced_type(coerced_type, 
val)?)
+    match bound {
+        WindowFrameBound::Preceding(v) => {
+            coerce_scalar_range_aware(target_type, 
v).map(WindowFrameBound::Preceding)
         }
-        WindowFrameBound::CurrentRow => WindowFrameBound::CurrentRow,
-        WindowFrameBound::Following(val) => {
-            WindowFrameBound::Following(convert_to_coerced_type(coerced_type, 
val)?)
+        WindowFrameBound::CurrentRow => Ok(WindowFrameBound::CurrentRow),
+        WindowFrameBound::Following(v) => {
+            coerce_scalar_range_aware(target_type, 
v).map(WindowFrameBound::Following)
         }
-    })
+    }
 }
 
-fn get_coerced_window_frame(
+// Coerces the given `window_frame` to use appropriate natural types.
+// For example, ROWS and GROUPS frames use `UInt64` during calculations.
+fn coerce_window_frame(
     window_frame: WindowFrame,
     schema: &DFSchemaRef,
     expressions: &[Expr],
 ) -> Result<WindowFrame> {
-    fn get_coerced_type(column_type: &DataType) -> Result<DataType> {
-        if is_numeric(column_type) | is_uft8(column_type) {
-            Ok(column_type.clone())
-        } else if is_timestamp(column_type) || is_date(column_type) {
-            Ok(DataType::Interval(IntervalUnit::MonthDayNano))
-        } else {
-            Err(DataFusionError::Internal(format!(
-                "Cannot run range queries on datatype: {column_type:?}"
-            )))
-        }
-    }
-
     let mut window_frame = window_frame;
     let current_types = expressions
         .iter()
         .map(|e| e.get_type(schema))
         .collect::<Result<Vec<_>>>()?;
-    match &mut window_frame.units {
+    let target_type = match window_frame.units {
         WindowFrameUnits::Range => {
-            let col_type = current_types.first().ok_or_else(|| {
-                DataFusionError::Internal("ORDER BY column cannot be 
empty".to_string())
-            })?;
-            let coerced_type = get_coerced_type(col_type)?;
-            window_frame.start_bound =
-                coerce_frame_bound(&coerced_type, &window_frame.start_bound)?;
-            window_frame.end_bound =
-                coerce_frame_bound(&coerced_type, &window_frame.end_bound)?;
-        }
-        WindowFrameUnits::Rows | WindowFrameUnits::Groups => {
-            let coerced_type = DataType::UInt64;
-            window_frame.start_bound =
-                coerce_frame_bound(&coerced_type, &window_frame.start_bound)?;
-            window_frame.end_bound =
-                coerce_frame_bound(&coerced_type, &window_frame.end_bound)?;
+            if let Some(col_type) = current_types.first() {
+                if is_numeric(col_type) || is_utf8(col_type) {
+                    col_type
+                } else if is_timestamp(col_type) || is_date(col_type) {
+                    &DataType::Interval(IntervalUnit::MonthDayNano)
+                } else {
+                    return Err(DataFusionError::Internal(format!(
+                        "Cannot run range queries on datatype: {col_type:?}"
+                    )));
+                }
+            } else {
+                return Err(DataFusionError::Internal(
+                    "ORDER BY column cannot be empty".to_string(),
+                ));
+            }
         }
-    }
+        WindowFrameUnits::Rows | WindowFrameUnits::Groups => &DataType::UInt64,
+    };
+    window_frame.start_bound =
+        coerce_frame_bound(target_type, &window_frame.start_bound)?;
+    window_frame.end_bound = coerce_frame_bound(target_type, 
&window_frame.end_bound)?;
     Ok(window_frame)
 }
+
 // Support the `IsTrue` `IsNotTrue` `IsFalse` `IsNotFalse` type coercion.
 // The above op will be rewrite to the binary op when creating the physical op.
 fn get_casted_expr_for_bool_op(expr: &Expr, schema: &DFSchemaRef) -> 
Result<Expr> {

Reply via email to