alamb commented on code in PR #18152:
URL: https://github.com/apache/datafusion/pull/18152#discussion_r2462721912


##########
datafusion/physical-expr/src/expressions/case.rs:
##########
@@ -122,6 +123,276 @@ fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>) 
-> bool {
     expr.as_any().is::<Column>()
 }
 
+/// Creates a [FilterPredicate] from a boolean array.
+fn create_filter(predicate: &BooleanArray) -> FilterPredicate {
+    let mut filter_builder = FilterBuilder::new(predicate);
+    // Always optimize the filter since we use them multiple times.
+    filter_builder = filter_builder.optimize();
+    filter_builder.build()
+}
+
+// This should be removed when https://github.com/apache/arrow-rs/pull/8693
+// is merged and becomes available.
+fn filter_record_batch(
+    record_batch: &RecordBatch,
+    filter: &FilterPredicate,
+) -> std::result::Result<RecordBatch, ArrowError> {
+    let filtered_columns = record_batch
+        .columns()
+        .iter()
+        .map(|a| filter_array(a, filter))
+        .collect::<std::result::Result<Vec<_>, _>>()?;
+    // SAFETY: since we start from a valid RecordBatch, there's no need to 
revalidate the schema
+    // since the set of columns has not changed.
+    // The input column arrays all had the same length (since they're coming 
from a valid RecordBatch)
+    // and the filtering them with the same filter will produces a new set of 
arrays with identical
+    // lengths.
+    unsafe {
+        Ok(RecordBatch::new_unchecked(
+            record_batch.schema(),
+            filtered_columns,
+            filter.count(),
+        ))
+    }
+}
+
+#[inline(always)]
+fn filter_array(

Review Comment:
   a minor nit is why bother with a function here (why not call `filter.filter` 
directly)?



##########
datafusion/physical-expr/src/expressions/case.rs:
##########
@@ -122,6 +123,276 @@ fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>) 
-> bool {
     expr.as_any().is::<Column>()
 }
 
+/// Creates a [FilterPredicate] from a boolean array.
+fn create_filter(predicate: &BooleanArray) -> FilterPredicate {
+    let mut filter_builder = FilterBuilder::new(predicate);
+    // Always optimize the filter since we use them multiple times.
+    filter_builder = filter_builder.optimize();
+    filter_builder.build()
+}
+
+// This should be removed when https://github.com/apache/arrow-rs/pull/8693
+// is merged and becomes available.
+fn filter_record_batch(
+    record_batch: &RecordBatch,
+    filter: &FilterPredicate,
+) -> std::result::Result<RecordBatch, ArrowError> {
+    let filtered_columns = record_batch
+        .columns()
+        .iter()
+        .map(|a| filter_array(a, filter))
+        .collect::<std::result::Result<Vec<_>, _>>()?;
+    // SAFETY: since we start from a valid RecordBatch, there's no need to 
revalidate the schema
+    // since the set of columns has not changed.
+    // The input column arrays all had the same length (since they're coming 
from a valid RecordBatch)
+    // and the filtering them with the same filter will produces a new set of 
arrays with identical
+    // lengths.
+    unsafe {
+        Ok(RecordBatch::new_unchecked(
+            record_batch.schema(),
+            filtered_columns,
+            filter.count(),
+        ))
+    }
+}
+
+#[inline(always)]
+fn filter_array(
+    array: &dyn Array,
+    filter: &FilterPredicate,
+) -> std::result::Result<ArrayRef, ArrowError> {
+    filter.filter(array)
+}
+
+///
+/// Merges elements by index from a list of [`ArrayData`], creating a new 
[`ColumnarValue`] from
+/// those values.
+///
+/// Each element in `indices` is the index of an array in `values` offset by 
1. `indices` is
+/// processed sequentially. The first occurrence of index value `n` will be 
mapped to the first
+/// value of array `n - 1`. The second occurrence to the second value, and so 
on.
+///
+/// The index value `0` is used to indicate null values.
+///
+/// ```text
+/// ┌─────────────────┐      ┌─────────┐                                  
┌─────────────────┐
+/// │        A        │      │    0    │        merge(                    │    
   NULL      │
+/// ├─────────────────┤      ├─────────┤          [values0, values1],     
├─────────────────┤
+/// │        D        │      │    2    │          indices                 │    
    B        │
+/// └─────────────────┘      ├─────────┤        )                         
├─────────────────┤
+///   values array 0         │    2    │      ─────────────────────────▶  │    
    C        │
+///                          ├─────────┤                                  
├─────────────────┤
+///                          │    1    │                                  │    
    A        │
+///                          ├─────────┤                                  
├─────────────────┤
+///                          │    1    │                                  │    
    D        │
+/// ┌─────────────────┐      ├─────────┤                                  
├─────────────────┤
+/// │        B        │      │    2    │                                  │    
    E        │
+/// ├─────────────────┤      └─────────┘                                  
└─────────────────┘
+/// │        C        │
+/// ├─────────────────┤        indices
+/// │        E        │         array                                      
result
+/// └─────────────────┘
+///   values array 1
+/// ```
+fn merge(values: &[ArrayData], indices: &[usize]) -> Result<ArrayRef> {
+    let data_refs = values.iter().collect();
+    let mut mutable = MutableArrayData::new(data_refs, true, indices.len());
+
+    // This loop extends the mutable array by taking slices from the partial 
results.
+    //
+    // take_offsets keeps track of how many values have been taken from each 
array.
+    let mut take_offsets = vec![0; values.len() + 1];
+    let mut start_row_ix = 0;
+    loop {
+        let array_ix = indices[start_row_ix];
+
+        // Determine the length of the slice to take.
+        let mut end_row_ix = start_row_ix + 1;
+        while end_row_ix < indices.len() && indices[end_row_ix] == array_ix {
+            end_row_ix += 1;
+        }
+        let slice_length = end_row_ix - start_row_ix;
+
+        // Extend mutable with either nulls or with values from the array.
+        let start_offset = take_offsets[array_ix];
+        let end_offset = start_offset + slice_length;
+        if array_ix == 0 {
+            mutable.extend_nulls(slice_length);
+        } else {
+            mutable.extend(array_ix - 1, start_offset, end_offset);
+        }
+
+        if end_row_ix == indices.len() {
+            break;
+        } else {
+            // Update the take_offsets array.
+            take_offsets[array_ix] = end_offset;
+            // Set the start_row_ix for the next slice.
+            start_row_ix = end_row_ix;
+        }
+    }
+
+    Ok(make_array(mutable.freeze()))
+}
+
+/// A builder for constructing result arrays for CASE expressions.
+///
+/// Rather than building a monolithic array containing all results, it 
maintains a set of
+/// partial result arrays and a mapping that indicates for each row which 
partial array
+/// contains the result value for that row.
+///
+/// On finish(), the builder will merge all partial results into a single 
array if necessary.
+/// If all rows evaluated to the same array, that array can be returned 
directly without
+/// any merging overhead.
+struct ResultBuilder {
+    data_type: DataType,
+    // A Vec of partial results that should be merged. 
`partial_result_indices` contains
+    // indexes into this vec.
+    partial_results: Vec<ArrayData>,
+    // Indicates per result row from which array in `partial_results` a value 
should be taken.
+    // The indexes in this array are offset by +1. The special value 0 
indicates null values.
+    partial_result_indices: Vec<usize>,
+    // An optional result that is the covering result for all rows.
+    // This is used as an optimisation to avoid the cost of merging when all 
rows
+    // evaluate to the same case branch.
+    covering_result: Option<ColumnarValue>,
+}
+
+impl ResultBuilder {
+    /// Creates a new ResultBuilder that will produce arrays of the given data 
type.
+    ///
+    /// The capacity parameter indicates the number of rows in the result.
+    fn new(data_type: &DataType, capacity: usize) -> Self {
+        Self {
+            data_type: data_type.clone(),
+            partial_result_indices: vec![0; capacity],
+            partial_results: vec![],
+            covering_result: None,
+        }
+    }
+
+    /// Adds a result for one branch of the case expression.
+    ///
+    /// `row_indices` should be a [UInt32Array] containing [RecordBatch] 
relative row indices
+    /// for which `value` contains result values.
+    ///
+    /// If `value` is a scalar, the scalar value will be used as the value for 
each row in `row_indices`.
+    ///
+    /// If `value` is an array, the values from the array and the indices from 
`row_indices` will be
+    /// processed pairwise. The lengths of `value` and `row_indices` must 
match.
+    ///
+    /// The diagram below shows a situation where a when expression matched 
rows 1 and 4 of the
+    /// record batch. The then expression produced the value array `[A, D]`.
+    /// After adding this result, the result array will have been added to 
`partial_results` and
+    /// `partial_indices` will have been updated at indexes 1 and 4.
+    ///
+    /// ```text
+    /// ┌─────────┐     ┌─────────┐┌───────────┐                            
┌─────────┐┌───────────┐
+    /// │    A    │     │    0    ││           │                            │  
  0    ││┌─────────┐│
+    /// ├─────────┤     ├─────────┤│           │                            
├─────────┤││    A    ││
+    /// │    D    │     │    0    ││           │                            │  
  1    ││├─────────┤│
+    /// └─────────┘     ├─────────┤│           │   add_branch_result(       
├─────────┤││    D    ││
+    ///   value         │    0    ││           │     row indices,           │  
  0    ││└─────────┘│
+    ///                 ├─────────┤│           │     value                  
├─────────┤│           │
+    ///                 │    0    ││           │   )                        │  
  0    ││           │
+    /// ┌─────────┐     ├─────────┤│           │ ─────────────────────────▶ 
├─────────┤│           │
+    /// │    1    │     │    0    ││           │                            │  
  1    ││           │
+    /// ├─────────┤     ├─────────┤│           │                            
├─────────┤│           │
+    /// │    4    │     │    0    ││           │                            │  
  0    ││           │
+    /// └─────────┘     └─────────┘└───────────┘                            
└─────────┘└───────────┘
+    /// row indices
+    ///                   partial     partial                                 
partial     partial
+    ///                   indices     results                                 
indices     results
+    /// ```
+    fn add_branch_result(
+        &mut self,
+        row_indices: &ArrayRef,
+        value: ColumnarValue,
+    ) -> Result<()> {
+        match value {
+            ColumnarValue::Array(a) => {
+                assert_eq!(a.len(), row_indices.len());
+                if row_indices.len() == self.partial_result_indices.len() {

Review Comment:
   I expected this also to have to check the values in row_indices to make sure 
all rows were covered. Does this assume that there are no duplicates in 
row_indices?



##########
datafusion/physical-expr/src/expressions/case.rs:
##########
@@ -122,6 +123,276 @@ fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>) 
-> bool {
     expr.as_any().is::<Column>()
 }
 
+/// Creates a [FilterPredicate] from a boolean array.
+fn create_filter(predicate: &BooleanArray) -> FilterPredicate {
+    let mut filter_builder = FilterBuilder::new(predicate);
+    // Always optimize the filter since we use them multiple times.
+    filter_builder = filter_builder.optimize();
+    filter_builder.build()
+}
+
+// This should be removed when https://github.com/apache/arrow-rs/pull/8693
+// is merged and becomes available.
+fn filter_record_batch(
+    record_batch: &RecordBatch,
+    filter: &FilterPredicate,
+) -> std::result::Result<RecordBatch, ArrowError> {
+    let filtered_columns = record_batch
+        .columns()
+        .iter()
+        .map(|a| filter_array(a, filter))
+        .collect::<std::result::Result<Vec<_>, _>>()?;
+    // SAFETY: since we start from a valid RecordBatch, there's no need to 
revalidate the schema
+    // since the set of columns has not changed.
+    // The input column arrays all had the same length (since they're coming 
from a valid RecordBatch)
+    // and the filtering them with the same filter will produces a new set of 
arrays with identical
+    // lengths.
+    unsafe {
+        Ok(RecordBatch::new_unchecked(
+            record_batch.schema(),
+            filtered_columns,
+            filter.count(),
+        ))
+    }
+}
+
+#[inline(always)]
+fn filter_array(
+    array: &dyn Array,
+    filter: &FilterPredicate,
+) -> std::result::Result<ArrayRef, ArrowError> {
+    filter.filter(array)
+}
+
+///

Review Comment:
   I recommend also adding a comment here explaining that this is a specialized 
version of the upstream `interleave` kernel just for case expressions and that 
we aren't using `interleave` because this specialized version is faster



##########
datafusion/physical-expr/src/expressions/case.rs:
##########
@@ -122,6 +123,276 @@ fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>) 
-> bool {
     expr.as_any().is::<Column>()
 }
 
+/// Creates a [FilterPredicate] from a boolean array.
+fn create_filter(predicate: &BooleanArray) -> FilterPredicate {
+    let mut filter_builder = FilterBuilder::new(predicate);
+    // Always optimize the filter since we use them multiple times.
+    filter_builder = filter_builder.optimize();
+    filter_builder.build()
+}
+
+// This should be removed when https://github.com/apache/arrow-rs/pull/8693
+// is merged and becomes available.
+fn filter_record_batch(
+    record_batch: &RecordBatch,
+    filter: &FilterPredicate,
+) -> std::result::Result<RecordBatch, ArrowError> {
+    let filtered_columns = record_batch
+        .columns()
+        .iter()
+        .map(|a| filter_array(a, filter))
+        .collect::<std::result::Result<Vec<_>, _>>()?;
+    // SAFETY: since we start from a valid RecordBatch, there's no need to 
revalidate the schema
+    // since the set of columns has not changed.
+    // The input column arrays all had the same length (since they're coming 
from a valid RecordBatch)
+    // and the filtering them with the same filter will produces a new set of 
arrays with identical
+    // lengths.
+    unsafe {
+        Ok(RecordBatch::new_unchecked(
+            record_batch.schema(),
+            filtered_columns,
+            filter.count(),
+        ))
+    }
+}
+
+#[inline(always)]
+fn filter_array(
+    array: &dyn Array,
+    filter: &FilterPredicate,
+) -> std::result::Result<ArrayRef, ArrowError> {
+    filter.filter(array)
+}
+
+///
+/// Merges elements by index from a list of [`ArrayData`], creating a new 
[`ColumnarValue`] from
+/// those values.
+///
+/// Each element in `indices` is the index of an array in `values` offset by 
1. `indices` is
+/// processed sequentially. The first occurrence of index value `n` will be 
mapped to the first
+/// value of array `n - 1`. The second occurrence to the second value, and so 
on.
+///
+/// The index value `0` is used to indicate null values.
+///
+/// ```text
+/// ┌─────────────────┐      ┌─────────┐                                  
┌─────────────────┐
+/// │        A        │      │    0    │        merge(                    │    
   NULL      │
+/// ├─────────────────┤      ├─────────┤          [values0, values1],     
├─────────────────┤
+/// │        D        │      │    2    │          indices                 │    
    B        │
+/// └─────────────────┘      ├─────────┤        )                         
├─────────────────┤
+///   values array 0         │    2    │      ─────────────────────────▶  │    
    C        │
+///                          ├─────────┤                                  
├─────────────────┤
+///                          │    1    │                                  │    
    A        │
+///                          ├─────────┤                                  
├─────────────────┤
+///                          │    1    │                                  │    
    D        │
+/// ┌─────────────────┐      ├─────────┤                                  
├─────────────────┤
+/// │        B        │      │    2    │                                  │    
    E        │
+/// ├─────────────────┤      └─────────┘                                  
└─────────────────┘
+/// │        C        │
+/// ├─────────────────┤        indices
+/// │        E        │         array                                      
result
+/// └─────────────────┘
+///   values array 1
+/// ```
+fn merge(values: &[ArrayData], indices: &[usize]) -> Result<ArrayRef> {

Review Comment:
   This seems very similar to the zip kernel: 
https://docs.rs/arrow/latest/arrow/compute/kernels/zip/fn.zip.html
   
   @rluvaton is also in the process of optimizing that kernel:
   - https://github.com/apache/arrow-rs/pull/8653
   
   Thus, if this code was to use zip it could take advantage of those 
optimizations when they land
   
   
   Edit: I was confused by the example that only shows values 0 and 1 -- I see 
now this works for an arbitrary number of input arrays (more than 2)



##########
datafusion/physical-expr/src/expressions/case.rs:
##########
@@ -196,82 +467,135 @@ impl CaseExpr {
     /// END
     fn case_when_with_expr(&self, batch: &RecordBatch) -> 
Result<ColumnarValue> {
         let return_type = self.data_type(&batch.schema())?;
-        let expr = self.expr.as_ref().unwrap();
-        let base_value = expr.evaluate(batch)?;
-        let base_value = base_value.into_array(batch.num_rows())?;
+        let mut result_builder = ResultBuilder::new(&return_type, 
batch.num_rows());
+
+        // `remainder_rows` contains the indices of the rows that need to be 
evaluated
+        let mut remainder_rows: ArrayRef =
+            Arc::new(UInt32Array::from_iter_values(0..batch.num_rows() as 
u32));
+        // `remainder_batch` contains the rows themselves that need to be 
evaluated
+        let mut remainder_batch = Cow::Borrowed(batch);
+
+        // evaluate the base expression
+        let mut base_value = self
+            .expr
+            .as_ref()
+            .unwrap()
+            .evaluate(batch)?
+            .into_array(batch.num_rows())?;

Review Comment:
   we can probably make this even faster by avoiding this `into_array` call and 
handling the ColumnarValue specially



##########
datafusion/physical-expr/src/expressions/case.rs:
##########
@@ -122,6 +123,276 @@ fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>) 
-> bool {
     expr.as_any().is::<Column>()
 }
 
+/// Creates a [FilterPredicate] from a boolean array.
+fn create_filter(predicate: &BooleanArray) -> FilterPredicate {
+    let mut filter_builder = FilterBuilder::new(predicate);
+    // Always optimize the filter since we use them multiple times.
+    filter_builder = filter_builder.optimize();
+    filter_builder.build()
+}
+
+// This should be removed when https://github.com/apache/arrow-rs/pull/8693
+// is merged and becomes available.
+fn filter_record_batch(
+    record_batch: &RecordBatch,
+    filter: &FilterPredicate,
+) -> std::result::Result<RecordBatch, ArrowError> {
+    let filtered_columns = record_batch
+        .columns()
+        .iter()
+        .map(|a| filter_array(a, filter))
+        .collect::<std::result::Result<Vec<_>, _>>()?;
+    // SAFETY: since we start from a valid RecordBatch, there's no need to 
revalidate the schema
+    // since the set of columns has not changed.
+    // The input column arrays all had the same length (since they're coming 
from a valid RecordBatch)
+    // and the filtering them with the same filter will produces a new set of 
arrays with identical
+    // lengths.
+    unsafe {
+        Ok(RecordBatch::new_unchecked(
+            record_batch.schema(),
+            filtered_columns,
+            filter.count(),
+        ))
+    }
+}
+
+#[inline(always)]
+fn filter_array(
+    array: &dyn Array,
+    filter: &FilterPredicate,
+) -> std::result::Result<ArrayRef, ArrowError> {
+    filter.filter(array)
+}
+
+///
+/// Merges elements by index from a list of [`ArrayData`], creating a new 
[`ColumnarValue`] from
+/// those values.
+///
+/// Each element in `indices` is the index of an array in `values` offset by 
1. `indices` is
+/// processed sequentially. The first occurrence of index value `n` will be 
mapped to the first
+/// value of array `n - 1`. The second occurrence to the second value, and so 
on.
+///
+/// The index value `0` is used to indicate null values.

Review Comment:
   If you used `USIZE::MAX` for the null value it might make the code easier to 
read as there wouldn't be as many special cases



##########
datafusion/physical-expr/src/expressions/case.rs:
##########
@@ -122,6 +123,181 @@ fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>) 
-> bool {
     expr.as_any().is::<Column>()
 }
 
+/// Creates a [FilterPredicate] from a boolean array.
+fn create_filter(predicate: &BooleanArray, optimize: bool) -> FilterPredicate {
+    let mut filter_builder = FilterBuilder::new(predicate);
+    if optimize {
+        filter_builder = filter_builder.optimize();
+    }
+    filter_builder.build()
+}
+
+fn filter_record_batch(
+    record_batch: &RecordBatch,
+    filter: &FilterPredicate,
+) -> std::result::Result<RecordBatch, ArrowError> {
+    let filtered_columns = record_batch
+        .columns()
+        .iter()
+        .map(|a| filter_array(a, filter))
+        .collect::<std::result::Result<Vec<_>, _>>()?;
+    unsafe {
+        Ok(RecordBatch::new_unchecked(
+            record_batch.schema(),
+            filtered_columns,
+            filter.count(),
+        ))
+    }
+}
+
+#[inline(always)]
+fn filter_array(
+    array: &dyn Array,
+    filter: &FilterPredicate,
+) -> std::result::Result<ArrayRef, ArrowError> {
+    filter.filter(array)
+}
+
+struct ResultBuilder {
+    data_type: DataType,
+    // A Vec of partial results that should be merged. 
`partial_result_indices` contains
+    // indexes into this vec.
+    partial_results: Vec<ArrayData>,
+    // Indicates per result row from which array in `partial_results` a value 
should be taken.
+    // The indexes in this array are offset by +1. The special value 0 
indicates null values.
+    partial_result_indices: Vec<usize>,
+    // An optional result that is the covering result for all rows.
+    // This is used as an optimisation to avoid the cost of merging when all 
rows
+    // evaluate to the same case branch.
+    covering_result: Option<ColumnarValue>,

Review Comment:
   "single value" maybe?
   
   Another way to make it clearer might be to use an enum like
   
   ```rust
   enum ResultState {
       /// A subset of the result should be returned
       Partial {
         // A Vec of partial results that should be merged. 
`partial_result_indices` contains
         // indexes into this vec.
         results: Vec<ArrayData>,
         // Indicates per result row from which array in `partial_results` a 
value should be taken.
         // The indexes in this array are offset by +1. The special value 0 
indicates null values.
         indices: Vec<usize>
       },
       // An optional result that is the covering result for all rows.
       // This is used as an optimisation to avoid the cost of merging when all 
rows
       // evaluate to the same case branch.
       Covering(ColumnarValue)
   }
   
   struct ResultBuilder { 
   ...
      /// The inprogress result
      result:  ResultState,
   ...
   }
   ```
   
   That way the compiler can verify the output



##########
datafusion/physical-expr/src/expressions/case.rs:
##########
@@ -122,6 +123,276 @@ fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>) 
-> bool {
     expr.as_any().is::<Column>()
 }
 
+/// Creates a [FilterPredicate] from a boolean array.
+fn create_filter(predicate: &BooleanArray) -> FilterPredicate {
+    let mut filter_builder = FilterBuilder::new(predicate);
+    // Always optimize the filter since we use them multiple times.
+    filter_builder = filter_builder.optimize();
+    filter_builder.build()
+}
+
+// This should be removed when https://github.com/apache/arrow-rs/pull/8693
+// is merged and becomes available.
+fn filter_record_batch(
+    record_batch: &RecordBatch,
+    filter: &FilterPredicate,
+) -> std::result::Result<RecordBatch, ArrowError> {
+    let filtered_columns = record_batch
+        .columns()
+        .iter()
+        .map(|a| filter_array(a, filter))
+        .collect::<std::result::Result<Vec<_>, _>>()?;
+    // SAFETY: since we start from a valid RecordBatch, there's no need to 
revalidate the schema
+    // since the set of columns has not changed.
+    // The input column arrays all had the same length (since they're coming 
from a valid RecordBatch)
+    // and the filtering them with the same filter will produces a new set of 
arrays with identical
+    // lengths.
+    unsafe {
+        Ok(RecordBatch::new_unchecked(
+            record_batch.schema(),
+            filtered_columns,
+            filter.count(),
+        ))
+    }
+}
+
+#[inline(always)]
+fn filter_array(
+    array: &dyn Array,
+    filter: &FilterPredicate,
+) -> std::result::Result<ArrayRef, ArrowError> {
+    filter.filter(array)
+}
+
+///
+/// Merges elements by index from a list of [`ArrayData`], creating a new 
[`ColumnarValue`] from
+/// those values.
+///
+/// Each element in `indices` is the index of an array in `values` offset by 
1. `indices` is
+/// processed sequentially. The first occurrence of index value `n` will be 
mapped to the first
+/// value of array `n - 1`. The second occurrence to the second value, and so 
on.
+///
+/// The index value `0` is used to indicate null values.
+///
+/// ```text
+/// ┌─────────────────┐      ┌─────────┐                                  
┌─────────────────┐
+/// │        A        │      │    0    │        merge(                    │    
   NULL      │
+/// ├─────────────────┤      ├─────────┤          [values0, values1],     
├─────────────────┤
+/// │        D        │      │    2    │          indices                 │    
    B        │
+/// └─────────────────┘      ├─────────┤        )                         
├─────────────────┤
+///   values array 0         │    2    │      ─────────────────────────▶  │    
    C        │
+///                          ├─────────┤                                  
├─────────────────┤
+///                          │    1    │                                  │    
    A        │
+///                          ├─────────┤                                  
├─────────────────┤
+///                          │    1    │                                  │    
    D        │
+/// ┌─────────────────┐      ├─────────┤                                  
├─────────────────┤
+/// │        B        │      │    2    │                                  │    
    E        │
+/// ├─────────────────┤      └─────────┘                                  
└─────────────────┘
+/// │        C        │
+/// ├─────────────────┤        indices
+/// │        E        │         array                                      
result
+/// └─────────────────┘
+///   values array 1
+/// ```
+fn merge(values: &[ArrayData], indices: &[usize]) -> Result<ArrayRef> {
+    let data_refs = values.iter().collect();
+    let mut mutable = MutableArrayData::new(data_refs, true, indices.len());
+
+    // This loop extends the mutable array by taking slices from the partial 
results.
+    //
+    // take_offsets keeps track of how many values have been taken from each 
array.
+    let mut take_offsets = vec![0; values.len() + 1];
+    let mut start_row_ix = 0;
+    loop {
+        let array_ix = indices[start_row_ix];
+
+        // Determine the length of the slice to take.
+        let mut end_row_ix = start_row_ix + 1;
+        while end_row_ix < indices.len() && indices[end_row_ix] == array_ix {
+            end_row_ix += 1;
+        }
+        let slice_length = end_row_ix - start_row_ix;
+
+        // Extend mutable with either nulls or with values from the array.
+        let start_offset = take_offsets[array_ix];
+        let end_offset = start_offset + slice_length;
+        if array_ix == 0 {
+            mutable.extend_nulls(slice_length);
+        } else {
+            mutable.extend(array_ix - 1, start_offset, end_offset);
+        }
+
+        if end_row_ix == indices.len() {
+            break;
+        } else {
+            // Update the take_offsets array.
+            take_offsets[array_ix] = end_offset;
+            // Set the start_row_ix for the next slice.
+            start_row_ix = end_row_ix;
+        }
+    }
+
+    Ok(make_array(mutable.freeze()))
+}
+
+/// A builder for constructing result arrays for CASE expressions.
+///
+/// Rather than building a monolithic array containing all results, it 
maintains a set of
+/// partial result arrays and a mapping that indicates for each row which 
partial array
+/// contains the result value for that row.
+///
+/// On finish(), the builder will merge all partial results into a single 
array if necessary.
+/// If all rows evaluated to the same array, that array can be returned 
directly without
+/// any merging overhead.
+struct ResultBuilder {
+    data_type: DataType,
+    // A Vec of partial results that should be merged. 
`partial_result_indices` contains
+    // indexes into this vec.
+    partial_results: Vec<ArrayData>,
+    // Indicates per result row from which array in `partial_results` a value 
should be taken.
+    // The indexes in this array are offset by +1. The special value 0 
indicates null values.
+    partial_result_indices: Vec<usize>,
+    // An optional result that is the covering result for all rows.
+    // This is used as an optimisation to avoid the cost of merging when all 
rows
+    // evaluate to the same case branch.
+    covering_result: Option<ColumnarValue>,
+}
+
+impl ResultBuilder {
+    /// Creates a new ResultBuilder that will produce arrays of the given data 
type.
+    ///
+    /// The capacity parameter indicates the number of rows in the result.
+    fn new(data_type: &DataType, capacity: usize) -> Self {
+        Self {
+            data_type: data_type.clone(),
+            partial_result_indices: vec![0; capacity],
+            partial_results: vec![],
+            covering_result: None,
+        }
+    }
+
+    /// Adds a result for one branch of the case expression.
+    ///
+    /// `row_indices` should be a [UInt32Array] containing [RecordBatch] 
relative row indices
+    /// for which `value` contains result values.
+    ///
+    /// If `value` is a scalar, the scalar value will be used as the value for 
each row in `row_indices`.
+    ///
+    /// If `value` is an array, the values from the array and the indices from 
`row_indices` will be
+    /// processed pairwise. The lengths of `value` and `row_indices` must 
match.
+    ///
+    /// The diagram below shows a situation where a when expression matched 
rows 1 and 4 of the
+    /// record batch. The then expression produced the value array `[A, D]`.
+    /// After adding this result, the result array will have been added to 
`partial_results` and
+    /// `partial_indices` will have been updated at indexes 1 and 4.
+    ///
+    /// ```text
+    /// ┌─────────┐     ┌─────────┐┌───────────┐                            
┌─────────┐┌───────────┐
+    /// │    A    │     │    0    ││           │                            │  
  0    ││┌─────────┐│
+    /// ├─────────┤     ├─────────┤│           │                            
├─────────┤││    A    ││
+    /// │    D    │     │    0    ││           │                            │  
  1    ││├─────────┤│
+    /// └─────────┘     ├─────────┤│           │   add_branch_result(       
├─────────┤││    D    ││
+    ///   value         │    0    ││           │     row indices,           │  
  0    ││└─────────┘│
+    ///                 ├─────────┤│           │     value                  
├─────────┤│           │
+    ///                 │    0    ││           │   )                        │  
  0    ││           │
+    /// ┌─────────┐     ├─────────┤│           │ ─────────────────────────▶ 
├─────────┤│           │
+    /// │    1    │     │    0    ││           │                            │  
  1    ││           │
+    /// ├─────────┤     ├─────────┤│           │                            
├─────────┤│           │
+    /// │    4    │     │    0    ││           │                            │  
  0    ││           │
+    /// └─────────┘     └─────────┘└───────────┘                            
└─────────┘└───────────┘
+    /// row indices
+    ///                   partial     partial                                 
partial     partial
+    ///                   indices     results                                 
indices     results
+    /// ```
+    fn add_branch_result(
+        &mut self,
+        row_indices: &ArrayRef,
+        value: ColumnarValue,
+    ) -> Result<()> {
+        match value {
+            ColumnarValue::Array(a) => {
+                assert_eq!(a.len(), row_indices.len());
+                if row_indices.len() == self.partial_result_indices.len() {
+                    self.set_covering_result(ColumnarValue::Array(a));
+                } else {
+                    self.add_partial_result(row_indices, a.to_data());
+                }
+            }
+            ColumnarValue::Scalar(s) => {
+                if row_indices.len() == self.partial_result_indices.len() {
+                    self.set_covering_result(ColumnarValue::Scalar(s));
+                } else {
+                    self.add_partial_result(
+                        row_indices,
+                        s.to_array_of_size(row_indices.len())?.to_data(),
+                    );
+                }
+            }
+        }
+        Ok(())
+    }
+
+    /// Adds a partial result array.
+    ///
+    /// This method adds the given array data as a partial result and updates 
the index mapping
+    /// to indicate that the specified rows should take their values from this 
array.
+    /// The partial results will be merged into a single array when finish() 
is called.
+    fn add_partial_result(&mut self, row_indices: &ArrayRef, row_values: 
ArrayData) {
+        // Covering results and partial results are mutually exclusive.
+        // We can assert this since the case evaluation methods are written to 
only evaluate
+        // each row of the record batch once.
+        assert!(self.covering_result.is_none());
+
+        self.partial_results.push(row_values);
+        let array_index = self.partial_results.len();
+
+        for row_ix in row_indices.as_primitive::<UInt32Type>().values().iter() 
{
+            self.partial_result_indices[*row_ix as usize] = array_index;
+        }
+    }
+
+    /// Sets a covering result that applies to all rows.
+    ///
+    /// This is an optimization for cases where all rows evaluate to the same 
result.
+    /// When a covering result is set, the builder will return it directly 
from finish()
+    /// without any merging overhead.
+    fn set_covering_result(&mut self, value: ColumnarValue) {
+        // Covering results and partial results are mutually exclusive.
+        // We can assert this since the case evaluation methods are written to 
only evaluate
+        // each row of the record batch once.
+        assert!(self.partial_results.is_empty());

Review Comment:
   if you used the enum above you could avoid this assert (the compiler would 
ensure it for you)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to