This is an automated email from the ASF dual-hosted git repository. blaginin pushed a commit to branch annarose/dict-coercion in repository https://gitbox.apache.org/repos/asf/datafusion-sandbox.git
commit 29d63c19641de6211c6c8f3bbaa16b359caeb429 Author: Adam Gutglick <[email protected]> AuthorDate: Tue Feb 3 12:22:33 2026 +0000 Optimize `PhysicalExprSimplifier` (#20111) ## Which issue does this PR close? - Related to #20078 ## Rationale for this change An attempt at reducing the cost of physical expression simplification ## What changes are included in this PR? 1. The most important change in this PR is that if an expression is already literal, don't transform it, which means we can stop transforming the tree much earlier. currently on main, even expressions like `lit(5)` end up running through the loop 5 times. This takes this PR to ~96% improvement on the benchmark. 2. Allocate a single dummy record batch for simplifying const expressions, instead of one per `simplify_const_expr` call. 3. Adds the benchmark I've been using to test the impact of changes 4. `simplify_not_expr` and `simplify_const_expr` now take an `Arc` instead of `&Arc` ## Are these changes tested? All existing tests pass with minor modifications. ## Are there any user-facing changes? Two of the individual recursive simplification functions (`simplify_not_expr` and `simplify_const_expr`) are public. This PR breaks their signature, but I think we should consider also making them private. --------- Signed-off-by: Adam Gutglick <[email protected]> --- datafusion/physical-expr/Cargo.toml | 4 + datafusion/physical-expr/benches/simplify.rs | 299 +++++++++++++++++++++ .../src/simplifier/const_evaluator.rs | 25 +- datafusion/physical-expr/src/simplifier/mod.rs | 19 +- datafusion/physical-expr/src/simplifier/not.rs | 6 +- .../physical-expr/src/simplifier/unwrap_cast.rs | 33 ++- 6 files changed, 351 insertions(+), 35 deletions(-) diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 1b23beeaa..7e61be3a1 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -85,5 +85,9 @@ name = "is_null" harness = false name = "binary_op" +[[bench]] +harness = false +name = "simplify" + [package.metadata.cargo-machete] ignored = ["half"] diff --git a/datafusion/physical-expr/benches/simplify.rs b/datafusion/physical-expr/benches/simplify.rs new file mode 100644 index 000000000..cc00c7100 --- /dev/null +++ b/datafusion/physical-expr/benches/simplify.rs @@ -0,0 +1,299 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This is an attempt at reproducing some predicates generated by TPC-DS query #76, +//! and trying to figure out how long it takes to simplify them. + +use arrow::datatypes::{DataType, Field, Schema}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr::simplifier::PhysicalExprSimplifier; +use std::hint::black_box; +use std::sync::Arc; + +use datafusion_common::ScalarValue; +use datafusion_expr::Operator; + +use datafusion_physical_expr::expressions::{ + BinaryExpr, CaseExpr, Column, IsNullExpr, Literal, +}; + +fn catalog_sales_schema() -> Schema { + Schema::new(vec![ + Field::new("cs_sold_date_sk", DataType::Int64, true), // 0 + Field::new("cs_sold_time_sk", DataType::Int64, true), // 1 + Field::new("cs_ship_date_sk", DataType::Int64, true), // 2 + Field::new("cs_bill_customer_sk", DataType::Int64, true), // 3 + Field::new("cs_bill_cdemo_sk", DataType::Int64, true), // 4 + Field::new("cs_bill_hdemo_sk", DataType::Int64, true), // 5 + Field::new("cs_bill_addr_sk", DataType::Int64, true), // 6 + Field::new("cs_ship_customer_sk", DataType::Int64, true), // 7 + Field::new("cs_ship_cdemo_sk", DataType::Int64, true), // 8 + Field::new("cs_ship_hdemo_sk", DataType::Int64, true), // 9 + Field::new("cs_ship_addr_sk", DataType::Int64, true), // 10 + Field::new("cs_call_center_sk", DataType::Int64, true), // 11 + Field::new("cs_catalog_page_sk", DataType::Int64, true), // 12 + Field::new("cs_ship_mode_sk", DataType::Int64, true), // 13 + Field::new("cs_warehouse_sk", DataType::Int64, true), // 14 + Field::new("cs_item_sk", DataType::Int64, true), // 15 + Field::new("cs_promo_sk", DataType::Int64, true), // 16 + Field::new("cs_order_number", DataType::Int64, true), // 17 + Field::new("cs_quantity", DataType::Int64, true), // 18 + Field::new("cs_wholesale_cost", DataType::Decimal128(7, 2), true), + Field::new("cs_list_price", DataType::Decimal128(7, 2), true), + Field::new("cs_sales_price", DataType::Decimal128(7, 2), true), + Field::new("cs_ext_discount_amt", DataType::Decimal128(7, 2), true), + Field::new("cs_ext_sales_price", DataType::Decimal128(7, 2), true), + Field::new("cs_ext_wholesale_cost", DataType::Decimal128(7, 2), true), + Field::new("cs_ext_list_price", DataType::Decimal128(7, 2), true), + Field::new("cs_ext_tax", DataType::Decimal128(7, 2), true), + Field::new("cs_coupon_amt", DataType::Decimal128(7, 2), true), + Field::new("cs_ext_ship_cost", DataType::Decimal128(7, 2), true), + Field::new("cs_net_paid", DataType::Decimal128(7, 2), true), + Field::new("cs_net_paid_inc_tax", DataType::Decimal128(7, 2), true), + Field::new("cs_net_paid_inc_ship", DataType::Decimal128(7, 2), true), + Field::new("cs_net_paid_inc_ship_tax", DataType::Decimal128(7, 2), true), + Field::new("cs_net_profit", DataType::Decimal128(7, 2), true), + ]) +} + +fn web_sales_schema() -> Schema { + Schema::new(vec![ + Field::new("ws_sold_date_sk", DataType::Int64, true), + Field::new("ws_sold_time_sk", DataType::Int64, true), + Field::new("ws_ship_date_sk", DataType::Int64, true), + Field::new("ws_item_sk", DataType::Int64, true), + Field::new("ws_bill_customer_sk", DataType::Int64, true), + Field::new("ws_bill_cdemo_sk", DataType::Int64, true), + Field::new("ws_bill_hdemo_sk", DataType::Int64, true), + Field::new("ws_bill_addr_sk", DataType::Int64, true), + Field::new("ws_ship_customer_sk", DataType::Int64, true), + Field::new("ws_ship_cdemo_sk", DataType::Int64, true), + Field::new("ws_ship_hdemo_sk", DataType::Int64, true), + Field::new("ws_ship_addr_sk", DataType::Int64, true), + Field::new("ws_web_page_sk", DataType::Int64, true), + Field::new("ws_web_site_sk", DataType::Int64, true), + Field::new("ws_ship_mode_sk", DataType::Int64, true), + Field::new("ws_warehouse_sk", DataType::Int64, true), + Field::new("ws_promo_sk", DataType::Int64, true), + Field::new("ws_order_number", DataType::Int64, true), + Field::new("ws_quantity", DataType::Int64, true), + Field::new("ws_wholesale_cost", DataType::Decimal128(7, 2), true), + Field::new("ws_list_price", DataType::Decimal128(7, 2), true), + Field::new("ws_sales_price", DataType::Decimal128(7, 2), true), + Field::new("ws_ext_discount_amt", DataType::Decimal128(7, 2), true), + Field::new("ws_ext_sales_price", DataType::Decimal128(7, 2), true), + Field::new("ws_ext_wholesale_cost", DataType::Decimal128(7, 2), true), + Field::new("ws_ext_list_price", DataType::Decimal128(7, 2), true), + Field::new("ws_ext_tax", DataType::Decimal128(7, 2), true), + Field::new("ws_coupon_amt", DataType::Decimal128(7, 2), true), + Field::new("ws_ext_ship_cost", DataType::Decimal128(7, 2), true), + Field::new("ws_net_paid", DataType::Decimal128(7, 2), true), + Field::new("ws_net_paid_inc_tax", DataType::Decimal128(7, 2), true), + Field::new("ws_net_paid_inc_ship", DataType::Decimal128(7, 2), true), + Field::new("ws_net_paid_inc_ship_tax", DataType::Decimal128(7, 2), true), + Field::new("ws_net_profit", DataType::Decimal128(7, 2), true), + ]) +} + +// Helper to create a literal +fn lit_i64(val: i64) -> Arc<dyn PhysicalExpr> { + Arc::new(Literal::new(ScalarValue::Int64(Some(val)))) +} + +fn lit_i32(val: i32) -> Arc<dyn PhysicalExpr> { + Arc::new(Literal::new(ScalarValue::Int32(Some(val)))) +} + +fn lit_bool(val: bool) -> Arc<dyn PhysicalExpr> { + Arc::new(Literal::new(ScalarValue::Boolean(Some(val)))) +} + +// Helper to create binary expressions +fn and( + left: Arc<dyn PhysicalExpr>, + right: Arc<dyn PhysicalExpr>, +) -> Arc<dyn PhysicalExpr> { + Arc::new(BinaryExpr::new(left, Operator::And, right)) +} + +fn gte( + left: Arc<dyn PhysicalExpr>, + right: Arc<dyn PhysicalExpr>, +) -> Arc<dyn PhysicalExpr> { + Arc::new(BinaryExpr::new(left, Operator::GtEq, right)) +} + +fn lte( + left: Arc<dyn PhysicalExpr>, + right: Arc<dyn PhysicalExpr>, +) -> Arc<dyn PhysicalExpr> { + Arc::new(BinaryExpr::new(left, Operator::LtEq, right)) +} + +fn modulo( + left: Arc<dyn PhysicalExpr>, + right: Arc<dyn PhysicalExpr>, +) -> Arc<dyn PhysicalExpr> { + Arc::new(BinaryExpr::new(left, Operator::Modulo, right)) +} + +fn eq( + left: Arc<dyn PhysicalExpr>, + right: Arc<dyn PhysicalExpr>, +) -> Arc<dyn PhysicalExpr> { + Arc::new(BinaryExpr::new(left, Operator::Eq, right)) +} + +/// Build a predicate similar to TPC-DS q76 catalog_sales filter. +/// Uses placeholder columns instead of hash expressions. +pub fn catalog_sales_predicate(num_partitions: usize) -> Arc<dyn PhysicalExpr> { + let cs_sold_date_sk: Arc<dyn PhysicalExpr> = + Arc::new(Column::new("cs_sold_date_sk", 0)); + let cs_ship_addr_sk: Arc<dyn PhysicalExpr> = + Arc::new(Column::new("cs_ship_addr_sk", 10)); + let cs_item_sk: Arc<dyn PhysicalExpr> = Arc::new(Column::new("cs_item_sk", 15)); + + // Use a simple modulo expression as placeholder for hash + let item_hash_mod = modulo(cs_item_sk.clone(), lit_i64(num_partitions as i64)); + let date_hash_mod = modulo(cs_sold_date_sk.clone(), lit_i64(num_partitions as i64)); + + // cs_ship_addr_sk IS NULL + let is_null_expr: Arc<dyn PhysicalExpr> = Arc::new(IsNullExpr::new(cs_ship_addr_sk)); + + // Build item_sk CASE expression with num_partitions branches + let item_when_then: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)> = (0 + ..num_partitions) + .map(|partition| { + let when_expr = eq(item_hash_mod.clone(), lit_i32(partition as i32)); + let then_expr = and( + gte(cs_item_sk.clone(), lit_i64(partition as i64)), + lte(cs_item_sk.clone(), lit_i64(18000)), + ); + (when_expr, then_expr) + }) + .collect(); + + let item_case_expr: Arc<dyn PhysicalExpr> = + Arc::new(CaseExpr::try_new(None, item_when_then, Some(lit_bool(false))).unwrap()); + + // Build sold_date_sk CASE expression with num_partitions branches + let date_when_then: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)> = (0 + ..num_partitions) + .map(|partition| { + let when_expr = eq(date_hash_mod.clone(), lit_i32(partition as i32)); + let then_expr = and( + gte(cs_sold_date_sk.clone(), lit_i64(2415000 + partition as i64)), + lte(cs_sold_date_sk.clone(), lit_i64(2488070)), + ); + (when_expr, then_expr) + }) + .collect(); + + let date_case_expr: Arc<dyn PhysicalExpr> = + Arc::new(CaseExpr::try_new(None, date_when_then, Some(lit_bool(false))).unwrap()); + + // Final: is_null AND item_case AND date_case + and(and(is_null_expr, item_case_expr), date_case_expr) +} +/// Build a predicate similar to TPC-DS q76 web_sales filter. +/// Uses placeholder columns instead of hash expressions. +fn web_sales_predicate(num_partitions: usize) -> Arc<dyn PhysicalExpr> { + let ws_sold_date_sk: Arc<dyn PhysicalExpr> = + Arc::new(Column::new("ws_sold_date_sk", 0)); + let ws_item_sk: Arc<dyn PhysicalExpr> = Arc::new(Column::new("ws_item_sk", 3)); + let ws_ship_customer_sk: Arc<dyn PhysicalExpr> = + Arc::new(Column::new("ws_ship_customer_sk", 8)); + + // Use simple modulo expression as placeholder for hash + let item_hash_mod = modulo(ws_item_sk.clone(), lit_i64(num_partitions as i64)); + let date_hash_mod = modulo(ws_sold_date_sk.clone(), lit_i64(num_partitions as i64)); + + // ws_ship_customer_sk IS NULL + let is_null_expr: Arc<dyn PhysicalExpr> = + Arc::new(IsNullExpr::new(ws_ship_customer_sk)); + + // Build item_sk CASE expression with num_partitions branches + let item_when_then: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)> = (0 + ..num_partitions) + .map(|partition| { + let when_expr = eq(item_hash_mod.clone(), lit_i32(partition as i32)); + let then_expr = and( + gte(ws_item_sk.clone(), lit_i64(partition as i64)), + lte(ws_item_sk.clone(), lit_i64(18000)), + ); + (when_expr, then_expr) + }) + .collect(); + + let item_case_expr: Arc<dyn PhysicalExpr> = + Arc::new(CaseExpr::try_new(None, item_when_then, Some(lit_bool(false))).unwrap()); + + // Build sold_date_sk CASE expression with num_partitions branches + let date_when_then: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)> = (0 + ..num_partitions) + .map(|partition| { + let when_expr = eq(date_hash_mod.clone(), lit_i32(partition as i32)); + let then_expr = and( + gte(ws_sold_date_sk.clone(), lit_i64(2415000 + partition as i64)), + lte(ws_sold_date_sk.clone(), lit_i64(2488070)), + ); + (when_expr, then_expr) + }) + .collect(); + + let date_case_expr: Arc<dyn PhysicalExpr> = + Arc::new(CaseExpr::try_new(None, date_when_then, Some(lit_bool(false))).unwrap()); + + and(and(is_null_expr, item_case_expr), date_case_expr) +} + +/// Measures how long `PhysicalExprSimplifier::simplify` takes for a given expression. +fn bench_simplify( + c: &mut Criterion, + name: &str, + schema: &Schema, + expr: &Arc<dyn PhysicalExpr>, +) { + let simplifier = PhysicalExprSimplifier::new(schema); + c.bench_function(name, |b| { + b.iter(|| black_box(simplifier.simplify(black_box(Arc::clone(expr))).unwrap())) + }); +} + +fn criterion_benchmark(c: &mut Criterion) { + let cs_schema = catalog_sales_schema(); + let ws_schema = web_sales_schema(); + + for num_partitions in [16, 128] { + bench_simplify( + c, + &format!("tpc-ds/q76/cs/{num_partitions}"), + &cs_schema, + &catalog_sales_predicate(num_partitions), + ); + bench_simplify( + c, + &format!("tpc-ds/q76/ws/{num_partitions}"), + &ws_schema, + &web_sales_predicate(num_partitions), + ); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/physical-expr/src/simplifier/const_evaluator.rs b/datafusion/physical-expr/src/simplifier/const_evaluator.rs index 8a2368c40..1e62e47ce 100644 --- a/datafusion/physical-expr/src/simplifier/const_evaluator.rs +++ b/datafusion/physical-expr/src/simplifier/const_evaluator.rs @@ -40,17 +40,22 @@ use crate::expressions::{Column, Literal}; /// - `(1 + 2) * 3` -> `9` (with bottom-up traversal) /// - `'hello' || ' world'` -> `'hello world'` pub fn simplify_const_expr( - expr: &Arc<dyn PhysicalExpr>, + expr: Arc<dyn PhysicalExpr>, ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> { - if !can_evaluate_as_constant(expr) { - return Ok(Transformed::no(Arc::clone(expr))); - } + simplify_const_expr_with_dummy(expr, &create_dummy_batch()?) +} - // Create a 1-row dummy batch for evaluation - let batch = create_dummy_batch()?; +pub(crate) fn simplify_const_expr_with_dummy( + expr: Arc<dyn PhysicalExpr>, + batch: &RecordBatch, +) -> Result<Transformed<Arc<dyn PhysicalExpr>>> { + // If expr is already a const literal or can't be evaluated into one. + if expr.as_any().is::<Literal>() || (!can_evaluate_as_constant(&expr)) { + return Ok(Transformed::no(expr)); + } // Evaluate the expression - match expr.evaluate(&batch) { + match expr.evaluate(batch) { Ok(ColumnarValue::Scalar(scalar)) => { Ok(Transformed::yes(Arc::new(Literal::new(scalar)))) } @@ -61,13 +66,13 @@ pub fn simplify_const_expr( } Ok(_) => { // Unexpected result - keep original expression - Ok(Transformed::no(Arc::clone(expr))) + Ok(Transformed::no(expr)) } Err(_) => { // On error, keep original expression // The expression might succeed at runtime due to short-circuit evaluation // or other runtime conditions - Ok(Transformed::no(Arc::clone(expr))) + Ok(Transformed::no(expr)) } } } @@ -95,7 +100,7 @@ fn can_evaluate_as_constant(expr: &Arc<dyn PhysicalExpr>) -> bool { /// that only contain literals, the batch content is irrelevant. /// /// This is the same approach used in the logical expression `ConstEvaluator`. -fn create_dummy_batch() -> Result<RecordBatch> { +pub(crate) fn create_dummy_batch() -> Result<RecordBatch> { // RecordBatch requires at least one column let dummy_schema = Arc::new(Schema::new(vec![Field::new("_", DataType::Null, true)])); let col = new_null_array(&DataType::Null, 1); diff --git a/datafusion/physical-expr/src/simplifier/mod.rs b/datafusion/physical-expr/src/simplifier/mod.rs index 3bd4683c1..45ead82a0 100644 --- a/datafusion/physical-expr/src/simplifier/mod.rs +++ b/datafusion/physical-expr/src/simplifier/mod.rs @@ -21,7 +21,14 @@ use arrow::datatypes::Schema; use datafusion_common::{Result, tree_node::TreeNode}; use std::sync::Arc; -use crate::{PhysicalExpr, simplifier::not::simplify_not_expr}; +use crate::{ + PhysicalExpr, + simplifier::{ + const_evaluator::{create_dummy_batch, simplify_const_expr_with_dummy}, + not::simplify_not_expr, + unwrap_cast::unwrap_cast_in_comparison, + }, +}; pub mod const_evaluator; pub mod not; @@ -50,6 +57,8 @@ impl<'a> PhysicalExprSimplifier<'a> { let mut count = 0; let schema = self.schema; + let batch = create_dummy_batch()?; + while count < MAX_LOOP_COUNT { count += 1; let result = current_expr.transform(|node| { @@ -58,11 +67,11 @@ impl<'a> PhysicalExprSimplifier<'a> { // Apply NOT expression simplification first, then unwrap cast optimization, // then constant expression evaluation - let rewritten = simplify_not_expr(&node, schema)? + let rewritten = simplify_not_expr(node, schema)? + .transform_data(|node| unwrap_cast_in_comparison(node, schema))? .transform_data(|node| { - unwrap_cast::unwrap_cast_in_comparison(node, schema) - })? - .transform_data(|node| const_evaluator::simplify_const_expr(&node))?; + simplify_const_expr_with_dummy(node, &batch) + })?; #[cfg(debug_assertions)] assert_eq!( diff --git a/datafusion/physical-expr/src/simplifier/not.rs b/datafusion/physical-expr/src/simplifier/not.rs index 9b65d5cba..ea5467d0a 100644 --- a/datafusion/physical-expr/src/simplifier/not.rs +++ b/datafusion/physical-expr/src/simplifier/not.rs @@ -44,13 +44,13 @@ use crate::expressions::{BinaryExpr, InListExpr, Literal, NotExpr, in_list, lit} /// TreeNodeRewriter, multiple passes will automatically be applied until no more /// transformations are possible. pub fn simplify_not_expr( - expr: &Arc<dyn PhysicalExpr>, + expr: Arc<dyn PhysicalExpr>, schema: &Schema, ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> { // Check if this is a NOT expression let not_expr = match expr.as_any().downcast_ref::<NotExpr>() { Some(not_expr) => not_expr, - None => return Ok(Transformed::no(Arc::clone(expr))), + None => return Ok(Transformed::no(expr)), }; let inner_expr = not_expr.arg(); @@ -120,5 +120,5 @@ pub fn simplify_not_expr( } // If no simplification possible, return the original expression - Ok(Transformed::no(Arc::clone(expr))) + Ok(Transformed::no(expr)) } diff --git a/datafusion/physical-expr/src/simplifier/unwrap_cast.rs b/datafusion/physical-expr/src/simplifier/unwrap_cast.rs index ae6da9c5e..0de517cd3 100644 --- a/datafusion/physical-expr/src/simplifier/unwrap_cast.rs +++ b/datafusion/physical-expr/src/simplifier/unwrap_cast.rs @@ -34,10 +34,7 @@ use std::sync::Arc; use arrow::datatypes::{DataType, Schema}; -use datafusion_common::{ - Result, ScalarValue, - tree_node::{Transformed, TreeNode}, -}; +use datafusion_common::{Result, ScalarValue, tree_node::Transformed}; use datafusion_expr::Operator; use datafusion_expr_common::casts::try_cast_literal_to_type; @@ -49,14 +46,12 @@ pub(crate) fn unwrap_cast_in_comparison( expr: Arc<dyn PhysicalExpr>, schema: &Schema, ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> { - expr.transform_down(|e| { - if let Some(binary) = e.as_any().downcast_ref::<BinaryExpr>() - && let Some(unwrapped) = try_unwrap_cast_binary(binary, schema)? - { - return Ok(Transformed::yes(unwrapped)); - } - Ok(Transformed::no(e)) - }) + if let Some(binary) = expr.as_any().downcast_ref::<BinaryExpr>() + && let Some(unwrapped) = try_unwrap_cast_binary(binary, schema)? + { + return Ok(Transformed::yes(unwrapped)); + } + Ok(Transformed::no(expr)) } /// Try to unwrap casts in binary expressions @@ -144,7 +139,7 @@ mod tests { use super::*; use crate::expressions::{col, lit}; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::ScalarValue; + use datafusion_common::{ScalarValue, tree_node::TreeNode}; use datafusion_expr::Operator; /// Check if an expression is a cast expression @@ -484,8 +479,10 @@ mod tests { let and_expr = Arc::new(BinaryExpr::new(compare1, Operator::And, compare2)); - // Apply unwrap cast optimization - let result = unwrap_cast_in_comparison(and_expr, &schema).unwrap(); + // Apply unwrap cast optimization recursively + let result = (and_expr as Arc<dyn PhysicalExpr>) + .transform_down(|node| unwrap_cast_in_comparison(node, &schema)) + .unwrap(); // Should be transformed assert!(result.transformed); @@ -602,8 +599,10 @@ mod tests { // Create AND expression let and_expr = Arc::new(BinaryExpr::new(c1_binary, Operator::And, c2_binary)); - // Apply unwrap cast optimization - let result = unwrap_cast_in_comparison(and_expr, &schema).unwrap(); + // Apply unwrap cast optimization recursively + let result = (and_expr as Arc<dyn PhysicalExpr>) + .transform_down(|node| unwrap_cast_in_comparison(node, &schema)) + .unwrap(); // Should be transformed assert!(result.transformed); --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
