Dandandan commented on code in PR #18832:
URL: https://github.com/apache/datafusion/pull/18832#discussion_r2549507569
##########
datafusion/physical-expr/src/expressions/in_list.rs:
##########
@@ -198,68 +206,122 @@ impl ArrayStaticFilter {
}
}
-struct Int32StaticFilter {
- null_count: usize,
- values: HashSet<i32>,
-}
+// Macro to generate specialized StaticFilter implementations for primitive
types
+macro_rules! primitive_static_filter {
+ ($Name:ident, $ArrowType:ty) => {
+ struct $Name {
+ null_count: usize,
+ values: HashSet<<$ArrowType as ArrowPrimitiveType>::Native>,
+ }
-impl Int32StaticFilter {
- fn try_new(in_array: &ArrayRef) -> Result<Self> {
- let in_array = in_array
- .as_primitive_opt::<Int32Type>()
- .ok_or_else(|| exec_datafusion_err!("Failed to downcast array"))?;
+ impl $Name {
+ fn try_new(in_array: &ArrayRef) -> Result<Self> {
+ let in_array = in_array
+ .as_primitive_opt::<$ArrowType>()
+ .ok_or_else(|| exec_datafusion_err!(format!("Failed to
downcast an array to a '{}' array", stringify!($ArrowType))))?;
- let mut values = HashSet::with_capacity(in_array.len());
- let null_count = in_array.null_count();
+ let mut values = HashSet::with_capacity(in_array.len());
+ let null_count = in_array.null_count();
+
+ for v in in_array.iter().flatten() {
+ values.insert(v);
+ }
- for v in in_array.iter().flatten() {
- values.insert(v);
+ Ok(Self { null_count, values })
+ }
}
- Ok(Self { null_count, values })
- }
-}
+ impl StaticFilter for $Name {
+ fn null_count(&self) -> usize {
+ self.null_count
+ }
-impl StaticFilter for Int32StaticFilter {
- fn null_count(&self) -> usize {
- self.null_count
- }
+ fn contains(&self, v: &dyn Array, negated: bool) ->
Result<BooleanArray> {
+ // Handle dictionary arrays by recursing on the values
+ downcast_dictionary_array! {
+ v => {
+ let values_contains =
self.contains(v.values().as_ref(), negated)?;
+ let result = take(&values_contains, v.keys(), None)?;
+ return Ok(downcast_array(result.as_ref()))
+ }
+ _ => {}
+ }
- fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray> {
- let v = v
- .as_primitive_opt::<Int32Type>()
- .ok_or_else(|| exec_datafusion_err!("Failed to downcast array"))?;
-
- let result = match (v.null_count() > 0, negated) {
- (true, false) => {
- // has nulls, not negated"
- BooleanArray::from_iter(
- v.iter().map(|value| Some(self.values.contains(&value?))),
- )
- }
- (true, true) => {
- // has nulls, negated
- BooleanArray::from_iter(
- v.iter().map(|value| Some(!self.values.contains(&value?))),
- )
- }
- (false, false) => {
- //no null, not negated
- BooleanArray::from_iter(
- v.values().iter().map(|value| self.values.contains(value)),
- )
- }
- (false, true) => {
- // no null, negated
- BooleanArray::from_iter(
- v.values().iter().map(|value|
!self.values.contains(value)),
- )
+ let v = v
+ .as_primitive_opt::<$ArrowType>()
+ .ok_or_else(|| exec_datafusion_err!(format!("Failed to
downcast an array to a '{}' array", stringify!($ArrowType))))?;
+
+ let haystack_has_nulls = self.null_count > 0;
+
+ let result = match (v.null_count() > 0, haystack_has_nulls,
negated) {
+ (true, _, false) | (false, true, false) => {
+ // Either needle or haystack has nulls, not negated
+ BooleanArray::from_iter(v.iter().map(|value| {
+ match value {
+ // SQL three-valued logic: null IN (...) is
always null
+ None => None,
+ Some(v) => {
+ if self.values.contains(&v) {
+ Some(true)
+ } else if haystack_has_nulls {
+ // value not in set, but set has nulls
-> null
+ None
+ } else {
+ Some(false)
+ }
+ }
+ }
+ }))
+ }
+ (true, _, true) | (false, true, true) => {
+ // Either needle or haystack has nulls, negated
+ BooleanArray::from_iter(v.iter().map(|value| {
+ match value {
+ // SQL three-valued logic: null NOT IN (...)
is always null
+ None => None,
+ Some(v) => {
+ if self.values.contains(&v) {
+ Some(false)
+ } else if haystack_has_nulls {
+ // value not in set, but set has nulls
-> null
+ None
+ } else {
+ Some(true)
+ }
+ }
+ }
+ }))
+ }
+ (false, false, false) => {
+ // no nulls anywhere, not negated
+ BooleanArray::from_iter(
Review Comment:
`BooleanBuffer::collect_bool is faster`
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]