andygrove commented on code in PR #3611:
URL: https://github.com/apache/datafusion-comet/pull/3611#discussion_r3057984890


##########
native/spark-expr/src/array_funcs/array_exists.rs:
##########
@@ -0,0 +1,546 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::array::{Array, ArrayRef, BooleanArray, LargeListArray, ListArray};
+use arrow::buffer::NullBuffer;
+use arrow::compute::kernels::take::take;
+use arrow::datatypes::{DataType, Field, Schema, UInt32Type};
+use arrow::record_batch::RecordBatch;
+use datafusion::common::{DataFusionError, Result as DataFusionResult};
+use datafusion::logical_expr::ColumnarValue;
+use datafusion::physical_expr::PhysicalExpr;
+use std::any::Any;
+use std::fmt::{Debug, Display, Formatter};
+use std::hash::Hash;
+use std::sync::Arc;
+
+const LAMBDA_VAR_COLUMN: &str = "__comet_lambda_var";
+
+/// Decomposed list array: offsets as usize, values, and optional null buffer.
+struct ListComponents {
+    offsets: Vec<usize>,
+    values: ArrayRef,
+    nulls: Option<NullBuffer>,
+}
+
+impl ListComponents {
+    fn is_null(&self, row: usize) -> bool {
+        self.nulls.as_ref().is_some_and(|n| n.is_null(row))
+    }
+}
+
+fn decompose_list(array: &dyn Array) -> DataFusionResult<ListComponents> {
+    if let Some(list) = array.as_any().downcast_ref::<ListArray>() {
+        Ok(ListComponents {
+            offsets: list.offsets().iter().map(|&o| o as usize).collect(),
+            values: Arc::clone(list.values()),
+            nulls: list.nulls().cloned(),
+        })
+    } else if let Some(large) = 
array.as_any().downcast_ref::<LargeListArray>() {
+        Ok(ListComponents {
+            offsets: large.offsets().iter().map(|&o| o as usize).collect(),
+            values: Arc::clone(large.values()),
+            nulls: large.nulls().cloned(),
+        })
+    } else {
+        Err(DataFusionError::Internal(
+            "ArrayExists expects a ListArray or LargeListArray 
input".to_string(),
+        ))
+    }
+}
+
+/// Spark-compatible `array_exists(array, x -> predicate(x))`.
+///
+/// Evaluates the lambda body vectorized over all elements in a single pass 
rather
+/// than per-element to avoid repeated batch construction overhead.
+#[derive(Debug, Eq)]
+pub struct ArrayExistsExpr {
+    array_expr: Arc<dyn PhysicalExpr>,
+    lambda_body: Arc<dyn PhysicalExpr>,
+    follow_three_valued_logic: bool,
+}
+
+impl Hash for ArrayExistsExpr {
+    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+        self.array_expr.hash(state);
+        self.lambda_body.hash(state);
+        self.follow_three_valued_logic.hash(state);
+    }
+}
+
+impl PartialEq for ArrayExistsExpr {
+    fn eq(&self, other: &Self) -> bool {
+        self.array_expr.eq(&other.array_expr)
+            && self.lambda_body.eq(&other.lambda_body)
+            && self
+                .follow_three_valued_logic
+                .eq(&other.follow_three_valued_logic)
+    }
+}
+
+impl ArrayExistsExpr {
+    pub fn new(
+        array_expr: Arc<dyn PhysicalExpr>,
+        lambda_body: Arc<dyn PhysicalExpr>,
+        follow_three_valued_logic: bool,
+    ) -> Self {
+        Self {
+            array_expr,
+            lambda_body,
+            follow_three_valued_logic,
+        }
+    }
+}
+
+impl PhysicalExpr for ArrayExistsExpr {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+        Display::fmt(self, f)
+    }
+
+    fn data_type(&self, _input_schema: &Schema) -> DataFusionResult<DataType> {
+        Ok(DataType::Boolean)
+    }
+
+    fn nullable(&self, _input_schema: &Schema) -> DataFusionResult<bool> {
+        Ok(true)
+    }
+
+    fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> 
{
+        let num_rows = batch.num_rows();
+
+        let array_value = 
self.array_expr.evaluate(batch)?.into_array(num_rows)?;
+        let list = decompose_list(array_value.as_ref())?;
+        let total_elements = list.values.len();
+
+        if total_elements == 0 {
+            let mut result_builder = BooleanArray::builder(num_rows);
+            for row in 0..num_rows {
+                if list.is_null(row) {
+                    result_builder.append_null();
+                } else {
+                    result_builder.append_value(false);
+                }
+            }
+            return Ok(ColumnarValue::Array(Arc::new(result_builder.finish())));
+        }
+
+        let mut repeat_indices = Vec::with_capacity(total_elements);
+        for row in 0..num_rows {
+            let start = list.offsets[row];
+            let end = list.offsets[row + 1];
+            for _ in start..end {
+                repeat_indices.push(row as u32);
+            }
+        }
+
+        let repeat_indices_array = 
arrow::array::PrimitiveArray::<UInt32Type>::from(repeat_indices);
+
+        let mut expanded_columns: Vec<ArrayRef> = 
Vec::with_capacity(batch.num_columns() + 1);
+        let mut expanded_fields: Vec<Arc<Field>> = 
Vec::with_capacity(batch.num_columns() + 1);
+
+        for (i, col) in batch.columns().iter().enumerate() {
+            let expanded = take(col.as_ref(), &repeat_indices_array, None)?;
+            expanded_columns.push(expanded);
+            expanded_fields.push(Arc::new(batch.schema().field(i).clone()));
+        }

Review Comment:
   Thank @gstvg. Excellent advice. I have implemented this now.



##########
spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala:
##########
@@ -930,6 +930,82 @@ class CometArrayExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelp
     }
   }
 
+  test("array_exists - DataFrame API") {
+    val table = "t1"
+    withTable(table) {
+      sql(s"create table $table(arr array<int>, threshold int) using parquet")
+      sql(s"insert into $table values (array(1, 2, 3), 2)")
+      sql(s"insert into $table values (array(1, 2), 5)")
+      sql(s"insert into $table values (array(), 0)")
+      sql(s"insert into $table values (null, 1)")
+      sql(s"insert into $table values (array(1, null, 3), 2)")
+
+      val df = spark.table(table)
+
+      checkSparkAnswerAndOperator(df.select(exists(col("arr"), x => x > 2)))
+      checkSparkAnswerAndOperator(df.select(exists(col("arr"), x => x > 
col("threshold"))))
+      checkSparkAnswerAndOperator(
+        df.select(
+          exists(col("arr"), x => x > 0).as("any_positive"),
+          exists(col("arr"), x => x > 100).as("any_large")))
+    }
+  }
+
+  test("array_exists - DataFrame API with decimal") {
+    val table = "t1"
+    withTable(table) {
+      sql(s"create table $table(arr array<decimal(10,2)>) using parquet")
+      sql(s"insert into $table values (array(1.50, 2.75, 3.25))")
+      sql(s"insert into $table values (array(0.10, 0.20))")
+
+      val df = spark.table(table)
+      checkSparkAnswerAndOperator(df.select(exists(col("arr"), x => x > 2.0)))
+    }
+  }
+
+  test("array_exists - DataFrame API with date") {
+    val table = "t1"
+    withTable(table) {
+      sql(s"create table $table(arr array<date>) using parquet")
+      sql(s"insert into $table values (array(date'2024-01-01', 
date'2024-06-15'))")
+      sql(s"insert into $table values (array(date'2023-01-01'))")
+
+      val df = spark.table(table)
+      checkSparkAnswerAndOperator(
+        df.select(exists(col("arr"), x => x > lit("2024-03-01").cast("date"))))

Review Comment:
   Added



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to