seddonm1 commented on a change in pull request #9038:
URL: https://github.com/apache/arrow/pull/9038#discussion_r552950551



##########
File path: rust/datafusion/src/physical_plan/expressions.rs
##########
@@ -3769,4 +4002,166 @@ mod tests {
         let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
         Ok(batch)
     }
+
+    // applies the in_list expr to an input batch and list
+    macro_rules! in_list {
+        ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr) => {{
+            let expr = in_list(col("a"), $LIST, $NEGATED).unwrap();
+            let result = expr.evaluate(&$BATCH)?.into_array($BATCH.num_rows());
+            let result = result
+                .as_any()
+                .downcast_ref::<BooleanArray>()
+                .expect("failed to downcast to BooleanArray");
+            let expected = &BooleanArray::from($EXPECTED);
+            assert_eq!(expected, result);
+        }};
+    }
+
+    #[test]
+    fn in_list_utf8() -> Result<()> {
+        let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
+        let a = StringArray::from(vec![Some("a"), Some("d"), None]);
+        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
+
+        // expression: "a in ("a", "b")"
+        let list = vec![
+            lit(ScalarValue::Utf8(Some("a".to_string()))),
+            lit(ScalarValue::Utf8(Some("b".to_string()))),
+        ];
+        in_list!(batch, list, &false, vec![Some(true), Some(false), None]);
+
+        // expression: "a not in ("a", "b")"
+        let list = vec![
+            lit(ScalarValue::Utf8(Some("a".to_string()))),
+            lit(ScalarValue::Utf8(Some("b".to_string()))),
+        ];
+        in_list!(batch, list, &true, vec![Some(false), Some(true), None]);
+
+        // expression: "a not in ("a", "b")"
+        let list = vec![
+            lit(ScalarValue::Utf8(Some("a".to_string()))),
+            lit(ScalarValue::Utf8(Some("b".to_string()))),
+            lit(ScalarValue::Utf8(None)),
+        ];
+        in_list!(batch, list, &false, vec![Some(true), None, None]);
+
+        // expression: "a not in ("a", "b")"
+        let list = vec![
+            lit(ScalarValue::Utf8(Some("a".to_string()))),
+            lit(ScalarValue::Utf8(Some("b".to_string()))),
+            lit(ScalarValue::Utf8(None)),
+        ];
+        in_list!(batch, list, &true, vec![Some(false), None, None]);
+
+        Ok(())
+    }
+
+    #[test]
+    fn in_list_int64() -> Result<()> {
+        let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]);
+        let a = Int64Array::from(vec![Some(0), Some(2), None]);
+        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
+
+        // expression: "a in (0, 1)"
+        let list = vec![
+            lit(ScalarValue::Int64(Some(0))),
+            lit(ScalarValue::Int64(Some(1))),
+        ];
+        in_list!(batch, list, &false, vec![Some(true), Some(false), None]);
+
+        // expression: "a not in (0, 1)"
+        let list = vec![
+            lit(ScalarValue::Int64(Some(0))),
+            lit(ScalarValue::Int64(Some(1))),
+        ];
+        in_list!(batch, list, &true, vec![Some(false), Some(true), None]);
+
+        // expression: "a in (0, 1, NULL)"
+        let list = vec![
+            lit(ScalarValue::Int64(Some(0))),
+            lit(ScalarValue::Int64(Some(1))),
+            lit(ScalarValue::Utf8(None)),
+        ];
+        in_list!(batch, list, &false, vec![Some(true), None, None]);
+
+        // expression: "a not in (0, 1, NULL)"
+        let list = vec![
+            lit(ScalarValue::Int64(Some(0))),
+            lit(ScalarValue::Int64(Some(1))),
+            lit(ScalarValue::Utf8(None)),
+        ];
+        in_list!(batch, list, &true, vec![Some(false), None, None]);
+
+        Ok(())
+    }
+
+    #[test]
+    fn in_list_float64() -> Result<()> {
+        let schema = Schema::new(vec![Field::new("a", DataType::Float64, 
true)]);
+        let a = Float64Array::from(vec![Some(0.0), Some(0.2), None]);
+        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
+
+        // expression: "a in (0.0, 0.2)"
+        let list = vec![
+            lit(ScalarValue::Float64(Some(0.0))),
+            lit(ScalarValue::Float64(Some(0.1))),
+        ];
+        in_list!(batch, list, &false, vec![Some(true), Some(false), None]);
+
+        // expression: "a not in (0.0, 0.2)"
+        let list = vec![
+            lit(ScalarValue::Float64(Some(0.0))),
+            lit(ScalarValue::Float64(Some(0.1))),
+        ];
+        in_list!(batch, list, &true, vec![Some(false), Some(true), None]);
+
+        // expression: "a in (0.0, 0.2, NULL)"
+        let list = vec![
+            lit(ScalarValue::Float64(Some(0.0))),
+            lit(ScalarValue::Float64(Some(0.1))),
+            lit(ScalarValue::Utf8(None)),
+        ];
+        in_list!(batch, list, &false, vec![Some(true), None, None]);
+
+        // expression: "a not in (0.0, 0.2, NULL)"
+        let list = vec![
+            lit(ScalarValue::Float64(Some(0.0))),
+            lit(ScalarValue::Float64(Some(0.1))),
+            lit(ScalarValue::Utf8(None)),

Review comment:
       The literal `Expr::Literal(ScalarValue::Utf8(None))` is a special case 
in DataFusion at the moment which represents the SQL logical `NULL`. It is 
being passed through to the evaluation as it is required to identify whether 
the list contains any literal `NULL` so that we can override the return value 
with `NULL`. I think this could be optimised in future.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to