alamb commented on code in PR #18183: URL: https://github.com/apache/datafusion/pull/18183#discussion_r2596262315
########## datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs: ########## @@ -0,0 +1,327 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod boolean_lookup_table; +mod bytes_like_lookup_table; +mod primitive_lookup_table; + +use crate::expressions::case::literal_lookup_table::boolean_lookup_table::BooleanIndexMap; +use crate::expressions::case::literal_lookup_table::bytes_like_lookup_table::BytesLikeIndexMap; +use crate::expressions::case::literal_lookup_table::primitive_lookup_table::PrimitiveIndexMap; +use crate::expressions::case::CaseBody; +use crate::expressions::Literal; +use arrow::array::{downcast_primitive, Array, ArrayRef, UInt32Array}; +use arrow::datatypes::DataType; +use datafusion_common::{arrow_datafusion_err, plan_datafusion_err, ScalarValue}; +use indexmap::IndexMap; +use std::fmt::Debug; + +/// Optimization for CASE expressions with literal WHEN and THEN clauses +/// +/// for this form: +/// ```sql +/// CASE <expr_a> +/// WHEN <literal_a> THEN <literal_e> +/// WHEN <literal_b> THEN <literal_f> +/// WHEN <literal_c> THEN <literal_g> +/// WHEN <literal_d> THEN <literal_h> +/// ELSE <optional-fallback_literal> +/// END +/// ``` +/// +/// # Improvement idea +/// TODO - we should think of unwrapping the `IN` expressions into multiple equality comparisons Review Comment: I recommend filing a ticket to track this idea (and leave a link here), otherwise this TODO may never be found ########## datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs: ########## @@ -0,0 +1,225 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::expressions::case::literal_lookup_table::WhenLiteralIndexMap; +use arrow::array::{ + downcast_integer, Array, ArrayRef, AsArray, BinaryArray, BinaryViewArray, + DictionaryArray, FixedSizeBinaryArray, LargeBinaryArray, LargeStringArray, + StringArray, StringViewArray, +}; +use arrow::datatypes::{ + ArrowDictionaryKeyType, BinaryViewType, DataType, StringViewType, +}; +use datafusion_common::{internal_err, plan_datafusion_err, HashMap, ScalarValue}; +use std::fmt::Debug; + +/// Map from byte-like literal values to their first occurrence index +/// +/// This is a wrapper for handling different kinds of literal maps +#[derive(Clone, Debug)] +pub(super) struct BytesLikeIndexMap { + /// Map from non-null literal value the first occurrence index in the literals + map: HashMap<Vec<u8>, u32>, +} + +impl BytesLikeIndexMap { + /// Try creating a new lookup table from the given literals and else index + /// The index of each literal in the vector is used as the mapped value in the lookup table. + /// + /// `literals` are guaranteed to be unique and non-nullable + pub(super) fn try_new( + unique_non_null_literals: Vec<ScalarValue>, + ) -> datafusion_common::Result<Self> { + let input = ScalarValue::iter_to_array(unique_non_null_literals)?; + + // Literals are guaranteed to not contain nulls + if input.logical_null_count() > 0 { + return internal_err!("Literal values for WHEN clauses cannot contain nulls"); + } + + let map: HashMap<Vec<u8>, u32> = try_get_bytes_iterator(&input)? + // Flattening Option<&[u8]> to &[u8] as literals cannot contain nulls + .flatten() + .enumerate() + .map(|(map_index, value)| (value.to_vec(), map_index as u32)) + // Because literals are unique we can collect directly, and we can avoid only inserting the first occurrence + .collect(); + + Ok(Self { map }) + } +} + +impl WhenLiteralIndexMap for BytesLikeIndexMap { + fn map_to_when_indices( + &self, + array: &ArrayRef, + else_index: u32, + ) -> datafusion_common::Result<Vec<u32>> { + let indices = try_get_bytes_iterator(array)? + .map(|value| match value { + Some(value) => self.map.get(value).copied().unwrap_or(else_index), + None => else_index, + }) + .collect::<Vec<u32>>(); + + Ok(indices) + } +} + +fn try_get_bytes_iterator( Review Comment: I recommend pulling this into a utility / other type method -- it might be helpful when implementing other string-like functions Perhaps datafusion/physical-expr/src/utils/bytes_iter.rs or something ########## datafusion/physical-expr/src/expressions/case/literal_lookup_table/primitive_lookup_table.rs: ########## @@ -0,0 +1,229 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::expressions::case::literal_lookup_table::WhenLiteralIndexMap; +use arrow::array::{ + Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, AsArray, PrimitiveArray, +}; +use arrow::datatypes::{i256, DataType, IntervalDayTime, IntervalMonthDayNano}; +use datafusion_common::{internal_err, HashMap, ScalarValue}; +use half::f16; +use std::fmt::Debug; +use std::hash::Hash; + +#[derive(Clone)] +pub(super) struct PrimitiveIndexMap<T> +where + T: ArrowPrimitiveType, + T::Native: ToHashableKey, +{ + data_type: DataType, + /// Literal value to map index + /// + /// If searching this map becomes a bottleneck consider using linear map implementations for small hashmaps + map: HashMap<<T::Native as ToHashableKey>::HashableKey, u32>, +} + +impl<T> Debug for PrimitiveIndexMap<T> +where + T: ArrowPrimitiveType, + T::Native: ToHashableKey, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PrimitiveIndexMap") + .field("map", &self.map) + .finish() + } +} + +impl<T> PrimitiveIndexMap<T> +where + T: ArrowPrimitiveType, + T::Native: ToHashableKey, +{ + /// Try creating a new lookup table from the given literals and else index. + /// The index of each literal in the vector is used as the mapped value in the lookup table. + /// + /// `literals` are guaranteed to be unique and non-nullable + pub(super) fn try_new( + unique_non_null_literals: Vec<ScalarValue>, + ) -> datafusion_common::Result<Self> { + let input = ScalarValue::iter_to_array(unique_non_null_literals)?; + + // Literals are guaranteed to not contain nulls + if input.null_count() > 0 { + return internal_err!("Literal values for WHEN clauses cannot contain nulls"); + } + + let map = input + .as_primitive::<T>() + .values() + .iter() + .enumerate() + // Because literals are unique we can collect directly, and we can avoid only inserting the first occurrence + .map(|(map_index, value)| (value.into_hashable_key(), map_index as u32)) + .collect(); + + Ok(Self { + map, + data_type: input.data_type().clone(), + }) + } + + fn map_primitive_array_to_when_indices( + &self, + array: &PrimitiveArray<T>, + else_index: u32, + ) -> datafusion_common::Result<Vec<u32>> { + let indices = array + .into_iter() + .map(|value| match value { + Some(value) => self + .map + .get(&value.into_hashable_key()) + .copied() + .unwrap_or(else_index), + + None => else_index, + }) + .collect::<Vec<u32>>(); + + Ok(indices) + } +} + +impl<T> WhenLiteralIndexMap for PrimitiveIndexMap<T> +where + T: ArrowPrimitiveType, + T::Native: ToHashableKey, +{ + fn map_to_when_indices( + &self, + array: &ArrayRef, + else_index: u32, + ) -> datafusion_common::Result<Vec<u32>> { + match array.data_type() { + dt if dt == &self.data_type => { + let primitive_array = array.as_primitive::<T>(); + + self.map_primitive_array_to_when_indices(primitive_array, else_index) + } + // We support dictionary primitive array as we create the lookup table in `CaseWhen` expression + // creation when we don't know the schema, so we may receive dictionary encoded primitive arrays at execution time. + DataType::Dictionary(_, value_type) + if value_type.as_ref() == &self.data_type => + { + // Cast here to simplify the implementation. Review Comment: as above, you can probably reduce the size of this code by always calling `cast` ########## datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs: ########## @@ -0,0 +1,327 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod boolean_lookup_table; +mod bytes_like_lookup_table; +mod primitive_lookup_table; + +use crate::expressions::case::literal_lookup_table::boolean_lookup_table::BooleanIndexMap; +use crate::expressions::case::literal_lookup_table::bytes_like_lookup_table::BytesLikeIndexMap; +use crate::expressions::case::literal_lookup_table::primitive_lookup_table::PrimitiveIndexMap; +use crate::expressions::case::CaseBody; +use crate::expressions::Literal; +use arrow::array::{downcast_primitive, Array, ArrayRef, UInt32Array}; +use arrow::datatypes::DataType; +use datafusion_common::{arrow_datafusion_err, plan_datafusion_err, ScalarValue}; +use indexmap::IndexMap; +use std::fmt::Debug; + +/// Optimization for CASE expressions with literal WHEN and THEN clauses +/// +/// for this form: +/// ```sql +/// CASE <expr_a> +/// WHEN <literal_a> THEN <literal_e> +/// WHEN <literal_b> THEN <literal_f> +/// WHEN <literal_c> THEN <literal_g> +/// WHEN <literal_d> THEN <literal_h> +/// ELSE <optional-fallback_literal> +/// END +/// ``` +/// +/// # Improvement idea +/// TODO - we should think of unwrapping the `IN` expressions into multiple equality comparisons +/// so it will use this optimization as well, e.g. +/// ```sql +/// -- Before +/// CASE +/// WHEN (<expr_a> = <literal_a>) THEN <literal_e> +/// WHEN (<expr_a> in (<literal_b>, <literal_c>) THEN <literal_f> +/// WHEN (<expr_a> = <literal_d>) THEN <literal_g> +/// ELSE <optional-fallback_literal> +/// +/// -- After +/// CASE +/// WHEN (<expr_a> = <literal_a>) THEN <literal_e> +/// WHEN (<expr_a> = <literal_b>) THEN <literal_f> +/// WHEN (<expr_a> = <literal_c>) THEN <literal_g> +/// WHEN (<expr_a> = <literal_d>) THEN <literal_h> +/// ELSE <optional-fallback_literal> +/// END +/// ``` +/// +#[derive(Debug)] +pub(in super::super) struct LiteralLookupTable { + /// The lookup table to use for evaluating the CASE expression + lookup: Box<dyn WhenLiteralIndexMap>, + + else_index: u32, + + /// [`ArrayRef`] where `array[i] = then_literals[i]` + /// the last value in the array is the else_expr + /// + /// This will be used to take from based on the indices returned by the lookup table to build the final output + then_and_else_values: ArrayRef, +} + +impl LiteralLookupTable { + pub(in super::super) fn maybe_new(body: &CaseBody) -> Option<Self> { + // We can't use the optimization if we don't have any when then pairs + if body.when_then_expr.is_empty() { + return None; + } + + // If we only have 1 than this optimization is not useful + if body.when_then_expr.len() == 1 { + return None; + } + + // Try to downcast all the WHEN/THEN expressions to literals + let when_then_exprs_maybe_literals = body + .when_then_expr + .iter() + .map(|(when, then)| { + let when_maybe_literal = when.as_any().downcast_ref::<Literal>(); + let then_maybe_literal = then.as_any().downcast_ref::<Literal>(); + + when_maybe_literal.zip(then_maybe_literal) + }) + .collect::<Vec<_>>(); + + // If not all the WHEN/THEN expressions are literals we cannot use this optimization + if when_then_exprs_maybe_literals.contains(&None) { + return None; + } + + let when_then_exprs_scalars = when_then_exprs_maybe_literals + .into_iter() + // Unwrap the options as we have already checked there is no None + .flatten() + .map(|(when_lit, then_lit)| { + (when_lit.value().clone(), then_lit.value().clone()) + }) + // Only keep non-null WHEN literals + // as they cannot be matched - case NULL WHEN NULL THEN ... ELSE ... END always goes to ELSE + .filter(|(when_lit, _)| !when_lit.is_null()) + .collect::<Vec<_>>(); + + if when_then_exprs_scalars.is_empty() { + // All WHEN literals were nulls, so cannot use optimization + // + // instead, another optimization would be to go straight to the ELSE clause + return None; + } + + // Keep only the first occurrence of each when literal (as the first match is used) + // and remove nulls (as they cannot be matched - case NULL WHEN NULL THEN ... ELSE ... END always goes to ELSE) + let (when, then): (Vec<ScalarValue>, Vec<ScalarValue>) = { + let mut map = IndexMap::with_capacity(body.when_then_expr.len()); + + for (when, then) in when_then_exprs_scalars.into_iter() { + // Don't overwrite existing entries as we want to keep the first occurrence + if !map.contains_key(&when) { + map.insert(when, then); + } + } + + map.into_iter().unzip() + }; + + let else_value: ScalarValue = if let Some(else_expr) = &body.else_expr { + let literal = else_expr.as_any().downcast_ref::<Literal>()?; + + literal.value().clone() + } else { + let Ok(null_scalar) = ScalarValue::try_new_null(&then[0].data_type()) else { + return None; + }; + + null_scalar + }; + + { + let when_data_type = when[0].data_type(); + + // If not all the WHEN literals are the same data type we cannot use this optimization + if when.iter().any(|l| l.data_type() != when_data_type) { + return None; + } + } + + { + let data_type = then[0].data_type(); + + // If not all the then and the else literals are the same data type we cannot use this optimization + if then.iter().any(|l| l.data_type() != data_type) { + return None; + } + + if else_value.data_type() != data_type { + return None; + } + } + + let then_and_else_values = ScalarValue::iter_to_array( + then.iter() + // The else is in the end + .chain(std::iter::once(&else_value)) + .cloned(), + ) + .ok()?; + // The else expression is in the end + let else_index = then_and_else_values.len() as u32 - 1; + + let lookup = try_creating_lookup_table(when).ok()?; + + Some(Self { + lookup, + then_and_else_values, + else_index, + }) + } + + pub(in super::super) fn map_keys_to_values( + &self, + keys_array: &ArrayRef, + ) -> datafusion_common::Result<ArrayRef> { + let take_indices = self + .lookup + .map_to_when_indices(keys_array, self.else_index)?; + + // Zero-copy conversion + let take_indices = UInt32Array::from(take_indices); + + // An optimize version would depend on the type of the values_to_take_from Review Comment: I believe the `take` kernel already implements the optimization this comment is referring to (this is one of the benefits of re-using the arrow kernels). Thus I think this comment no longer applies ########## datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs: ########## @@ -0,0 +1,225 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::expressions::case::literal_lookup_table::WhenLiteralIndexMap; +use arrow::array::{ + downcast_integer, Array, ArrayRef, AsArray, BinaryArray, BinaryViewArray, + DictionaryArray, FixedSizeBinaryArray, LargeBinaryArray, LargeStringArray, + StringArray, StringViewArray, +}; +use arrow::datatypes::{ + ArrowDictionaryKeyType, BinaryViewType, DataType, StringViewType, +}; +use datafusion_common::{internal_err, plan_datafusion_err, HashMap, ScalarValue}; +use std::fmt::Debug; + +/// Map from byte-like literal values to their first occurrence index +/// +/// This is a wrapper for handling different kinds of literal maps +#[derive(Clone, Debug)] +pub(super) struct BytesLikeIndexMap { + /// Map from non-null literal value the first occurrence index in the literals + map: HashMap<Vec<u8>, u32>, +} + +impl BytesLikeIndexMap { + /// Try creating a new lookup table from the given literals and else index + /// The index of each literal in the vector is used as the mapped value in the lookup table. + /// + /// `literals` are guaranteed to be unique and non-nullable + pub(super) fn try_new( + unique_non_null_literals: Vec<ScalarValue>, + ) -> datafusion_common::Result<Self> { + let input = ScalarValue::iter_to_array(unique_non_null_literals)?; + + // Literals are guaranteed to not contain nulls + if input.logical_null_count() > 0 { + return internal_err!("Literal values for WHEN clauses cannot contain nulls"); + } + + let map: HashMap<Vec<u8>, u32> = try_get_bytes_iterator(&input)? + // Flattening Option<&[u8]> to &[u8] as literals cannot contain nulls + .flatten() + .enumerate() + .map(|(map_index, value)| (value.to_vec(), map_index as u32)) + // Because literals are unique we can collect directly, and we can avoid only inserting the first occurrence + .collect(); + + Ok(Self { map }) + } +} + +impl WhenLiteralIndexMap for BytesLikeIndexMap { + fn map_to_when_indices( + &self, + array: &ArrayRef, + else_index: u32, + ) -> datafusion_common::Result<Vec<u32>> { + let indices = try_get_bytes_iterator(array)? + .map(|value| match value { + Some(value) => self.map.get(value).copied().unwrap_or(else_index), + None => else_index, + }) + .collect::<Vec<u32>>(); + + Ok(indices) + } +} + +fn try_get_bytes_iterator( + array: &ArrayRef, +) -> datafusion_common::Result<Box<dyn Iterator<Item = Option<&[u8]>> + '_>> { + Ok(match array.data_type() { + DataType::Utf8 => Box::new(array.as_string::<i32>().into_iter().map(|item| { + item.map(|v| { + let bytes: &[u8] = v.as_ref(); + + bytes + }) + })), + + DataType::LargeUtf8 => { + Box::new(array.as_string::<i64>().into_iter().map(|item| { + item.map(|v| { + let bytes: &[u8] = v.as_ref(); + + bytes + }) + })) + } + + DataType::Binary => Box::new(array.as_binary::<i32>().into_iter()), + + DataType::LargeBinary => Box::new(array.as_binary::<i64>().into_iter()), + + DataType::FixedSizeBinary(_) => Box::new(array.as_binary::<i64>().into_iter()), + + DataType::Utf8View => Box::new( + array + .as_byte_view::<StringViewType>() + .into_iter() + .map(|item| { + item.map(|v| { + let bytes: &[u8] = v.as_ref(); + + bytes + }) + }), + ), + DataType::BinaryView => { + Box::new(array.as_byte_view::<BinaryViewType>().into_iter()) + } + + DataType::Dictionary(key, _) => { + macro_rules! downcast_dictionary_array_helper { + ($t:ty) => {{ + get_bytes_iterator_for_dictionary(array.as_dictionary::<$t>())? + }}; + } + + downcast_integer! { + key.as_ref() => (downcast_dictionary_array_helper), + k => unreachable!("unsupported dictionary key type: {}", k) + } + } + t => { + return Err(plan_datafusion_err!( + "Unsupported data type for bytes lookup table: {}", + t + )) + } + }) +} + +fn get_bytes_iterator_for_dictionary<K: ArrowDictionaryKeyType + Send + Sync>( + array: &DictionaryArray<K>, +) -> datafusion_common::Result<Box<dyn Iterator<Item = Option<&[u8]>> + '_>> { + Ok(match array.values().data_type() { + DataType::Utf8 => Box::new( + array Review Comment: Minor nit is that method is inconsistent with the branch for `Utf8` above which uses `as_string::<i32>()` I recommend having this code follow the same pattern as above as it is less verbose ```rust DataType::Utf8 => Box::new(array.as_string::<i32>().into_iter().map(|item| { ``` This comment applies to the other types here too ########## datafusion/physical-expr/src/expressions/case/literal_lookup_table/boolean_lookup_table.rs: ########## @@ -0,0 +1,122 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::expressions::case::literal_lookup_table::WhenLiteralIndexMap; +use arrow::array::{Array, ArrayRef, AsArray, BooleanArray}; +use arrow::datatypes::DataType; +use datafusion_common::{internal_err, ScalarValue}; + +#[derive(Clone, Debug)] +pub(super) struct BooleanIndexMap { + true_index: Option<u32>, + false_index: Option<u32>, +} + +impl BooleanIndexMap { + /// Try creating a new lookup table from the given literals and else index + /// The index of each literal in the vector is used as the mapped value in the lookup table. + /// + /// `literals` are guaranteed to be unique and non-nullable + pub(super) fn try_new( + unique_non_null_literals: Vec<ScalarValue>, + ) -> datafusion_common::Result<Self> { + let mut true_index: Option<u32> = None; + let mut false_index: Option<u32> = None; + + for (index, literal) in unique_non_null_literals.into_iter().enumerate() { + match literal { + ScalarValue::Boolean(Some(true)) => { + if true_index.is_some() { + return internal_err!( + "Duplicate true literal found in literals for BooleanIndexMap" + ); + } + true_index = Some(index as u32); + } + ScalarValue::Boolean(Some(false)) => { + if false_index.is_some() { + return internal_err!( + "Duplicate false literal found in literals for BooleanIndexMap" + ); + } + false_index = Some(index as u32); + } + ScalarValue::Boolean(None) => { + return internal_err!( + "Null literal found in non-null literals for BooleanIndexMap" + ) + } + _ => { + return internal_err!( + "Non-boolean literal found in literals for BooleanIndexMap" + ) + } + } + } + + Ok(Self { + true_index, + false_index, + }) + } + + fn map_boolean_array_to_when_indices( + &self, + array: &BooleanArray, + else_index: u32, + ) -> datafusion_common::Result<Vec<u32>> { + let true_index = self.true_index.unwrap_or(else_index); + let false_index = self.false_index.unwrap_or(else_index); + + Ok(array + .into_iter() + .map(|value| match value { + Some(true) => true_index, + Some(false) => false_index, + None => else_index, + }) + .collect::<Vec<u32>>()) + } +} + +impl WhenLiteralIndexMap for BooleanIndexMap { + fn map_to_when_indices( + &self, + array: &ArrayRef, + else_index: u32, + ) -> datafusion_common::Result<Vec<u32>> { + match array.data_type() { + DataType::Boolean => { + self.map_boolean_array_to_when_indices(array.as_boolean(), else_index) + } + // We support dictionary boolean array as we create the lookup table in `CaseWhen` expression + // creation when we don't know the schema, so we may receive dictionary encoded boolean arrays at execution time. + DataType::Dictionary(_, value_type) + if value_type.as_ref() == &DataType::Boolean => + { + // Since it is not common to have dictionary encoded boolean arrays + // at all than it is ok to do the cast here to simplify the implementation. + let converted = arrow::compute::cast(array.as_ref(), &DataType::Boolean)?; Review Comment: I think you could make this code simpler by always calling `cast` to boolean (which is a noop for an already BooleanArray). However, this structure will have more specific errors, so I don't think any changes are required ########## datafusion/physical-expr/src/expressions/case.rs: ########## @@ -73,8 +75,37 @@ enum EvalMethod { /// /// CASE WHEN condition THEN expression ELSE expression END ExpressionOrExpression(ProjectedCaseBody), + + /// This is a specialization for [`EvalMethod::WithExpression`] when the value and results are literals + /// + /// See [`LiteralLookupTable`] for more details + WithExprScalarLookupTable(LiteralLookupTable), +} + +/// Implementing hash so we can use `derive` on [`EvalMethod`]. +/// +/// not implementing actual [`Hash`] as it is not dyn compatible so we cannot implement it for +/// `dyn` [`literal_lookup_table::WhenLiteralIndexMap`]. +/// +/// So implementing empty hash is still valid as the data is derived from `PhysicalExpr` s which are already hashed +impl Hash for LiteralLookupTable { + fn hash<H: std::hash::Hasher>(&self, _state: &mut H) {} +} + +/// Implementing Equal so we can use `derive` on [`EvalMethod`]. +/// +/// not implementing actual [`PartialEq`] as it is not dyn compatible so we cannot implement it for +/// `dyn` [`literal_lookup_table::WhenLiteralIndexMap`]. +/// +/// So we always return true as the data is derived from `PhysicalExpr` s which are already compared Review Comment: I think it is ok to return `true` here but don't quite follow this argument (aka I recommend updating this comment and the one for hash) It seems to me it is ok to return `true` here because a CaseExpr has a `body` and an `eval` method: ```shell https://github.com/apache/datafusion/blob/0cc93a6a750ef3082e8fa94c8d8248d980e49e5f/datafusion/physical-expr/src/expressions/case.rs#L249-L254 ``` Since the `eval` field is deterministically determined (only) from the `body` then if the body is equal, we are guaranteed that `eval` will also be the same and no further checks are required ########## datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs: ########## @@ -0,0 +1,327 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod boolean_lookup_table; +mod bytes_like_lookup_table; +mod primitive_lookup_table; + +use crate::expressions::case::literal_lookup_table::boolean_lookup_table::BooleanIndexMap; +use crate::expressions::case::literal_lookup_table::bytes_like_lookup_table::BytesLikeIndexMap; +use crate::expressions::case::literal_lookup_table::primitive_lookup_table::PrimitiveIndexMap; +use crate::expressions::case::CaseBody; +use crate::expressions::Literal; +use arrow::array::{downcast_primitive, Array, ArrayRef, UInt32Array}; +use arrow::datatypes::DataType; +use datafusion_common::{arrow_datafusion_err, plan_datafusion_err, ScalarValue}; +use indexmap::IndexMap; +use std::fmt::Debug; + +/// Optimization for CASE expressions with literal WHEN and THEN clauses +/// +/// for this form: +/// ```sql +/// CASE <expr_a> +/// WHEN <literal_a> THEN <literal_e> +/// WHEN <literal_b> THEN <literal_f> +/// WHEN <literal_c> THEN <literal_g> +/// WHEN <literal_d> THEN <literal_h> +/// ELSE <optional-fallback_literal> +/// END +/// ``` +/// +/// # Improvement idea +/// TODO - we should think of unwrapping the `IN` expressions into multiple equality comparisons +/// so it will use this optimization as well, e.g. +/// ```sql +/// -- Before +/// CASE +/// WHEN (<expr_a> = <literal_a>) THEN <literal_e> +/// WHEN (<expr_a> in (<literal_b>, <literal_c>) THEN <literal_f> +/// WHEN (<expr_a> = <literal_d>) THEN <literal_g> +/// ELSE <optional-fallback_literal> +/// +/// -- After +/// CASE +/// WHEN (<expr_a> = <literal_a>) THEN <literal_e> +/// WHEN (<expr_a> = <literal_b>) THEN <literal_f> +/// WHEN (<expr_a> = <literal_c>) THEN <literal_g> +/// WHEN (<expr_a> = <literal_d>) THEN <literal_h> +/// ELSE <optional-fallback_literal> +/// END +/// ``` +/// +#[derive(Debug)] +pub(in super::super) struct LiteralLookupTable { + /// The lookup table to use for evaluating the CASE expression + lookup: Box<dyn WhenLiteralIndexMap>, + + else_index: u32, + + /// [`ArrayRef`] where `array[i] = then_literals[i]` + /// the last value in the array is the else_expr + /// + /// This will be used to take from based on the indices returned by the lookup table to build the final output + then_and_else_values: ArrayRef, +} + +impl LiteralLookupTable { + pub(in super::super) fn maybe_new(body: &CaseBody) -> Option<Self> { + // We can't use the optimization if we don't have any when then pairs + if body.when_then_expr.is_empty() { + return None; + } + + // If we only have 1 than this optimization is not useful + if body.when_then_expr.len() == 1 { + return None; + } + + // Try to downcast all the WHEN/THEN expressions to literals + let when_then_exprs_maybe_literals = body + .when_then_expr + .iter() + .map(|(when, then)| { + let when_maybe_literal = when.as_any().downcast_ref::<Literal>(); + let then_maybe_literal = then.as_any().downcast_ref::<Literal>(); + + when_maybe_literal.zip(then_maybe_literal) + }) + .collect::<Vec<_>>(); + + // If not all the WHEN/THEN expressions are literals we cannot use this optimization + if when_then_exprs_maybe_literals.contains(&None) { + return None; + } + + let when_then_exprs_scalars = when_then_exprs_maybe_literals + .into_iter() + // Unwrap the options as we have already checked there is no None + .flatten() + .map(|(when_lit, then_lit)| { + (when_lit.value().clone(), then_lit.value().clone()) + }) + // Only keep non-null WHEN literals + // as they cannot be matched - case NULL WHEN NULL THEN ... ELSE ... END always goes to ELSE + .filter(|(when_lit, _)| !when_lit.is_null()) + .collect::<Vec<_>>(); + + if when_then_exprs_scalars.is_empty() { + // All WHEN literals were nulls, so cannot use optimization + // + // instead, another optimization would be to go straight to the ELSE clause + return None; + } + + // Keep only the first occurrence of each when literal (as the first match is used) + // and remove nulls (as they cannot be matched - case NULL WHEN NULL THEN ... ELSE ... END always goes to ELSE) + let (when, then): (Vec<ScalarValue>, Vec<ScalarValue>) = { + let mut map = IndexMap::with_capacity(body.when_then_expr.len()); + + for (when, then) in when_then_exprs_scalars.into_iter() { + // Don't overwrite existing entries as we want to keep the first occurrence + if !map.contains_key(&when) { + map.insert(when, then); + } + } + + map.into_iter().unzip() + }; + + let else_value: ScalarValue = if let Some(else_expr) = &body.else_expr { + let literal = else_expr.as_any().downcast_ref::<Literal>()?; + + literal.value().clone() + } else { + let Ok(null_scalar) = ScalarValue::try_new_null(&then[0].data_type()) else { + return None; + }; + + null_scalar + }; + + { + let when_data_type = when[0].data_type(); + + // If not all the WHEN literals are the same data type we cannot use this optimization Review Comment: By the time we get here I would expect that all when literals are the same type (so I would expect this code not to be called) ########## datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs: ########## @@ -0,0 +1,327 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod boolean_lookup_table; +mod bytes_like_lookup_table; +mod primitive_lookup_table; + +use crate::expressions::case::literal_lookup_table::boolean_lookup_table::BooleanIndexMap; +use crate::expressions::case::literal_lookup_table::bytes_like_lookup_table::BytesLikeIndexMap; +use crate::expressions::case::literal_lookup_table::primitive_lookup_table::PrimitiveIndexMap; +use crate::expressions::case::CaseBody; +use crate::expressions::Literal; +use arrow::array::{downcast_primitive, Array, ArrayRef, UInt32Array}; +use arrow::datatypes::DataType; +use datafusion_common::{arrow_datafusion_err, plan_datafusion_err, ScalarValue}; +use indexmap::IndexMap; +use std::fmt::Debug; + +/// Optimization for CASE expressions with literal WHEN and THEN clauses +/// +/// for this form: +/// ```sql +/// CASE <expr_a> +/// WHEN <literal_a> THEN <literal_e> +/// WHEN <literal_b> THEN <literal_f> +/// WHEN <literal_c> THEN <literal_g> +/// WHEN <literal_d> THEN <literal_h> +/// ELSE <optional-fallback_literal> +/// END +/// ``` +/// +/// # Improvement idea +/// TODO - we should think of unwrapping the `IN` expressions into multiple equality comparisons +/// so it will use this optimization as well, e.g. +/// ```sql +/// -- Before +/// CASE +/// WHEN (<expr_a> = <literal_a>) THEN <literal_e> +/// WHEN (<expr_a> in (<literal_b>, <literal_c>) THEN <literal_f> +/// WHEN (<expr_a> = <literal_d>) THEN <literal_g> +/// ELSE <optional-fallback_literal> +/// +/// -- After +/// CASE +/// WHEN (<expr_a> = <literal_a>) THEN <literal_e> +/// WHEN (<expr_a> = <literal_b>) THEN <literal_f> +/// WHEN (<expr_a> = <literal_c>) THEN <literal_g> +/// WHEN (<expr_a> = <literal_d>) THEN <literal_h> +/// ELSE <optional-fallback_literal> +/// END +/// ``` +/// +#[derive(Debug)] +pub(in super::super) struct LiteralLookupTable { + /// The lookup table to use for evaluating the CASE expression + lookup: Box<dyn WhenLiteralIndexMap>, + + else_index: u32, + + /// [`ArrayRef`] where `array[i] = then_literals[i]` + /// the last value in the array is the else_expr + /// + /// This will be used to take from based on the indices returned by the lookup table to build the final output + then_and_else_values: ArrayRef, +} + +impl LiteralLookupTable { + pub(in super::super) fn maybe_new(body: &CaseBody) -> Option<Self> { + // We can't use the optimization if we don't have any when then pairs + if body.when_then_expr.is_empty() { + return None; + } + + // If we only have 1 than this optimization is not useful + if body.when_then_expr.len() == 1 { + return None; + } + + // Try to downcast all the WHEN/THEN expressions to literals + let when_then_exprs_maybe_literals = body + .when_then_expr + .iter() + .map(|(when, then)| { + let when_maybe_literal = when.as_any().downcast_ref::<Literal>(); + let then_maybe_literal = then.as_any().downcast_ref::<Literal>(); + + when_maybe_literal.zip(then_maybe_literal) + }) + .collect::<Vec<_>>(); + + // If not all the WHEN/THEN expressions are literals we cannot use this optimization + if when_then_exprs_maybe_literals.contains(&None) { + return None; + } + + let when_then_exprs_scalars = when_then_exprs_maybe_literals + .into_iter() + // Unwrap the options as we have already checked there is no None + .flatten() + .map(|(when_lit, then_lit)| { + (when_lit.value().clone(), then_lit.value().clone()) + }) + // Only keep non-null WHEN literals + // as they cannot be matched - case NULL WHEN NULL THEN ... ELSE ... END always goes to ELSE + .filter(|(when_lit, _)| !when_lit.is_null()) + .collect::<Vec<_>>(); + + if when_then_exprs_scalars.is_empty() { + // All WHEN literals were nulls, so cannot use optimization + // + // instead, another optimization would be to go straight to the ELSE clause + return None; + } + + // Keep only the first occurrence of each when literal (as the first match is used) + // and remove nulls (as they cannot be matched - case NULL WHEN NULL THEN ... ELSE ... END always goes to ELSE) + let (when, then): (Vec<ScalarValue>, Vec<ScalarValue>) = { + let mut map = IndexMap::with_capacity(body.when_then_expr.len()); + + for (when, then) in when_then_exprs_scalars.into_iter() { + // Don't overwrite existing entries as we want to keep the first occurrence + if !map.contains_key(&when) { + map.insert(when, then); + } + } + + map.into_iter().unzip() + }; + + let else_value: ScalarValue = if let Some(else_expr) = &body.else_expr { + let literal = else_expr.as_any().downcast_ref::<Literal>()?; + + literal.value().clone() + } else { + let Ok(null_scalar) = ScalarValue::try_new_null(&then[0].data_type()) else { + return None; + }; + + null_scalar + }; + + { + let when_data_type = when[0].data_type(); + + // If not all the WHEN literals are the same data type we cannot use this optimization + if when.iter().any(|l| l.data_type() != when_data_type) { + return None; + } + } + + { + let data_type = then[0].data_type(); + + // If not all the then and the else literals are the same data type we cannot use this optimization + if then.iter().any(|l| l.data_type() != data_type) { + return None; + } + + if else_value.data_type() != data_type { + return None; + } + } + + let then_and_else_values = ScalarValue::iter_to_array( + then.iter() + // The else is in the end + .chain(std::iter::once(&else_value)) + .cloned(), + ) + .ok()?; Review Comment: Why (silently) ignore any errors? (I don't expect that this would ever fail given the checks above, but it might help debugging to explicitly return an internal error or something in this case Silently ignoring the error (via .ok()) seems like it could mask errors that are preventing this optimization from actually happening. Also, we had some problems in the past with performance when using `ok()` because each DataFusionError requires a string allocation, so creating one and ignoring it causes non trivial overhead. Since this code is called once per expression, I suspect it isn't a big performance problem, but I do think the masking of errors is concerning -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
