This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new b3fe6aa68a Lead and Lag window functions should support default value 
with data type other than Int64 (#9001)
b3fe6aa68a is described below

commit b3fe6aa68adb644d275d8914b3802c153b4a3a27
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Fri Jan 26 08:42:23 2024 -0800

    Lead and Lag window functions should support default value with data type 
other than Int64 (#9001)
---
 datafusion/common/src/scalar.rs                 | 10 ++++++++++
 datafusion/physical-expr/src/window/lead_lag.rs | 14 +++-----------
 datafusion/sqllogictest/test_files/window.slt   | 14 ++++++++++++++
 3 files changed, 27 insertions(+), 11 deletions(-)

diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs
index 99b8cff20d..2f9e374bd7 100644
--- a/datafusion/common/src/scalar.rs
+++ b/datafusion/common/src/scalar.rs
@@ -2364,6 +2364,16 @@ impl ScalarValue {
         ScalarValue::try_from_array(&cast_arr, 0)
     }
 
+    /// Try to cast this value to a ScalarValue of type `data_type`
+    pub fn cast_to(&self, data_type: &DataType) -> Result<Self> {
+        let cast_options = CastOptions {
+            safe: false,
+            format_options: Default::default(),
+        };
+        let cast_arr = cast_with_options(&self.to_array()?, data_type, 
&cast_options)?;
+        ScalarValue::try_from_array(&cast_arr, 0)
+    }
+
     fn eq_array_decimal(
         array: &ArrayRef,
         index: usize,
diff --git a/datafusion/physical-expr/src/window/lead_lag.rs 
b/datafusion/physical-expr/src/window/lead_lag.rs
index d8072be839..c218b5555a 100644
--- a/datafusion/physical-expr/src/window/lead_lag.rs
+++ b/datafusion/physical-expr/src/window/lead_lag.rs
@@ -23,9 +23,7 @@ use crate::PhysicalExpr;
 use arrow::array::ArrayRef;
 use arrow::compute::cast;
 use arrow::datatypes::{DataType, Field};
-use datafusion_common::{
-    arrow_datafusion_err, exec_err, DataFusionError, Result, ScalarValue,
-};
+use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, 
ScalarValue};
 use datafusion_expr::PartitionEvaluator;
 use std::any::Any;
 use std::cmp::min;
@@ -238,15 +236,9 @@ fn get_default_value(
     dtype: &DataType,
 ) -> Result<ScalarValue> {
     match default_value {
-        Some(v) if v.data_type() == DataType::Int64 => {
-            ScalarValue::try_from_string(v.to_string(), dtype)
-        }
-        Some(v) if !v.data_type().is_null() => exec_err!(
-            "Unexpected datatype for default value: {}. Expected: Int64",
-            v.data_type()
-        ),
+        Some(v) if !v.data_type().is_null() => v.cast_to(dtype),
         // If None or Null datatype
-        _ => Ok(ScalarValue::try_from(dtype)?),
+        _ => ScalarValue::try_from(dtype),
     }
 }
 
diff --git a/datafusion/sqllogictest/test_files/window.slt 
b/datafusion/sqllogictest/test_files/window.slt
index 303e8e035e..aec2fed738 100644
--- a/datafusion/sqllogictest/test_files/window.slt
+++ b/datafusion/sqllogictest/test_files/window.slt
@@ -4004,3 +4004,17 @@ select lag(a, 1, null) over (order by a) from (select 1 
a union all select 2 a)
 ----
 NULL
 1
+
+# test LEAD window function with string default value
+query T
+select lead(a, 1, 'default') over (order by a) from (select '1' a union all 
select '2' a)
+----
+2
+default
+
+# test LAG window function with string default value
+query T
+select lag(a, 1, 'default') over (order by a) from (select '1' a union all 
select '2' a)
+----
+default
+1

Reply via email to