This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 569f6fe87e Remove input schema from PhysicalExpr, move the validation
logic to physical expression planner (#6122)
569f6fe87e is described below
commit 569f6fe87ec0035d4263f6b9f1dfbf29677054c5
Author: Ken, Wang <[email protected]>
AuthorDate: Wed Apr 26 18:45:29 2023 +0800
Remove input schema from PhysicalExpr, move the validation logic to
physical expression planner (#6122)
* Remove input schema from PhysicalExpr, move the validation logic to
phyiscal expression planner
* fix fmt
---
.../physical-expr/src/expressions/datetime.rs | 74 +++++++++++-----------
.../physical-expr/src/expressions/in_list.rs | 46 ++++++++------
datafusion/physical-expr/src/expressions/mod.rs | 2 +-
.../physical-expr/src/intervals/test_utils.rs | 26 +++-----
datafusion/physical-expr/src/planner.rs | 30 ++-------
datafusion/physical-expr/src/utils.rs | 23 +------
datafusion/proto/src/physical_plan/from_proto.rs | 21 +++---
datafusion/proto/src/physical_plan/mod.rs | 22 +++----
8 files changed, 99 insertions(+), 145 deletions(-)
diff --git a/datafusion/physical-expr/src/expressions/datetime.rs
b/datafusion/physical-expr/src/expressions/datetime.rs
index 25f02be3b1..dae12fea73 100644
--- a/datafusion/physical-expr/src/expressions/datetime.rs
+++ b/datafusion/physical-expr/src/expressions/datetime.rs
@@ -37,45 +37,16 @@ pub struct DateTimeIntervalExpr {
lhs: Arc<dyn PhysicalExpr>,
op: Operator,
rhs: Arc<dyn PhysicalExpr>,
- // TODO: move type checking to the planning phase and not in the physical
expr
- // so we can remove this
- input_schema: Schema,
}
impl DateTimeIntervalExpr {
/// Create a new instance of DateIntervalExpr
- pub fn try_new(
+ pub fn new(
lhs: Arc<dyn PhysicalExpr>,
op: Operator,
rhs: Arc<dyn PhysicalExpr>,
- input_schema: &Schema,
- ) -> Result<Self> {
- match (
- lhs.data_type(input_schema)?,
- op,
- rhs.data_type(input_schema)?,
- ) {
- (
- DataType::Date32 | DataType::Date64 | DataType::Timestamp(_,
_),
- Operator::Plus | Operator::Minus,
- DataType::Interval(_),
- )
- | (DataType::Timestamp(_, _), Operator::Minus,
DataType::Timestamp(_, _))
- | (DataType::Interval(_), Operator::Plus, DataType::Timestamp(_,
_))
- | (
- DataType::Interval(_),
- Operator::Plus | Operator::Minus,
- DataType::Interval(_),
- ) => Ok(Self {
- lhs,
- op,
- rhs,
- input_schema: input_schema.clone(),
- }),
- (lhs, _, rhs) => Err(DataFusionError::Execution(format!(
- "Invalid operation {op} between '{lhs}' and '{rhs}' for
DateIntervalExpr"
- ))),
- }
+ ) -> Self {
+ Self { lhs, op, rhs }
}
/// Get the left-hand side expression
@@ -202,12 +173,11 @@ impl PhysicalExpr for DateTimeIntervalExpr {
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>> {
- Ok(Arc::new(DateTimeIntervalExpr::try_new(
+ Ok(Arc::new(DateTimeIntervalExpr::new(
children[0].clone(),
self.op,
children[1].clone(),
- &self.input_schema,
- )?))
+ )))
}
}
@@ -220,6 +190,36 @@ impl PartialEq<dyn Any> for DateTimeIntervalExpr {
}
}
+/// create a DateIntervalExpr
+pub fn date_time_interval_expr(
+ lhs: Arc<dyn PhysicalExpr>,
+ op: Operator,
+ rhs: Arc<dyn PhysicalExpr>,
+ input_schema: &Schema,
+) -> Result<Arc<dyn PhysicalExpr>> {
+ match (
+ lhs.data_type(input_schema)?,
+ op,
+ rhs.data_type(input_schema)?,
+ ) {
+ (
+ DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, _),
+ Operator::Plus | Operator::Minus,
+ DataType::Interval(_),
+ )
+ | (DataType::Timestamp(_, _), Operator::Minus, DataType::Timestamp(_,
_))
+ | (DataType::Interval(_), Operator::Plus, DataType::Timestamp(_, _))
+ | (
+ DataType::Interval(_),
+ Operator::Plus | Operator::Minus,
+ DataType::Interval(_),
+ ) => Ok(Arc::new(DateTimeIntervalExpr::new(lhs, op, rhs))),
+ (lhs, _, rhs) => Err(DataFusionError::Execution(format!(
+ "Invalid operation {op} between '{lhs}' and '{rhs}' for
DateIntervalExpr"
+ ))),
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -535,7 +535,7 @@ mod tests {
let lhs = create_physical_expr(&dt, &dfs, &schema, &props)?;
let rhs = create_physical_expr(&interval, &dfs, &schema, &props)?;
- let cut = DateTimeIntervalExpr::try_new(lhs, op, rhs, &schema)?;
+ let cut = date_time_interval_expr(lhs, op, rhs, &schema)?;
let res = cut.evaluate(&batch)?;
let mut builder = Date32Builder::with_capacity(8);
@@ -613,7 +613,7 @@ mod tests {
let lhs_str = format!("{lhs}");
let rhs_str = format!("{rhs}");
- let cut = DateTimeIntervalExpr::try_new(lhs, op, rhs, &schema)?;
+ let cut = DateTimeIntervalExpr::new(lhs, op, rhs);
assert_eq!(lhs_str, format!("{}", cut.lhs()));
assert_eq!(op, cut.op().clone());
diff --git a/datafusion/physical-expr/src/expressions/in_list.rs
b/datafusion/physical-expr/src/expressions/in_list.rs
index 575050cfff..3feb728900 100644
--- a/datafusion/physical-expr/src/expressions/in_list.rs
+++ b/datafusion/physical-expr/src/expressions/in_list.rs
@@ -47,8 +47,7 @@ pub struct InListExpr {
expr: Arc<dyn PhysicalExpr>,
list: Vec<Arc<dyn PhysicalExpr>>,
negated: bool,
- static_filter: Option<Box<dyn Set>>,
- input_schema: Schema,
+ static_filter: Option<Arc<dyn Set>>,
}
impl Debug for InListExpr {
@@ -62,7 +61,7 @@ impl Debug for InListExpr {
}
/// A type-erased container of array elements
-trait Set: Send + Sync {
+pub trait Set: Send + Sync {
fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray>;
}
@@ -172,36 +171,36 @@ where
}
/// Creates a `Box<dyn Set>` for the given list of `IN` expressions and `batch`
-fn make_set(array: &dyn Array) -> Result<Box<dyn Set>> {
+fn make_set(array: &dyn Array) -> Result<Arc<dyn Set>> {
Ok(downcast_primitive_array! {
- array => Box::new(ArraySet::new(array, make_hash_set(array))),
+ array => Arc::new(ArraySet::new(array, make_hash_set(array))),
DataType::Boolean => {
let array = as_boolean_array(array)?;
- Box::new(ArraySet::new(array, make_hash_set(array)))
+ Arc::new(ArraySet::new(array, make_hash_set(array)))
},
DataType::Decimal128(_, _) => {
let array = as_primitive_array::<Decimal128Type>(array)?;
- Box::new(ArraySet::new(array, make_hash_set(array)))
+ Arc::new(ArraySet::new(array, make_hash_set(array)))
}
DataType::Decimal256(_, _) => {
let array = as_primitive_array::<Decimal256Type>(array)?;
- Box::new(ArraySet::new(array, make_hash_set(array)))
+ Arc::new(ArraySet::new(array, make_hash_set(array)))
}
DataType::Utf8 => {
let array = as_string_array(array)?;
- Box::new(ArraySet::new(array, make_hash_set(array)))
+ Arc::new(ArraySet::new(array, make_hash_set(array)))
}
DataType::LargeUtf8 => {
let array = as_largestring_array(array);
- Box::new(ArraySet::new(array, make_hash_set(array)))
+ Arc::new(ArraySet::new(array, make_hash_set(array)))
}
DataType::Binary => {
let array = as_generic_binary_array::<i32>(array)?;
- Box::new(ArraySet::new(array, make_hash_set(array)))
+ Arc::new(ArraySet::new(array, make_hash_set(array)))
}
DataType::LargeBinary => {
let array = as_generic_binary_array::<i64>(array)?;
- Box::new(ArraySet::new(array, make_hash_set(array)))
+ Arc::new(ArraySet::new(array, make_hash_set(array)))
}
DataType::Dictionary(_, _) => unreachable!("dictionary should have
been flattened"),
d => return Err(DataFusionError::NotImplemented(format!("DataType::{d}
not supported in InList")))
@@ -233,7 +232,7 @@ fn evaluate_list(
fn try_cast_static_filter_to_set(
list: &[Arc<dyn PhysicalExpr>],
schema: &Schema,
-) -> Result<Box<dyn Set>> {
+) -> Result<Arc<dyn Set>> {
let batch = RecordBatch::new_empty(Arc::new(schema.clone()));
make_set(evaluate_list(list, &batch)?.as_ref())
}
@@ -244,15 +243,13 @@ impl InListExpr {
expr: Arc<dyn PhysicalExpr>,
list: Vec<Arc<dyn PhysicalExpr>>,
negated: bool,
- schema: &Schema,
+ static_filter: Option<Arc<dyn Set>>,
) -> Self {
- let static_filter = try_cast_static_filter_to_set(&list, schema).ok();
Self {
expr,
list,
negated,
static_filter,
- input_schema: schema.clone(),
}
}
@@ -325,12 +322,13 @@ impl PhysicalExpr for InListExpr {
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>> {
- in_list(
+ // assume the static_filter will not change during the rewrite process
+ Ok(Arc::new(InListExpr::new(
children[0].clone(),
children[1..].to_vec(),
- &self.negated,
- &self.input_schema,
- )
+ self.negated,
+ self.static_filter.clone(),
+ )))
}
}
@@ -364,7 +362,13 @@ pub fn in_list(
)));
}
}
- Ok(Arc::new(InListExpr::new(expr, list, *negated, schema)))
+ let static_filter = try_cast_static_filter_to_set(&list, schema).ok();
+ Ok(Arc::new(InListExpr::new(
+ expr,
+ list,
+ *negated,
+ static_filter,
+ )))
}
#[cfg(test)]
diff --git a/datafusion/physical-expr/src/expressions/mod.rs
b/datafusion/physical-expr/src/expressions/mod.rs
index ad4b7031c0..135e24dc83 100644
--- a/datafusion/physical-expr/src/expressions/mod.rs
+++ b/datafusion/physical-expr/src/expressions/mod.rs
@@ -78,7 +78,7 @@ pub use cast::{
cast, cast_column, cast_with_options, CastExpr,
DEFAULT_DATAFUSION_CAST_OPTIONS,
};
pub use column::{col, Column, UnKnownColumn};
-pub use datetime::DateTimeIntervalExpr;
+pub use datetime::{date_time_interval_expr, DateTimeIntervalExpr};
pub use get_indexed_field::GetIndexedFieldExpr;
pub use in_list::{in_list, InListExpr};
pub use is_not_null::{is_not_null, IsNotNullExpr};
diff --git a/datafusion/physical-expr/src/intervals/test_utils.rs
b/datafusion/physical-expr/src/intervals/test_utils.rs
index 6bbf74dc7d..8e695c2556 100644
--- a/datafusion/physical-expr/src/intervals/test_utils.rs
+++ b/datafusion/physical-expr/src/intervals/test_utils.rs
@@ -19,7 +19,7 @@
use std::sync::Arc;
-use crate::expressions::{BinaryExpr, DateTimeIntervalExpr, Literal};
+use crate::expressions::{date_time_interval_expr, BinaryExpr, Literal};
use crate::PhysicalExpr;
use arrow_schema::Schema;
use datafusion_common::{DataFusionError, ScalarValue};
@@ -78,30 +78,22 @@ pub fn gen_conjunctive_temporal_expr(
d: ScalarValue,
schema: &Schema,
) -> Result<Arc<dyn PhysicalExpr>, DataFusionError> {
- let left_and_1 = Arc::new(DateTimeIntervalExpr::try_new(
+ let left_and_1 = date_time_interval_expr(
left_col.clone(),
op_1,
Arc::new(Literal::new(a)),
schema,
- )?);
- let left_and_2 = Arc::new(DateTimeIntervalExpr::try_new(
+ )?;
+ let left_and_2 = date_time_interval_expr(
right_col.clone(),
op_2,
Arc::new(Literal::new(b)),
schema,
- )?);
- let right_and_1 = Arc::new(DateTimeIntervalExpr::try_new(
- left_col,
- op_3,
- Arc::new(Literal::new(c)),
- schema,
- )?);
- let right_and_2 = Arc::new(DateTimeIntervalExpr::try_new(
- right_col,
- op_4,
- Arc::new(Literal::new(d)),
- schema,
- )?);
+ )?;
+ let right_and_1 =
+ date_time_interval_expr(left_col, op_3, Arc::new(Literal::new(c)),
schema)?;
+ let right_and_2 =
+ date_time_interval_expr(right_col, op_4, Arc::new(Literal::new(d)),
schema)?;
let left_expr = Arc::new(BinaryExpr::new(left_and_1, Operator::Gt,
left_and_2));
let right_expr = Arc::new(BinaryExpr::new(right_and_1, Operator::Lt,
right_and_2));
Ok(Arc::new(BinaryExpr::new(
diff --git a/datafusion/physical-expr/src/planner.rs
b/datafusion/physical-expr/src/planner.rs
index 0266ecfd2e..f1bb35c2e2 100644
--- a/datafusion/physical-expr/src/planner.rs
+++ b/datafusion/physical-expr/src/planner.rs
@@ -19,7 +19,7 @@ use crate::var_provider::is_system_variables;
use crate::{
execution_props::ExecutionProps,
expressions::{
- self, binary, like, Column, DateTimeIntervalExpr, GetIndexedFieldExpr,
Literal,
+ self, binary, date_time_interval_expr, like, Column,
GetIndexedFieldExpr, Literal,
},
functions, udf,
var_provider::VarType,
@@ -195,42 +195,22 @@ pub fn create_physical_expr(
DataType::Date32 | DataType::Date64 |
DataType::Timestamp(_, _),
Operator::Plus | Operator::Minus,
DataType::Interval(_),
- ) => Ok(Arc::new(DateTimeIntervalExpr::try_new(
- lhs,
- *op,
- rhs,
- input_schema,
- )?)),
+ ) => Ok(date_time_interval_expr(lhs, *op, rhs, input_schema)?),
(
DataType::Interval(_),
Operator::Plus | Operator::Minus,
DataType::Date32 | DataType::Date64 |
DataType::Timestamp(_, _),
- ) => Ok(Arc::new(DateTimeIntervalExpr::try_new(
- rhs,
- *op,
- lhs,
- input_schema,
- )?)),
+ ) => Ok(date_time_interval_expr(rhs, *op, lhs, input_schema)?),
(
DataType::Timestamp(_, _),
Operator::Minus,
DataType::Timestamp(_, _),
- ) => Ok(Arc::new(DateTimeIntervalExpr::try_new(
- lhs,
- *op,
- rhs,
- input_schema,
- )?)),
+ ) => Ok(date_time_interval_expr(lhs, *op, rhs, input_schema)?),
(
DataType::Interval(_),
Operator::Plus | Operator::Minus,
DataType::Interval(_),
- ) => Ok(Arc::new(DateTimeIntervalExpr::try_new(
- lhs,
- *op,
- rhs,
- input_schema,
- )?)),
+ ) => Ok(date_time_interval_expr(lhs, *op, rhs, input_schema)?),
_ => {
// Note that the logical planner is responsible
// for type coercion on the arguments (e.g. if one
diff --git a/datafusion/physical-expr/src/utils.rs
b/datafusion/physical-expr/src/utils.rs
index 89b51dada8..70297bce78 100644
--- a/datafusion/physical-expr/src/utils.rs
+++ b/datafusion/physical-expr/src/utils.rs
@@ -16,7 +16,7 @@
// under the License.
use crate::equivalence::EquivalentClass;
-use crate::expressions::{BinaryExpr, Column, InListExpr, UnKnownColumn};
+use crate::expressions::{BinaryExpr, Column, UnKnownColumn};
use crate::{
EquivalenceProperties, PhysicalExpr, PhysicalSortExpr,
PhysicalSortRequirement,
};
@@ -586,28 +586,7 @@ pub fn reassign_predicate_columns(
column.name(),
index,
))));
- } else if let Some(in_list) = expr_any.downcast_ref::<InListExpr>() {
- // transform child first
- let expr = reassign_predicate_columns(
- in_list.expr().clone(),
- schema,
- ignore_not_found,
- )?;
- let list = in_list
- .list()
- .iter()
- .map(|expr| {
- reassign_predicate_columns(expr.clone(), schema,
ignore_not_found)
- })
- .collect::<Result<Vec<_>>>()?;
- return Ok(Transformed::Yes(Arc::new(InListExpr::new(
- expr,
- list,
- in_list.negated(),
- schema.as_ref(),
- ))));
}
-
Ok(Transformed::No(expr))
})
}
diff --git a/datafusion/proto/src/physical_plan/from_proto.rs
b/datafusion/proto/src/physical_plan/from_proto.rs
index b38ea2b35e..8e9ba1f605 100644
--- a/datafusion/proto/src/physical_plan/from_proto.rs
+++ b/datafusion/proto/src/physical_plan/from_proto.rs
@@ -27,15 +27,16 @@ use datafusion::datasource::object_store::ObjectStoreUrl;
use datafusion::execution::context::ExecutionProps;
use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::window_function::WindowFunction;
-use datafusion::physical_expr::expressions::DateTimeIntervalExpr;
use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr};
-use datafusion::physical_plan::expressions::GetIndexedFieldExpr;
-use datafusion::physical_plan::expressions::LikeExpr;
+use datafusion::physical_plan::expressions::{
+ date_time_interval_expr, GetIndexedFieldExpr,
+};
+use datafusion::physical_plan::expressions::{in_list, LikeExpr};
use datafusion::physical_plan::file_format::FileScanConfig;
use datafusion::physical_plan::{
expressions::{
- BinaryExpr, CaseExpr, CastExpr, Column, InListExpr, IsNotNullExpr,
IsNullExpr,
- Literal, NegativeExpr, NotExpr, TryCastExpr,
DEFAULT_DATAFUSION_CAST_OPTIONS,
+ BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr,
Literal,
+ NegativeExpr, NotExpr, TryCastExpr, DEFAULT_DATAFUSION_CAST_OPTIONS,
},
functions, Partitioning,
};
@@ -99,7 +100,7 @@ pub fn parse_physical_expr(
input_schema,
)?,
)),
- ExprType::DateTimeIntervalExpr(expr) =>
Arc::new(DateTimeIntervalExpr::try_new(
+ ExprType::DateTimeIntervalExpr(expr) => date_time_interval_expr(
parse_required_physical_expr(
expr.l.as_deref(),
registry,
@@ -114,7 +115,7 @@ pub fn parse_physical_expr(
input_schema,
)?,
input_schema,
- )?),
+ )?,
ExprType::AggregateExpr(_) => {
return Err(DataFusionError::NotImplemented(
"Cannot convert aggregate expr node to physical
expression".to_owned(),
@@ -160,7 +161,7 @@ pub fn parse_physical_expr(
input_schema,
)?))
}
- ExprType::InList(e) => Arc::new(InListExpr::new(
+ ExprType::InList(e) => in_list(
parse_required_physical_expr(
e.expr.as_deref(),
registry,
@@ -171,9 +172,9 @@ pub fn parse_physical_expr(
.iter()
.map(|x| parse_physical_expr(x, registry, input_schema))
.collect::<Result<Vec<_>, _>>()?,
- e.negated,
+ &e.negated,
input_schema,
- )),
+ )?,
ExprType::Case(e) => Arc::new(CaseExpr::try_new(
e.expr
.as_ref()
diff --git a/datafusion/proto/src/physical_plan/mod.rs
b/datafusion/proto/src/physical_plan/mod.rs
index 2c67e0b9b4..cd8950e97b 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -1243,10 +1243,12 @@ mod roundtrip_tests {
use datafusion::execution::context::ExecutionProps;
use datafusion::logical_expr::create_udf;
use datafusion::logical_expr::{BuiltinScalarFunction, Volatility};
- use datafusion::physical_expr::expressions::DateTimeIntervalExpr;
+ use datafusion::physical_expr::expressions::in_list;
use datafusion::physical_expr::ScalarFunctionExpr;
use datafusion::physical_plan::aggregates::PhysicalGroupBy;
- use datafusion::physical_plan::expressions::{like, BinaryExpr,
GetIndexedFieldExpr};
+ use datafusion::physical_plan::expressions::{
+ date_time_interval_expr, like, BinaryExpr, GetIndexedFieldExpr,
+ };
use datafusion::physical_plan::functions::make_scalar_function;
use datafusion::physical_plan::projection::ProjectionExec;
use datafusion::physical_plan::{functions, udaf};
@@ -1260,7 +1262,7 @@ mod roundtrip_tests {
physical_plan::{
aggregates::{AggregateExec, AggregateMode},
empty::EmptyExec,
- expressions::{binary, col, lit, InListExpr, NotExpr},
+ expressions::{binary, col, lit, NotExpr},
expressions::{Avg, Column, DistinctCount, PhysicalSortExpr},
file_format::{FileScanConfig, ParquetExec},
filter::FilterExec,
@@ -1326,12 +1328,8 @@ mod roundtrip_tests {
let input = Arc::new(EmptyExec::new(false, Arc::new(schema.clone())));
let date_expr = col("some_date", &schema)?;
let literal_expr = col("some_interval", &schema)?;
- let date_time_interval_expr = Arc::new(DateTimeIntervalExpr::try_new(
- date_expr,
- Operator::Plus,
- literal_expr,
- &schema,
- )?);
+ let date_time_interval_expr =
+ date_time_interval_expr(date_expr, Operator::Plus, literal_expr,
&schema)?;
let plan = Arc::new(ProjectionExec::try_new(
vec![(date_time_interval_expr, "result".to_string())],
input,
@@ -1510,15 +1508,15 @@ mod roundtrip_tests {
let field_c = Field::new("c", DataType::Int64, false);
let schema = Arc::new(Schema::new(vec![field_a, field_b, field_c]));
let not = Arc::new(NotExpr::new(col("a", &schema)?));
- let in_list = Arc::new(InListExpr::new(
+ let in_list = in_list(
col("b", &schema)?,
vec![
lit(ScalarValue::Int64(Some(1))),
lit(ScalarValue::Int64(Some(2))),
],
- false,
+ &false,
schema.as_ref(),
- ));
+ )?;
let and = binary(not, Operator::And, in_list, &schema)?;
roundtrip_test(Arc::new(FilterExec::try_new(
and,