Dandandan commented on code in PR #18832:
URL: https://github.com/apache/datafusion/pull/18832#discussion_r2549501179


##########
datafusion/physical-expr/src/expressions/in_list.rs:
##########
@@ -198,68 +206,122 @@ impl ArrayStaticFilter {
     }
 }
 
-struct Int32StaticFilter {
-    null_count: usize,
-    values: HashSet<i32>,
-}
+// Macro to generate specialized StaticFilter implementations for primitive 
types
+macro_rules! primitive_static_filter {
+    ($Name:ident, $ArrowType:ty) => {
+        struct $Name {
+            null_count: usize,
+            values: HashSet<<$ArrowType as ArrowPrimitiveType>::Native>,
+        }
 
-impl Int32StaticFilter {
-    fn try_new(in_array: &ArrayRef) -> Result<Self> {
-        let in_array = in_array
-            .as_primitive_opt::<Int32Type>()
-            .ok_or_else(|| exec_datafusion_err!("Failed to downcast array"))?;
+        impl $Name {
+            fn try_new(in_array: &ArrayRef) -> Result<Self> {
+                let in_array = in_array
+                    .as_primitive_opt::<$ArrowType>()
+                    .ok_or_else(|| exec_datafusion_err!(format!("Failed to 
downcast an array to a '{}' array", stringify!($ArrowType))))?;
 
-        let mut values = HashSet::with_capacity(in_array.len());
-        let null_count = in_array.null_count();
+                let mut values = HashSet::with_capacity(in_array.len());
+                let null_count = in_array.null_count();
+
+                for v in in_array.iter().flatten() {
+                    values.insert(v);
+                }
 
-        for v in in_array.iter().flatten() {
-            values.insert(v);
+                Ok(Self { null_count, values })
+            }
         }
 
-        Ok(Self { null_count, values })
-    }
-}
+        impl StaticFilter for $Name {
+            fn null_count(&self) -> usize {
+                self.null_count
+            }
 
-impl StaticFilter for Int32StaticFilter {
-    fn null_count(&self) -> usize {
-        self.null_count
-    }
+            fn contains(&self, v: &dyn Array, negated: bool) -> 
Result<BooleanArray> {
+                // Handle dictionary arrays by recursing on the values
+                downcast_dictionary_array! {
+                    v => {
+                        let values_contains = 
self.contains(v.values().as_ref(), negated)?;
+                        let result = take(&values_contains, v.keys(), None)?;
+                        return Ok(downcast_array(result.as_ref()))
+                    }
+                    _ => {}
+                }
 
-    fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray> {
-        let v = v
-            .as_primitive_opt::<Int32Type>()
-            .ok_or_else(|| exec_datafusion_err!("Failed to downcast array"))?;
-
-        let result = match (v.null_count() > 0, negated) {
-            (true, false) => {
-                // has nulls, not negated"
-                BooleanArray::from_iter(
-                    v.iter().map(|value| Some(self.values.contains(&value?))),
-                )
-            }
-            (true, true) => {
-                // has nulls, negated
-                BooleanArray::from_iter(
-                    v.iter().map(|value| Some(!self.values.contains(&value?))),
-                )
-            }
-            (false, false) => {
-                //no null, not negated
-                BooleanArray::from_iter(
-                    v.values().iter().map(|value| self.values.contains(value)),
-                )
-            }
-            (false, true) => {
-                // no null, negated
-                BooleanArray::from_iter(
-                    v.values().iter().map(|value| 
!self.values.contains(value)),
-                )
+                let v = v
+                    .as_primitive_opt::<$ArrowType>()
+                    .ok_or_else(|| exec_datafusion_err!(format!("Failed to 
downcast an array to a '{}' array", stringify!($ArrowType))))?;
+
+                let haystack_has_nulls = self.null_count > 0;
+
+                let result = match (v.null_count() > 0, haystack_has_nulls, 
negated) {
+                    (true, _, false) | (false, true, false) => {
+                        // Either needle or haystack has nulls, not negated
+                        BooleanArray::from_iter(v.iter().map(|value| {

Review Comment:
   `BooleanArray::collect_bool` is much faster



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to