tustvold commented on code in PR #4057:
URL: https://github.com/apache/arrow-datafusion/pull/4057#discussion_r1010215858


##########
datafusion/physical-expr/src/expressions/in_list.rs:
##########
@@ -70,320 +57,185 @@ impl Debug for InListExpr {
             .field("expr", &self.expr)
             .field("list", &self.list)
             .field("negated", &self.negated)
-            .field("set", &self.set)
             .finish()
     }
 }
 
-/// InSet
-#[derive(Debug, PartialEq, Eq)]
-pub struct InSet {
-    // TODO: optimization: In the `IN` or `NOT IN` we don't need to consider 
the NULL value
-    // The data type is same, we can use  set: HashSet<T>
-    set: HashSet<ScalarValue>,
+/// A type-erased container of array elements
+trait Set: Send + Sync {
+    fn contains(&self, v: &dyn Array, negated: bool) -> BooleanArray;
 }
 
-impl InSet {
-    pub fn new(set: HashSet<ScalarValue>) -> Self {
-        Self { set }
-    }
-
-    pub fn get_set(&self) -> &HashSet<ScalarValue> {
-        &self.set
-    }
+struct ArrayHashSet {
+    state: RandomState,
+    /// Used to provide a lookup from value to in list index
+    ///
+    /// Note: usize::hash is not used, instead the raw entry
+    /// API is used to store entries w.r.t their value has
+    map: HashMap<usize, (), ()>,
 }
 
-macro_rules! make_contains {
-    ($ARRAY:expr, $LIST_VALUES:expr, $NEGATED:expr, $SCALAR_VALUE:ident, 
$ARRAY_TYPE:ident) => {{
-        let array = $ARRAY.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap();
-
-        let contains_null = $LIST_VALUES
-            .iter()
-            .any(|v| matches!(v, ColumnarValue::Scalar(s) if s.is_null()));
-        let values = $LIST_VALUES
-            .iter()
-            .flat_map(|expr| match expr {
-                ColumnarValue::Scalar(s) => match s {
-                    ScalarValue::$SCALAR_VALUE(Some(v)) => Some(*v),
-                    ScalarValue::$SCALAR_VALUE(None) => None,
-                    datatype => unreachable!("InList can't reach other data 
type {} for {}.", datatype, s),
-                },
-                ColumnarValue::Array(_) => {
-                    unimplemented!("InList does not yet support nested 
columns.")
-                }
-            })
-            .collect::<Vec<_>>();
-
-        collection_contains_check!(array, values, $NEGATED, contains_null)
-    }};
+struct ArraySet<T> {
+    array: T,
+    hash_set: Option<ArrayHashSet>,
 }
 
-macro_rules! make_contains_primitive {
-    ($ARRAY:expr, $LIST_VALUES:expr, $NEGATED:expr, $SCALAR_VALUE:ident, 
$ARRAY_TYPE:ident) => {{
-        let array = $ARRAY.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap();
-
-        let contains_null = $LIST_VALUES
-            .iter()
-            .any(|v| matches!(v, ColumnarValue::Scalar(s) if s.is_null()));
-        let values = $LIST_VALUES
-            .iter()
-            .flat_map(|expr| match expr {
-                ColumnarValue::Scalar(s) => match s {
-                    ScalarValue::$SCALAR_VALUE(Some(v), ..) => Some(*v),
-                    ScalarValue::$SCALAR_VALUE(None, ..) => None,
-                    datatype => unreachable!("InList can't reach other data 
type {} for {}.", datatype, s),
-                },
-                ColumnarValue::Array(_) => {
-                    unimplemented!("InList does not yet support nested 
columns.")
-                }
-            })
-            .collect::<Vec<_>>();
-
-        Ok(collection_contains_check!(array, values, $NEGATED, contains_null))
-    }};
+impl<T> ArraySet<T>
+where
+    T: Array + From<ArrayData>,
+{
+    fn new(array: &T, hash_set: Option<ArrayHashSet>) -> Self {
+        Self {
+            array: T::from(array.data().clone()),
+            hash_set,
+        }
+    }
 }
 
-macro_rules! set_contains_for_float {
-    ($ARRAY:expr, $SET_VALUES:expr, $SCALAR_VALUE:ident, $NEGATED:expr) => {{
-        let contains_null = $SET_VALUES.iter().any(|s| s.is_null());
-        let bool_array = if $NEGATED {
-            // Not in
-            if contains_null {
-                $ARRAY
-                    .iter()
-                    .map(|vop| {
-                        match vop.map(|v| 
!$SET_VALUES.contains(&v.try_into().unwrap())) {
-                            Some(true) => None,
-                            x => x,
+impl<T> Set for ArraySet<T>
+where
+    T: Array + 'static,
+    for<'a> &'a T: ArrayAccessor,
+    for<'a> <&'a T as ArrayAccessor>::Item: PartialEq + HashValue,
+{
+    fn contains(&self, v: &dyn Array, negated: bool) -> BooleanArray {
+        let v = v.as_any().downcast_ref::<T>().unwrap();
+        let in_data = self.array.data();
+        let in_array = &self.array;
+        let has_nulls = in_data.null_count() != 0;
+
+        match &self.hash_set {
+            Some(hash_set) => ArrayIter::new(v)
+                .map(|v| {
+                    v.and_then(|v| {
+                        let hash = v.hash_one(&hash_set.state);
+                        let contains = hash_set
+                            .map
+                            .raw_entry()
+                            .from_hash(hash, |idx| in_array.value(*idx) == v)
+                            .is_some();
+
+                        match contains {
+                            true => Some(!negated),
+                            false if has_nulls => None,
+                            false => Some(negated),
                         }
                     })
-                    .collect::<BooleanArray>()
-            } else {
-                $ARRAY
-                    .iter()
-                    .map(|vop| vop.map(|v| 
!$SET_VALUES.contains(&v.try_into().unwrap())))
-                    .collect::<BooleanArray>()
-            }
-        } else {
-            // In
-            if contains_null {
-                $ARRAY
-                    .iter()
-                    .map(|vop| {
-                        match vop.map(|v| 
$SET_VALUES.contains(&v.try_into().unwrap())) {
-                            Some(false) => None,
-                            x => x,
-                        }
+                })
+                .collect(),
+            None => ArrayIter::new(v)
+                .map(|v| {
+                    v.map(|v| {
+                        let contains = (0..in_data.len()).any(|x| 
in_array.value(x) == v);
+                        contains != negated
                     })
-                    .collect::<BooleanArray>()
-            } else {
-                $ARRAY
-                    .iter()
-                    .map(|vop| vop.map(|v| 
$SET_VALUES.contains(&v.try_into().unwrap())))
-                    .collect::<BooleanArray>()
-            }
-        };
-        ColumnarValue::Array(Arc::new(bool_array))
-    }};
+                })
+                .collect(),
+        }
+    }
 }
 
-macro_rules! set_contains_for_primitive {
-    ($ARRAY:expr, $SET_VALUES:expr, $SCALAR_VALUE:ident, $NEGATED:expr) => {{
-        let contains_null = $SET_VALUES.iter().any(|s| s.is_null());
-        let native_set = $SET_VALUES
-            .iter()
-            .flat_map(|v| match v {
-                $SCALAR_VALUE(value, ..) => *value,
-                datatype => {
-                    unreachable!(
-                        "InList can't reach other data type {} for {}.",
-                        datatype, v
-                    )
-                }
-            })
-            .collect::<HashSet<_>>();
+/// Computes an [`ArrayHashSet`] for the provided [`Array`] if there are nulls 
present
+/// or there are more than [`OPTIMIZER_INSET_THRESHOLD`] values
+///
+/// Note: This is split into a separate function as higher-rank trait bounds 
currently
+/// cause type inference to misbehave
+fn make_hash_set<T>(array: T) -> Option<ArrayHashSet>
+where
+    T: ArrayAccessor,
+    T::Item: PartialEq + HashValue,
+{
+    let use_hashset = array.null_count() != 0 || array.len() >= 
OPTIMIZER_INSET_THRESHOLD;
+    if !use_hashset {
+        return None;
+    }
+    let data = array.data();
+
+    let state = RandomState::new();
+    let mut map: HashMap<usize, (), ()> =
+        HashMap::with_capacity_and_hasher(data.len(), ());
+
+    let insert_value = |idx| {
+        let value = array.value(idx);
+        let hash = value.hash_one(&state);
+        if let RawEntryMut::Vacant(v) = map
+            .raw_entry_mut()
+            .from_hash(hash, |x| array.value(*x) == value)
+        {
+            v.insert_with_hasher(hash, idx, (), |x| 
array.value(*x).hash_one(&state));
+        }
+    };
 
-        collection_contains_check!($ARRAY, native_set, $NEGATED, contains_null)
-    }};
-}
+    match data.null_buffer() {
+        Some(buffer) => BitIndexIterator::new(buffer.as_ref(), data.offset(), 
data.len())
+            .for_each(insert_value),
+        None => (0..data.len()).for_each(insert_value),
+    }
 
-macro_rules! collection_contains_check {
-    ($ARRAY:expr, $VALUES:expr, $NEGATED:expr, $CONTAINS_NULL:expr) => {{
-        let bool_array = if $NEGATED {
-            // Not in
-            if $CONTAINS_NULL {
-                $ARRAY
-                    .iter()
-                    .map(|vop| match vop.map(|v| !$VALUES.contains(&v)) {
-                        Some(true) => None,
-                        x => x,
-                    })
-                    .collect::<BooleanArray>()
-            } else {
-                $ARRAY
-                    .iter()
-                    .map(|vop| vop.map(|v| !$VALUES.contains(&v)))
-                    .collect::<BooleanArray>()
-            }
-        } else {
-            // In
-            if $CONTAINS_NULL {
-                $ARRAY
-                    .iter()
-                    .map(|vop| match vop.map(|v| $VALUES.contains(&v)) {
-                        Some(false) => None,
-                        x => x,
-                    })
-                    .collect::<BooleanArray>()
-            } else {
-                $ARRAY
-                    .iter()
-                    .map(|vop| vop.map(|v| $VALUES.contains(&v)))
-                    .collect::<BooleanArray>()
-            }
-        };
-        ColumnarValue::Array(Arc::new(bool_array))
-    }};
+    Some(ArrayHashSet { state, map })
 }
 
-macro_rules! collection_contains_check_decimal {
-    ($ARRAY:expr, $VALUES:expr, $NEGATED:expr, $CONTAINS_NULL:expr) => {{
-        let bool_array = if $NEGATED {
-            // Not in
-            if $CONTAINS_NULL {
-                $ARRAY
-                    .iter()
-                    .map(|vop| match vop.map(|v| !$VALUES.contains(&v)) {
-                        Some(true) => None,
-                        x => x,
-                    })
-                    .collect::<BooleanArray>()
-            } else {
-                $ARRAY
-                    .iter()
-                    .map(|vop| vop.map(|v| !$VALUES.contains(&v)))
-                    .collect::<BooleanArray>()
-            }
-        } else {
-            // In
-            if $CONTAINS_NULL {
-                $ARRAY
-                    .iter()
-                    .map(|vop| match vop.map(|v| $VALUES.contains(&v)) {
-                        Some(false) => None,
-                        x => x,
-                    })
-                    .collect::<BooleanArray>()
-            } else {
-                $ARRAY
-                    .iter()
-                    .map(|vop| vop.map(|v| $VALUES.contains(&v)))
-                    .collect::<BooleanArray>()
-            }
-        };
-        ColumnarValue::Array(Arc::new(bool_array))
-    }};
+/// Creates a `Box<dyn Set>` for the given list of `IN` expressions and `batch`
+fn make_set(array: &dyn Array) -> Result<Box<dyn Set>> {
+    Ok(downcast_primitive_array! {
+        array => Box::new(ArraySet::new(array, make_hash_set(array))),
+        DataType::Boolean => {
+            let array = as_boolean_array(array);
+            Box::new(ArraySet::new(array, make_hash_set(array)))
+        },
+        DataType::Decimal128(_, _) => {
+            let array = as_primitive_array::<Decimal128Type>(array);
+            Box::new(ArraySet::new(array, make_hash_set(array)))
+        }
+        DataType::Decimal256(_, _) => {
+            let array = as_primitive_array::<Decimal256Type>(array);
+            Box::new(ArraySet::new(array, make_hash_set(array)))
+        }
+        DataType::Utf8 => {
+            let array = as_string_array(array);
+            Box::new(ArraySet::new(array, make_hash_set(array)))
+        }
+        DataType::LargeUtf8 => {
+            let array = as_largestring_array(array);
+            Box::new(ArraySet::new(array, make_hash_set(array)))
+        }
+        DataType::Binary => {
+            let array = as_generic_binary_array::<i32>(array);
+            Box::new(ArraySet::new(array, make_hash_set(array)))
+        }
+        DataType::LargeBinary => {
+            let array = as_generic_binary_array::<i64>(array);
+            Box::new(ArraySet::new(array, make_hash_set(array)))
+        }
+        d => return Err(DataFusionError::NotImplemented(format!("DataType::{} 
not supported in InList", d)))
+    })
 }
 
-// try evaluate all list exprs and check if the exprs are constants or not
-fn try_cast_static_filter_to_set(
+fn evaluate_list(
     list: &[Arc<dyn PhysicalExpr>],
-    schema: &Schema,
-) -> Result<HashSet<ScalarValue>> {
-    let batch = RecordBatch::new_empty(Arc::new(schema.to_owned()));
-    list.iter()
-        .map(|expr| match expr.evaluate(&batch) {
-            Ok(ColumnarValue::Array(_)) => Err(DataFusionError::NotImplemented(
-                "InList doesn't support to evaluate the array 
result".to_string(),
-            )),
-            Ok(ColumnarValue::Scalar(s)) => Ok(s),
-            Err(e) => Err(e),
-        })
-        .collect::<Result<HashSet<_>>>()
-}
-
-fn make_list_contains_decimal(
-    array: &Decimal128Array,
-    list: Vec<ColumnarValue>,
-    negated: bool,
-) -> ColumnarValue {
-    let contains_null = list
-        .iter()
-        .any(|v| matches!(v, ColumnarValue::Scalar(s) if s.is_null()));
-    let values = list
+    batch: &RecordBatch,
+) -> Result<ArrayRef> {
+    let scalars = list
         .iter()
-        .flat_map(|v| match v {
-            ColumnarValue::Scalar(s) => match s {
-                Decimal128(v128op, _, _) => *v128op,
-                datatype => unreachable!(
-                    "InList can't reach other data type {} for {}.",
-                    datatype, s
-                ),
-            },
-            ColumnarValue::Array(_) => {
-                unimplemented!("InList does not yet support nested columns.")
-            }
-        })
-        .collect::<Vec<_>>();
-
-    collection_contains_check_decimal!(array, values, negated, contains_null)
-}
-
-fn make_set_contains_decimal(
-    array: &Decimal128Array,
-    set: &HashSet<ScalarValue>,
-    negated: bool,
-) -> ColumnarValue {
-    let contains_null = set.iter().any(|v| v.is_null());
-    let native_set = set
-        .iter()
-        .flat_map(|v| match v {
-            Decimal128(v128op, _, _) => *v128op,
-            datatype => {
-                unreachable!("InList can't reach other data type {} for {}.", 
datatype, v)
-            }
-        })
-        .collect::<HashSet<_>>();
-
-    collection_contains_check_decimal!(array, native_set, negated, 
contains_null)
-}
-
-fn set_contains_utf8<OffsetSize: OffsetSizeTrait>(
-    array: &GenericStringArray<OffsetSize>,
-    set: &HashSet<ScalarValue>,
-    negated: bool,
-) -> ColumnarValue {
-    let contains_null = set.iter().any(|v| v.is_null());
-    let native_set = set
-        .iter()
-        .flat_map(|v| match v {
-            Utf8(v) | LargeUtf8(v) => v.as_deref(),
-            datatype => {
-                unreachable!("InList can't reach other data type {} for {}.", 
datatype, v)
-            }
+        .map(|expr| {
+            expr.evaluate(batch).and_then(|r| match r {
+                ColumnarValue::Array(_) => Err(DataFusionError::Execution(
+                    "InList expression must evaluate to a scalar".to_string(),

Review Comment:
   This is consistent with the code before, which instead returned "InList does 
not yet support nested columns". I think this error message is more clear about 
what the cause actually is, it has nothing to do with nested columns at all.
   
   See #3766 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to