This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 7943199c8 perf: Optimize contains expression with SIMD-based scalar 
pattern sea… (#2991)
7943199c8 is described below

commit 7943199c88a9b1e3f33afbbca430355d54d02881
Author: Shekhar Prasad Rajak <[email protected]>
AuthorDate: Tue Feb 10 03:45:33 2026 +0530

    perf: Optimize contains expression with SIMD-based scalar pattern sea… 
(#2991)
---
 native/spark-expr/src/comet_scalar_funcs.rs        |   5 +-
 native/spark-expr/src/string_funcs/contains.rs     | 246 +++++++++++++++++++++
 native/spark-expr/src/string_funcs/mod.rs          |   2 +
 .../org/apache/comet/CometExpressionSuite.scala    |  19 +-
 4 files changed, 269 insertions(+), 3 deletions(-)

diff --git a/native/spark-expr/src/comet_scalar_funcs.rs 
b/native/spark-expr/src/comet_scalar_funcs.rs
index 760dc3570..6647e01cc 100644
--- a/native/spark-expr/src/comet_scalar_funcs.rs
+++ b/native/spark-expr/src/comet_scalar_funcs.rs
@@ -22,8 +22,8 @@ use crate::math_funcs::modulo_expr::spark_modulo;
 use crate::{
     spark_array_repeat, spark_ceil, spark_decimal_div, 
spark_decimal_integral_div, spark_floor,
     spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, 
spark_round, spark_rpad,
-    spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, 
SparkDateDiff, SparkDateTrunc,
-    SparkSizeFunc, SparkStringSpace,
+    spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, 
SparkContains, SparkDateDiff,
+    SparkDateTrunc, SparkSizeFunc, SparkStringSpace,
 };
 use arrow::datatypes::DataType;
 use datafusion::common::{DataFusionError, Result as DataFusionResult};
@@ -192,6 +192,7 @@ pub fn create_comet_physical_fun_with_eval_mode(
 fn all_scalar_functions() -> Vec<Arc<ScalarUDF>> {
     vec![
         Arc::new(ScalarUDF::new_from_impl(SparkBitwiseCount::default())),
+        Arc::new(ScalarUDF::new_from_impl(SparkContains::default())),
         Arc::new(ScalarUDF::new_from_impl(SparkDateDiff::default())),
         Arc::new(ScalarUDF::new_from_impl(SparkDateTrunc::default())),
         Arc::new(ScalarUDF::new_from_impl(SparkStringSpace::default())),
diff --git a/native/spark-expr/src/string_funcs/contains.rs 
b/native/spark-expr/src/string_funcs/contains.rs
new file mode 100644
index 000000000..bc34ce9cb
--- /dev/null
+++ b/native/spark-expr/src/string_funcs/contains.rs
@@ -0,0 +1,246 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Optimized `contains` string function for Spark compatibility.
+//!
+//! Optimized for scalar pattern case by passing scalar directly to 
arrow_contains
+//! instead of expanding to arrays like DataFusion's built-in contains.
+
+use arrow::array::{Array, ArrayRef, BooleanArray, StringArray};
+use arrow::compute::kernels::comparison::contains as arrow_contains;
+use arrow::datatypes::DataType;
+use datafusion::common::{exec_err, Result, ScalarValue};
+use datafusion::logical_expr::{
+    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
+};
+use std::any::Any;
+use std::sync::Arc;
+
+/// Spark-optimized contains function.
+/// Returns true if the first string argument contains the second string 
argument.
+/// Optimized for scalar pattern constants.
+#[derive(Debug, PartialEq, Eq, Hash)]
+pub struct SparkContains {
+    signature: Signature,
+}
+
+impl Default for SparkContains {
+    fn default() -> Self {
+        Self::new()
+    }
+}
+
+impl SparkContains {
+    pub fn new() -> Self {
+        Self {
+            signature: Signature::variadic_any(Volatility::Immutable),
+        }
+    }
+}
+
+impl ScalarUDFImpl for SparkContains {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn name(&self) -> &str {
+        "contains"
+    }
+
+    fn signature(&self) -> &Signature {
+        &self.signature
+    }
+
+    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+        Ok(DataType::Boolean)
+    }
+
+    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
+        if args.args.len() != 2 {
+            return exec_err!("contains function requires exactly 2 arguments");
+        }
+        spark_contains(&args.args[0], &args.args[1])
+    }
+}
+
+/// Execute contains function with optimized scalar pattern handling.
+fn spark_contains(haystack: &ColumnarValue, needle: &ColumnarValue) -> 
Result<ColumnarValue> {
+    match (haystack, needle) {
+        // Both arrays - use arrow's contains directly
+        (ColumnarValue::Array(haystack_array), 
ColumnarValue::Array(needle_array)) => {
+            let result = arrow_contains(haystack_array, needle_array)?;
+            Ok(ColumnarValue::Array(Arc::new(result)))
+        }
+
+        // Array haystack, scalar needle - OPTIMIZED PATH
+        (ColumnarValue::Array(haystack_array), 
ColumnarValue::Scalar(needle_scalar)) => {
+            let result = contains_with_arrow_scalar(haystack_array, 
needle_scalar)?;
+            Ok(ColumnarValue::Array(result))
+        }
+
+        // Scalar haystack, array needle - less common
+        (ColumnarValue::Scalar(haystack_scalar), 
ColumnarValue::Array(needle_array)) => {
+            let haystack_array = 
haystack_scalar.to_array_of_size(needle_array.len())?;
+            let result = arrow_contains(&haystack_array, needle_array)?;
+            Ok(ColumnarValue::Array(Arc::new(result)))
+        }
+
+        // Both scalars - compute single result
+        (ColumnarValue::Scalar(haystack_scalar), 
ColumnarValue::Scalar(needle_scalar)) => {
+            let result = contains_scalar_scalar(haystack_scalar, 
needle_scalar)?;
+            Ok(ColumnarValue::Scalar(result))
+        }
+    }
+}
+
+/// Optimized contains for array haystack with scalar needle.
+/// Uses Arrow's native scalar handling for better performance.
+fn contains_with_arrow_scalar(
+    haystack_array: &ArrayRef,
+    needle_scalar: &ScalarValue,
+) -> Result<ArrayRef> {
+    // Handle null needle
+    if needle_scalar.is_null() {
+        return Ok(Arc::new(BooleanArray::new_null(haystack_array.len())));
+    }
+
+    // Extract the needle string
+    let needle_str = match needle_scalar {
+        ScalarValue::Utf8(Some(s))
+        | ScalarValue::LargeUtf8(Some(s))
+        | ScalarValue::Utf8View(Some(s)) => s.clone(),
+        _ => {
+            return exec_err!(
+                "contains function requires string type for needle, got {:?}",
+                needle_scalar.data_type()
+            )
+        }
+    };
+
+    // Create scalar array for needle - tells Arrow to use optimized paths
+    let needle_scalar_array = StringArray::new_scalar(needle_str);
+
+    // Use Arrow's contains which detects scalar case and uses optimized paths
+    let result = arrow_contains(haystack_array, &needle_scalar_array)?;
+    Ok(Arc::new(result))
+}
+
+/// Contains for two scalar values.
+fn contains_scalar_scalar(
+    haystack_scalar: &ScalarValue,
+    needle_scalar: &ScalarValue,
+) -> Result<ScalarValue> {
+    // Handle nulls
+    if haystack_scalar.is_null() || needle_scalar.is_null() {
+        return Ok(ScalarValue::Boolean(None));
+    }
+
+    let haystack_str = match haystack_scalar {
+        ScalarValue::Utf8(Some(s))
+        | ScalarValue::LargeUtf8(Some(s))
+        | ScalarValue::Utf8View(Some(s)) => s.as_str(),
+        _ => {
+            return exec_err!(
+                "contains function requires string type for haystack, got 
{:?}",
+                haystack_scalar.data_type()
+            )
+        }
+    };
+
+    let needle_str = match needle_scalar {
+        ScalarValue::Utf8(Some(s))
+        | ScalarValue::LargeUtf8(Some(s))
+        | ScalarValue::Utf8View(Some(s)) => s.as_str(),
+        _ => {
+            return exec_err!(
+                "contains function requires string type for needle, got {:?}",
+                needle_scalar.data_type()
+            )
+        }
+    };
+
+    Ok(ScalarValue::Boolean(Some(
+        haystack_str.contains(needle_str),
+    )))
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use arrow::array::StringArray;
+
+    #[test]
+    fn test_contains_array_scalar() {
+        let haystack = Arc::new(StringArray::from(vec![
+            Some("hello world"),
+            Some("foo bar"),
+            Some("testing"),
+            None,
+        ])) as ArrayRef;
+        let needle = ScalarValue::Utf8(Some("world".to_string()));
+
+        let result = contains_with_arrow_scalar(&haystack, &needle).unwrap();
+        let bool_array = 
result.as_any().downcast_ref::<BooleanArray>().unwrap();
+
+        assert!(bool_array.value(0)); // "hello world" contains "world"
+        assert!(!bool_array.value(1)); // "foo bar" does not contain "world"
+        assert!(!bool_array.value(2)); // "testing" does not contain "world"
+        assert!(bool_array.is_null(3)); // null input => null output
+    }
+
+    #[test]
+    fn test_contains_scalar_scalar() {
+        let haystack = ScalarValue::Utf8(Some("hello world".to_string()));
+        let needle = ScalarValue::Utf8(Some("world".to_string()));
+
+        let result = contains_scalar_scalar(&haystack, &needle).unwrap();
+        assert_eq!(result, ScalarValue::Boolean(Some(true)));
+
+        let needle_not_found = ScalarValue::Utf8(Some("xyz".to_string()));
+        let result = contains_scalar_scalar(&haystack, 
&needle_not_found).unwrap();
+        assert_eq!(result, ScalarValue::Boolean(Some(false)));
+    }
+
+    #[test]
+    fn test_contains_null_needle() {
+        let haystack = Arc::new(StringArray::from(vec![
+            Some("hello world"),
+            Some("foo bar"),
+        ])) as ArrayRef;
+        let needle = ScalarValue::Utf8(None);
+
+        let result = contains_with_arrow_scalar(&haystack, &needle).unwrap();
+        let bool_array = 
result.as_any().downcast_ref::<BooleanArray>().unwrap();
+
+        // Null needle should produce null results
+        assert!(bool_array.is_null(0));
+        assert!(bool_array.is_null(1));
+    }
+
+    #[test]
+    fn test_contains_empty_needle() {
+        let haystack = Arc::new(StringArray::from(vec![Some("hello world"), 
Some("")])) as ArrayRef;
+        let needle = ScalarValue::Utf8(Some("".to_string()));
+
+        let result = contains_with_arrow_scalar(&haystack, &needle).unwrap();
+        let bool_array = 
result.as_any().downcast_ref::<BooleanArray>().unwrap();
+
+        // Empty string is contained in any string
+        assert!(bool_array.value(0));
+        assert!(bool_array.value(1));
+    }
+}
diff --git a/native/spark-expr/src/string_funcs/mod.rs 
b/native/spark-expr/src/string_funcs/mod.rs
index aac8204e2..abdd0cc89 100644
--- a/native/spark-expr/src/string_funcs/mod.rs
+++ b/native/spark-expr/src/string_funcs/mod.rs
@@ -15,8 +15,10 @@
 // specific language governing permissions and limitations
 // under the License.
 
+mod contains;
 mod string_space;
 mod substring;
 
+pub use contains::SparkContains;
 pub use string_space::SparkStringSpace;
 pub use substring::SubstringExpr;
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index 2999d8bfe..91428bb61 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -1163,7 +1163,24 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
 
       // Filter rows that contains 'rose' in 'name' column
       val queryContains = sql(s"select id from $table where contains (name, 
'rose')")
-      checkAnswer(queryContains, Row(5) :: Nil)
+      checkSparkAnswerAndOperator(queryContains)
+
+      // Additional test cases for optimized contains implementation
+      // Test with empty pattern (should match all non-null rows)
+      val queryEmptyPattern = sql(s"select id from $table where contains 
(name, '')")
+      checkSparkAnswerAndOperator(queryEmptyPattern)
+
+      // Test with pattern not found
+      val queryNotFound = sql(s"select id from $table where contains (name, 
'xyz')")
+      checkSparkAnswerAndOperator(queryNotFound)
+
+      // Test with pattern at start
+      val queryStart = sql(s"select id from $table where contains (name, 
'James')")
+      checkSparkAnswerAndOperator(queryStart)
+
+      // Test with pattern at end
+      val queryEnd = sql(s"select id from $table where contains (name, 
'Smith')")
+      checkSparkAnswerAndOperator(queryEnd)
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to