This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 0ce6f1b1f fix bug: the static/constant set value must be evaluated and
get the result (#2834)
0ce6f1b1f is described below
commit 0ce6f1b1fd8fbf94238db913f8bc884e3c3c6aeb
Author: Kun Liu <[email protected]>
AuthorDate: Thu Jul 7 18:17:41 2022 +0800
fix bug: the static/constant set value must be evaluated and get the result
(#2834)
---
datafusion/core/src/physical_plan/planner.rs | 5 +-
.../physical-expr/src/expressions/in_list.rs | 129 +++++++++++++--------
datafusion/physical-expr/src/planner.rs | 2 +-
3 files changed, 82 insertions(+), 54 deletions(-)
diff --git a/datafusion/core/src/physical_plan/planner.rs
b/datafusion/core/src/physical_plan/planner.rs
index 0cb037620..40daef606 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -1890,14 +1890,13 @@ mod tests {
for i in 1..31 {
list.push(Expr::Literal(ScalarValue::Int64(Some(i))));
}
-
let logical_plan = test_csv_scan()
.await?
.filter(col("c12").lt(lit(0.05)))?
.project(vec![col("c1").in_list(list, false)])?
.build()?;
let execution_plan = plan(&logical_plan).await?;
- let expected = "expr: [(InListExpr { expr: Column { name: \"c1\",
index: 0 }, list: [Literal { value: Utf8(\"a\") }, TryCastExpr { expr: Literal
{ value: Int64(1) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value:
Int64(2) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(3) },
cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(4) }, cast_type:
Utf8 }, TryCastExpr { expr: Literal { value: Int64(5) }, cast_type: Utf8 },
TryCastExpr { expr: Literal [...]
+ let expected = "expr: [(InListExpr { expr: Column { name: \"c1\",
index: 0 }, list: [Literal { value: Utf8(\"a\") }, TryCastExpr { expr: Literal
{ value: Int64(1) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value:
Int64(2) }, cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(3) },
cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(4) }, cast_type:
Utf8 }, TryCastExpr { expr: Literal { value: Int64(5) }, cast_type: Utf8 },
TryCastExpr { expr: Literal [...]
assert!(format!("{:?}", execution_plan).contains(expected));
Ok(())
}
@@ -1916,7 +1915,7 @@ mod tests {
.project(vec![col("c1").in_list(list, false)])?
.build()?;
let execution_plan = plan(&logical_plan).await?;
- let expected = "expr: [(InListExpr { expr: Column { name: \"c1\",
index: 0 }, list: [TryCastExpr { expr: Literal { value: Int64(NULL) },
cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(1) }, cast_type:
Utf8 }, TryCastExpr { expr: Literal { value: Int64(2) }, cast_type: Utf8 },
TryCastExpr { expr: Literal { value: Int64(3) }, cast_type: Utf8 }, TryCastExpr
{ expr: Literal { value: Int64(4) }, cast_type: Utf8 }, TryCastExpr { expr:
Literal { value: Int64(5) }, cast_ty [...]
+ let expected = "expr: [(InListExpr { expr: Column { name: \"c1\",
index: 0 }, list: [TryCastExpr { expr: Literal { value: Int64(NULL) },
cast_type: Utf8 }, TryCastExpr { expr: Literal { value: Int64(1) }, cast_type:
Utf8 }, TryCastExpr { expr: Literal { value: Int64(2) }, cast_type: Utf8 },
TryCastExpr { expr: Literal { value: Int64(3) }, cast_type: Utf8 }, TryCastExpr
{ expr: Literal { value: Int64(4) }, cast_type: Utf8 }, TryCastExpr { expr:
Literal { value: Int64(5) }, cast_ty [...]
assert!(format!("{:?}", execution_plan).contains(expected));
Ok(())
}
diff --git a/datafusion/physical-expr/src/expressions/in_list.rs
b/datafusion/physical-expr/src/expressions/in_list.rs
index c9a3e419a..40a253c33 100644
--- a/datafusion/physical-expr/src/expressions/in_list.rs
+++ b/datafusion/physical-expr/src/expressions/in_list.rs
@@ -32,7 +32,7 @@ use arrow::{
record_batch::RecordBatch,
};
-use crate::{expressions, PhysicalExpr};
+use crate::PhysicalExpr;
use arrow::array::*;
use arrow::buffer::{Buffer, MutableBuffer};
use datafusion_common::ScalarValue;
@@ -269,40 +269,26 @@ fn not_in_list_utf8<OffsetSize: OffsetSizeTrait>(
compare_op_scalar!(array, values, |x, v: &[&str]| !v.contains(&x))
}
-//check all filter values of In clause are static.
-//include `CastExpr + Literal` or `Literal`
-fn check_all_static_filter_expr(list: &[Arc<dyn PhysicalExpr>]) -> bool {
- list.iter().all(|v| {
- let cast = v.as_any().downcast_ref::<expressions::CastExpr>();
- if let Some(c) = cast {
- c.expr()
- .as_any()
- .downcast_ref::<expressions::Literal>()
- .is_some()
- } else {
- let cast = v.as_any().downcast_ref::<expressions::Literal>();
- cast.is_some()
- }
- })
-}
-
-fn cast_static_filter_to_set(list: &[Arc<dyn PhysicalExpr>]) ->
HashSet<ScalarValue> {
- HashSet::from_iter(list.iter().map(|expr| {
- if let Some(cast) =
expr.as_any().downcast_ref::<expressions::CastExpr>() {
- cast.expr()
- .as_any()
- .downcast_ref::<expressions::Literal>()
- .unwrap()
- .value()
- .clone()
- } else {
- expr.as_any()
- .downcast_ref::<expressions::Literal>()
- .unwrap()
- .value()
- .clone()
- }
- }))
+// try evaluate all list exprs and check if the exprs are constants or not
+fn try_cast_static_filter_to_set(
+ list: &[Arc<dyn PhysicalExpr>],
+ schema: &Schema,
+) -> Result<HashSet<ScalarValue>> {
+ let batch = RecordBatch::new_empty(Arc::new(schema.to_owned()));
+ match list
+ .iter()
+ .map(|expr| match expr.evaluate(&batch) {
+ Ok(ColumnarValue::Array(_)) => Err(DataFusionError::NotImplemented(
+ "InList doesn't support to evaluate the array
result".to_string(),
+ )),
+ Ok(ColumnarValue::Scalar(s)) => Ok(s),
+ Err(e) => Err(e),
+ })
+ .collect::<Result<Vec<_>>>()
+ {
+ Ok(s) => Ok(HashSet::from_iter(s)),
+ Err(e) => Err(e),
+ }
}
fn make_list_contains_decimal(
@@ -379,22 +365,24 @@ impl InListExpr {
expr: Arc<dyn PhysicalExpr>,
list: Vec<Arc<dyn PhysicalExpr>>,
negated: bool,
+ schema: &Schema,
) -> Self {
- if list.len() > OPTIMIZER_INSET_THRESHOLD &&
check_all_static_filter_expr(&list) {
- Self {
- expr,
- set: Some(InSet::new(cast_static_filter_to_set(&list))),
- list,
- negated,
- }
- } else {
- Self {
- expr,
- list,
- negated,
- set: None,
+ if list.len() > OPTIMIZER_INSET_THRESHOLD {
+ if let Ok(set) = try_cast_static_filter_to_set(&list, schema) {
+ return Self {
+ expr,
+ set: Some(InSet::new(set)),
+ list,
+ negated,
+ };
}
}
+ Self {
+ expr,
+ list,
+ negated,
+ set: None,
+ }
}
/// Input expression
@@ -795,8 +783,9 @@ pub fn in_list(
expr: Arc<dyn PhysicalExpr>,
list: Vec<Arc<dyn PhysicalExpr>>,
negated: &bool,
+ schema: &Schema,
) -> Result<Arc<dyn PhysicalExpr>> {
- Ok(Arc::new(InListExpr::new(expr, list, *negated)))
+ Ok(Arc::new(InListExpr::new(expr, list, *negated, schema)))
}
#[cfg(test)]
@@ -804,6 +793,7 @@ mod tests {
use arrow::{array::StringArray, datatypes::Field};
use super::*;
+ use crate::expressions;
use crate::expressions::{col, lit};
use crate::planner::in_list_cast;
use datafusion_common::Result;
@@ -812,7 +802,7 @@ mod tests {
macro_rules! in_list {
($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr,
$SCHEMA:expr) => {{
let (cast_expr, cast_list_exprs) = in_list_cast($COL, $LIST,
$SCHEMA)?;
- let expr = in_list(cast_expr, cast_list_exprs, $NEGATED).unwrap();
+ let expr = in_list(cast_expr, cast_list_exprs, $NEGATED,
$SCHEMA).unwrap();
let result = expr.evaluate(&$BATCH)?.into_array($BATCH.num_rows());
let result = result
.as_any()
@@ -1326,4 +1316,43 @@ mod tests {
);
Ok(())
}
+
+ #[test]
+ fn test_cast_static_filter_to_set() -> Result<()> {
+ // random schema
+ let schema = Schema::new(vec![Field::new("a", DataType::Decimal(13,
4), true)]);
+ // list of phy expr
+ let mut phy_exprs = vec![
+ lit(1i64),
+ expressions::cast(lit(2i32), &schema, DataType::Int64)?,
+ expressions::try_cast(lit(3.13f32), &schema, DataType::Int64)?,
+ ];
+ let result = try_cast_static_filter_to_set(&phy_exprs,
&schema).unwrap();
+
+ assert!(result.contains(&1i64.try_into().unwrap()));
+ assert!(result.contains(&2i64.try_into().unwrap()));
+ assert!(result.contains(&3i64.try_into().unwrap()));
+
+ assert!(try_cast_static_filter_to_set(&phy_exprs, &schema).is_ok());
+ // cast(cast(lit())), but the cast to the same data type, one case
will be ignored
+ phy_exprs.push(expressions::cast(
+ expressions::cast(lit(2i32), &schema, DataType::Int64)?,
+ &schema,
+ DataType::Int64,
+ )?);
+ assert!(try_cast_static_filter_to_set(&phy_exprs, &schema).is_ok());
+ // case(cast(lit())), the cast to the diff data type
+ phy_exprs.push(expressions::cast(
+ expressions::cast(lit(2i32), &schema, DataType::Int64)?,
+ &schema,
+ DataType::Int32,
+ )?);
+ assert!(try_cast_static_filter_to_set(&phy_exprs, &schema).is_ok());
+
+ // column
+ phy_exprs.push(expressions::col("a", &schema)?);
+ assert!(try_cast_static_filter_to_set(&phy_exprs, &schema).is_err());
+
+ Ok(())
+ }
}
diff --git a/datafusion/physical-expr/src/planner.rs
b/datafusion/physical-expr/src/planner.rs
index 1eb975bed..570f438e2 100644
--- a/datafusion/physical-expr/src/planner.rs
+++ b/datafusion/physical-expr/src/planner.rs
@@ -302,7 +302,7 @@ pub fn create_physical_expr(
let (cast_expr, cast_list_exprs) =
in_list_cast(value_expr, list_exprs, input_schema)?;
- expressions::in_list(cast_expr, cast_list_exprs, negated)
+ expressions::in_list(cast_expr, cast_list_exprs, negated,
input_schema)
}
},
other => Err(DataFusionError::NotImplemented(format!(