This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new b3fe6aa68a Lead and Lag window functions should support default value
with data type other than Int64 (#9001)
b3fe6aa68a is described below
commit b3fe6aa68adb644d275d8914b3802c153b4a3a27
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Fri Jan 26 08:42:23 2024 -0800
Lead and Lag window functions should support default value with data type
other than Int64 (#9001)
---
datafusion/common/src/scalar.rs | 10 ++++++++++
datafusion/physical-expr/src/window/lead_lag.rs | 14 +++-----------
datafusion/sqllogictest/test_files/window.slt | 14 ++++++++++++++
3 files changed, 27 insertions(+), 11 deletions(-)
diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs
index 99b8cff20d..2f9e374bd7 100644
--- a/datafusion/common/src/scalar.rs
+++ b/datafusion/common/src/scalar.rs
@@ -2364,6 +2364,16 @@ impl ScalarValue {
ScalarValue::try_from_array(&cast_arr, 0)
}
+ /// Try to cast this value to a ScalarValue of type `data_type`
+ pub fn cast_to(&self, data_type: &DataType) -> Result<Self> {
+ let cast_options = CastOptions {
+ safe: false,
+ format_options: Default::default(),
+ };
+ let cast_arr = cast_with_options(&self.to_array()?, data_type,
&cast_options)?;
+ ScalarValue::try_from_array(&cast_arr, 0)
+ }
+
fn eq_array_decimal(
array: &ArrayRef,
index: usize,
diff --git a/datafusion/physical-expr/src/window/lead_lag.rs
b/datafusion/physical-expr/src/window/lead_lag.rs
index d8072be839..c218b5555a 100644
--- a/datafusion/physical-expr/src/window/lead_lag.rs
+++ b/datafusion/physical-expr/src/window/lead_lag.rs
@@ -23,9 +23,7 @@ use crate::PhysicalExpr;
use arrow::array::ArrayRef;
use arrow::compute::cast;
use arrow::datatypes::{DataType, Field};
-use datafusion_common::{
- arrow_datafusion_err, exec_err, DataFusionError, Result, ScalarValue,
-};
+use datafusion_common::{arrow_datafusion_err, DataFusionError, Result,
ScalarValue};
use datafusion_expr::PartitionEvaluator;
use std::any::Any;
use std::cmp::min;
@@ -238,15 +236,9 @@ fn get_default_value(
dtype: &DataType,
) -> Result<ScalarValue> {
match default_value {
- Some(v) if v.data_type() == DataType::Int64 => {
- ScalarValue::try_from_string(v.to_string(), dtype)
- }
- Some(v) if !v.data_type().is_null() => exec_err!(
- "Unexpected datatype for default value: {}. Expected: Int64",
- v.data_type()
- ),
+ Some(v) if !v.data_type().is_null() => v.cast_to(dtype),
// If None or Null datatype
- _ => Ok(ScalarValue::try_from(dtype)?),
+ _ => ScalarValue::try_from(dtype),
}
}
diff --git a/datafusion/sqllogictest/test_files/window.slt
b/datafusion/sqllogictest/test_files/window.slt
index 303e8e035e..aec2fed738 100644
--- a/datafusion/sqllogictest/test_files/window.slt
+++ b/datafusion/sqllogictest/test_files/window.slt
@@ -4004,3 +4004,17 @@ select lag(a, 1, null) over (order by a) from (select 1
a union all select 2 a)
----
NULL
1
+
+# test LEAD window function with string default value
+query T
+select lead(a, 1, 'default') over (order by a) from (select '1' a union all
select '2' a)
+----
+2
+default
+
+# test LAG window function with string default value
+query T
+select lag(a, 1, 'default') over (order by a) from (select '1' a union all
select '2' a)
+----
+default
+1