This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 7943199c8 perf: Optimize contains expression with SIMD-based scalar
pattern sea… (#2991)
7943199c8 is described below
commit 7943199c88a9b1e3f33afbbca430355d54d02881
Author: Shekhar Prasad Rajak <[email protected]>
AuthorDate: Tue Feb 10 03:45:33 2026 +0530
perf: Optimize contains expression with SIMD-based scalar pattern sea…
(#2991)
---
native/spark-expr/src/comet_scalar_funcs.rs | 5 +-
native/spark-expr/src/string_funcs/contains.rs | 246 +++++++++++++++++++++
native/spark-expr/src/string_funcs/mod.rs | 2 +
.../org/apache/comet/CometExpressionSuite.scala | 19 +-
4 files changed, 269 insertions(+), 3 deletions(-)
diff --git a/native/spark-expr/src/comet_scalar_funcs.rs
b/native/spark-expr/src/comet_scalar_funcs.rs
index 760dc3570..6647e01cc 100644
--- a/native/spark-expr/src/comet_scalar_funcs.rs
+++ b/native/spark-expr/src/comet_scalar_funcs.rs
@@ -22,8 +22,8 @@ use crate::math_funcs::modulo_expr::spark_modulo;
use crate::{
spark_array_repeat, spark_ceil, spark_decimal_div,
spark_decimal_integral_div, spark_floor,
spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding,
spark_round, spark_rpad,
- spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount,
SparkDateDiff, SparkDateTrunc,
- SparkSizeFunc, SparkStringSpace,
+ spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount,
SparkContains, SparkDateDiff,
+ SparkDateTrunc, SparkSizeFunc, SparkStringSpace,
};
use arrow::datatypes::DataType;
use datafusion::common::{DataFusionError, Result as DataFusionResult};
@@ -192,6 +192,7 @@ pub fn create_comet_physical_fun_with_eval_mode(
fn all_scalar_functions() -> Vec<Arc<ScalarUDF>> {
vec![
Arc::new(ScalarUDF::new_from_impl(SparkBitwiseCount::default())),
+ Arc::new(ScalarUDF::new_from_impl(SparkContains::default())),
Arc::new(ScalarUDF::new_from_impl(SparkDateDiff::default())),
Arc::new(ScalarUDF::new_from_impl(SparkDateTrunc::default())),
Arc::new(ScalarUDF::new_from_impl(SparkStringSpace::default())),
diff --git a/native/spark-expr/src/string_funcs/contains.rs
b/native/spark-expr/src/string_funcs/contains.rs
new file mode 100644
index 000000000..bc34ce9cb
--- /dev/null
+++ b/native/spark-expr/src/string_funcs/contains.rs
@@ -0,0 +1,246 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Optimized `contains` string function for Spark compatibility.
+//!
+//! Optimized for scalar pattern case by passing scalar directly to
arrow_contains
+//! instead of expanding to arrays like DataFusion's built-in contains.
+
+use arrow::array::{Array, ArrayRef, BooleanArray, StringArray};
+use arrow::compute::kernels::comparison::contains as arrow_contains;
+use arrow::datatypes::DataType;
+use datafusion::common::{exec_err, Result, ScalarValue};
+use datafusion::logical_expr::{
+ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
+};
+use std::any::Any;
+use std::sync::Arc;
+
+/// Spark-optimized contains function.
+/// Returns true if the first string argument contains the second string
argument.
+/// Optimized for scalar pattern constants.
+#[derive(Debug, PartialEq, Eq, Hash)]
+pub struct SparkContains {
+ signature: Signature,
+}
+
+impl Default for SparkContains {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+impl SparkContains {
+ pub fn new() -> Self {
+ Self {
+ signature: Signature::variadic_any(Volatility::Immutable),
+ }
+ }
+}
+
+impl ScalarUDFImpl for SparkContains {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn name(&self) -> &str {
+ "contains"
+ }
+
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+
+ fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+ Ok(DataType::Boolean)
+ }
+
+ fn invoke_with_args(&self, args: ScalarFunctionArgs) ->
Result<ColumnarValue> {
+ if args.args.len() != 2 {
+ return exec_err!("contains function requires exactly 2 arguments");
+ }
+ spark_contains(&args.args[0], &args.args[1])
+ }
+}
+
+/// Execute contains function with optimized scalar pattern handling.
+fn spark_contains(haystack: &ColumnarValue, needle: &ColumnarValue) ->
Result<ColumnarValue> {
+ match (haystack, needle) {
+ // Both arrays - use arrow's contains directly
+ (ColumnarValue::Array(haystack_array),
ColumnarValue::Array(needle_array)) => {
+ let result = arrow_contains(haystack_array, needle_array)?;
+ Ok(ColumnarValue::Array(Arc::new(result)))
+ }
+
+ // Array haystack, scalar needle - OPTIMIZED PATH
+ (ColumnarValue::Array(haystack_array),
ColumnarValue::Scalar(needle_scalar)) => {
+ let result = contains_with_arrow_scalar(haystack_array,
needle_scalar)?;
+ Ok(ColumnarValue::Array(result))
+ }
+
+ // Scalar haystack, array needle - less common
+ (ColumnarValue::Scalar(haystack_scalar),
ColumnarValue::Array(needle_array)) => {
+ let haystack_array =
haystack_scalar.to_array_of_size(needle_array.len())?;
+ let result = arrow_contains(&haystack_array, needle_array)?;
+ Ok(ColumnarValue::Array(Arc::new(result)))
+ }
+
+ // Both scalars - compute single result
+ (ColumnarValue::Scalar(haystack_scalar),
ColumnarValue::Scalar(needle_scalar)) => {
+ let result = contains_scalar_scalar(haystack_scalar,
needle_scalar)?;
+ Ok(ColumnarValue::Scalar(result))
+ }
+ }
+}
+
+/// Optimized contains for array haystack with scalar needle.
+/// Uses Arrow's native scalar handling for better performance.
+fn contains_with_arrow_scalar(
+ haystack_array: &ArrayRef,
+ needle_scalar: &ScalarValue,
+) -> Result<ArrayRef> {
+ // Handle null needle
+ if needle_scalar.is_null() {
+ return Ok(Arc::new(BooleanArray::new_null(haystack_array.len())));
+ }
+
+ // Extract the needle string
+ let needle_str = match needle_scalar {
+ ScalarValue::Utf8(Some(s))
+ | ScalarValue::LargeUtf8(Some(s))
+ | ScalarValue::Utf8View(Some(s)) => s.clone(),
+ _ => {
+ return exec_err!(
+ "contains function requires string type for needle, got {:?}",
+ needle_scalar.data_type()
+ )
+ }
+ };
+
+ // Create scalar array for needle - tells Arrow to use optimized paths
+ let needle_scalar_array = StringArray::new_scalar(needle_str);
+
+ // Use Arrow's contains which detects scalar case and uses optimized paths
+ let result = arrow_contains(haystack_array, &needle_scalar_array)?;
+ Ok(Arc::new(result))
+}
+
+/// Contains for two scalar values.
+fn contains_scalar_scalar(
+ haystack_scalar: &ScalarValue,
+ needle_scalar: &ScalarValue,
+) -> Result<ScalarValue> {
+ // Handle nulls
+ if haystack_scalar.is_null() || needle_scalar.is_null() {
+ return Ok(ScalarValue::Boolean(None));
+ }
+
+ let haystack_str = match haystack_scalar {
+ ScalarValue::Utf8(Some(s))
+ | ScalarValue::LargeUtf8(Some(s))
+ | ScalarValue::Utf8View(Some(s)) => s.as_str(),
+ _ => {
+ return exec_err!(
+ "contains function requires string type for haystack, got
{:?}",
+ haystack_scalar.data_type()
+ )
+ }
+ };
+
+ let needle_str = match needle_scalar {
+ ScalarValue::Utf8(Some(s))
+ | ScalarValue::LargeUtf8(Some(s))
+ | ScalarValue::Utf8View(Some(s)) => s.as_str(),
+ _ => {
+ return exec_err!(
+ "contains function requires string type for needle, got {:?}",
+ needle_scalar.data_type()
+ )
+ }
+ };
+
+ Ok(ScalarValue::Boolean(Some(
+ haystack_str.contains(needle_str),
+ )))
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use arrow::array::StringArray;
+
+ #[test]
+ fn test_contains_array_scalar() {
+ let haystack = Arc::new(StringArray::from(vec![
+ Some("hello world"),
+ Some("foo bar"),
+ Some("testing"),
+ None,
+ ])) as ArrayRef;
+ let needle = ScalarValue::Utf8(Some("world".to_string()));
+
+ let result = contains_with_arrow_scalar(&haystack, &needle).unwrap();
+ let bool_array =
result.as_any().downcast_ref::<BooleanArray>().unwrap();
+
+ assert!(bool_array.value(0)); // "hello world" contains "world"
+ assert!(!bool_array.value(1)); // "foo bar" does not contain "world"
+ assert!(!bool_array.value(2)); // "testing" does not contain "world"
+ assert!(bool_array.is_null(3)); // null input => null output
+ }
+
+ #[test]
+ fn test_contains_scalar_scalar() {
+ let haystack = ScalarValue::Utf8(Some("hello world".to_string()));
+ let needle = ScalarValue::Utf8(Some("world".to_string()));
+
+ let result = contains_scalar_scalar(&haystack, &needle).unwrap();
+ assert_eq!(result, ScalarValue::Boolean(Some(true)));
+
+ let needle_not_found = ScalarValue::Utf8(Some("xyz".to_string()));
+ let result = contains_scalar_scalar(&haystack,
&needle_not_found).unwrap();
+ assert_eq!(result, ScalarValue::Boolean(Some(false)));
+ }
+
+ #[test]
+ fn test_contains_null_needle() {
+ let haystack = Arc::new(StringArray::from(vec![
+ Some("hello world"),
+ Some("foo bar"),
+ ])) as ArrayRef;
+ let needle = ScalarValue::Utf8(None);
+
+ let result = contains_with_arrow_scalar(&haystack, &needle).unwrap();
+ let bool_array =
result.as_any().downcast_ref::<BooleanArray>().unwrap();
+
+ // Null needle should produce null results
+ assert!(bool_array.is_null(0));
+ assert!(bool_array.is_null(1));
+ }
+
+ #[test]
+ fn test_contains_empty_needle() {
+ let haystack = Arc::new(StringArray::from(vec![Some("hello world"),
Some("")])) as ArrayRef;
+ let needle = ScalarValue::Utf8(Some("".to_string()));
+
+ let result = contains_with_arrow_scalar(&haystack, &needle).unwrap();
+ let bool_array =
result.as_any().downcast_ref::<BooleanArray>().unwrap();
+
+ // Empty string is contained in any string
+ assert!(bool_array.value(0));
+ assert!(bool_array.value(1));
+ }
+}
diff --git a/native/spark-expr/src/string_funcs/mod.rs
b/native/spark-expr/src/string_funcs/mod.rs
index aac8204e2..abdd0cc89 100644
--- a/native/spark-expr/src/string_funcs/mod.rs
+++ b/native/spark-expr/src/string_funcs/mod.rs
@@ -15,8 +15,10 @@
// specific language governing permissions and limitations
// under the License.
+mod contains;
mod string_space;
mod substring;
+pub use contains::SparkContains;
pub use string_space::SparkStringSpace;
pub use substring::SubstringExpr;
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index 2999d8bfe..91428bb61 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -1163,7 +1163,24 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
// Filter rows that contains 'rose' in 'name' column
val queryContains = sql(s"select id from $table where contains (name,
'rose')")
- checkAnswer(queryContains, Row(5) :: Nil)
+ checkSparkAnswerAndOperator(queryContains)
+
+ // Additional test cases for optimized contains implementation
+ // Test with empty pattern (should match all non-null rows)
+ val queryEmptyPattern = sql(s"select id from $table where contains
(name, '')")
+ checkSparkAnswerAndOperator(queryEmptyPattern)
+
+ // Test with pattern not found
+ val queryNotFound = sql(s"select id from $table where contains (name,
'xyz')")
+ checkSparkAnswerAndOperator(queryNotFound)
+
+ // Test with pattern at start
+ val queryStart = sql(s"select id from $table where contains (name,
'James')")
+ checkSparkAnswerAndOperator(queryStart)
+
+ // Test with pattern at end
+ val queryEnd = sql(s"select id from $table where contains (name,
'Smith')")
+ checkSparkAnswerAndOperator(queryEnd)
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]