jorgecarleitao commented on a change in pull request #8688: URL: https://github.com/apache/arrow/pull/8688#discussion_r524889515
########## File path: rust/arrow/src/compute/kernels/boolean.rs ########## @@ -149,6 +150,64 @@ pub fn is_not_null(input: &ArrayRef) -> Result<BooleanArray> { Ok(BooleanArray::from(Arc::new(data))) } +/// Copies original array, setting null bit to true if a secondary comparison boolean array is set to true. +/// Typically used to implement NULLIF. +pub fn nullif<T>( + left: &PrimitiveArray<T>, Review comment: If I understood correctly, because we only use `left.data()`, if we replace `PrimitiveArray<T>` by `impl Array` and remove the generic `T`, this is generalized for all data types. :) Note that this will not address this in DataFusion, as it seems that we can only compute comparison boolean arrays from primitives atm. ########## File path: rust/datafusion/src/physical_plan/expressions.rs ########## @@ -2604,6 +2678,61 @@ mod tests { ) } + #[test] + #[rustfmt::skip] + fn nullif_int32() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let a = Int32Array::from(vec![Some(1), Some(2), None, None, Some(3), None, None, Some(4), Some(5)]); + let a = Arc::new(a); + let a_len = a.len(); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a.clone()])?; + + let literal_expr = lit(ScalarValue::Int32(Some(2))); + let lit_array = literal_expr.evaluate(&batch)?; + + let result = nullif_func(&[a.clone(), lit_array])?; + + // Results should be: Some(1), None, None, None, Some(3), None + assert_eq!(result.len(), a_len); + + let result = result + .as_any() + .downcast_ref::<Int32Array>() + .expect("failed to downcast to Int32Array"); + + assert_eq!(1, result.value(0)); + assert_eq!(true, result.is_null(1)); // 2==2, slot 1 turned into null + assert_eq!(true, result.is_null(2)); + assert_eq!(true, result.is_null(3)); + assert_eq!(3, result.value(4)); + assert_eq!(true, result.is_null(5)); + assert_eq!(true, result.is_null(6)); + assert_eq!(5, result.value(8)); + Ok(()) Review comment: Same here: easier to debug if there is a single `assert_eq`, as we can see the whole array in the message. ########## File path: rust/datafusion/src/physical_plan/expressions.rs ########## @@ -2604,6 +2678,61 @@ mod tests { ) } + #[test] + #[rustfmt::skip] Review comment: why `skip`? ########## File path: rust/datafusion/src/physical_plan/expressions.rs ########## @@ -2604,6 +2678,61 @@ mod tests { ) } + #[test] + #[rustfmt::skip] + fn nullif_int32() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let a = Int32Array::from(vec![Some(1), Some(2), None, None, Some(3), None, None, Some(4), Some(5)]); + let a = Arc::new(a); + let a_len = a.len(); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a.clone()])?; + + let literal_expr = lit(ScalarValue::Int32(Some(2))); + let lit_array = literal_expr.evaluate(&batch)?; + + let result = nullif_func(&[a.clone(), lit_array])?; + + // Results should be: Some(1), None, None, None, Some(3), None + assert_eq!(result.len(), a_len); + + let result = result + .as_any() + .downcast_ref::<Int32Array>() + .expect("failed to downcast to Int32Array"); + + assert_eq!(1, result.value(0)); + assert_eq!(true, result.is_null(1)); // 2==2, slot 1 turned into null + assert_eq!(true, result.is_null(2)); + assert_eq!(true, result.is_null(3)); + assert_eq!(3, result.value(4)); + assert_eq!(true, result.is_null(5)); + assert_eq!(true, result.is_null(6)); + assert_eq!(5, result.value(8)); + Ok(()) + } + + #[test] + #[rustfmt::skip] + // Ensure that arrays with no nulls can also invoke NULLIF() correctly + fn nullif_int32_nonulls() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let a = Int32Array::from(vec![1, 3, 10, 7, 8, 1, 2, 4, 5]); + let a = Arc::new(a); + let a_len = a.len(); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a.clone()])?; + + let literal_expr = lit(ScalarValue::Int32(Some(1))); + let lit_array = literal_expr.evaluate(&batch)?; Review comment: I think that this can be simplified to `let lit_array = Int32Array::from(vec![1; a.len()]);`. ########## File path: rust/datafusion/src/physical_plan/expressions.rs ########## @@ -2604,6 +2678,61 @@ mod tests { ) } + #[test] + #[rustfmt::skip] + fn nullif_int32() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let a = Int32Array::from(vec![Some(1), Some(2), None, None, Some(3), None, None, Some(4), Some(5)]); + let a = Arc::new(a); + let a_len = a.len(); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a.clone()])?; + + let literal_expr = lit(ScalarValue::Int32(Some(2))); + let lit_array = literal_expr.evaluate(&batch)?; + + let result = nullif_func(&[a.clone(), lit_array])?; + + // Results should be: Some(1), None, None, None, Some(3), None Review comment: This suggests using `let expected = UInt32::from(vec![Some(1), None, None, None, Some(3), None]);` ########## File path: rust/datafusion/tests/sql.rs ########## @@ -508,6 +508,26 @@ async fn csv_query_avg_multi_batch() -> Result<()> { Ok(()) } +#[tokio::test] +async fn csv_query_nullif_divide_by_0() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx)?; + let sql = "SELECT c8/nullif(c7, 0) FROM aggregate_test_100"; + let actual: Vec<_> = execute(&mut ctx, sql) + .await + .iter() + .map(|x| x[0].clone()) + .collect(); + let actual = actual.join("\n"); + let expected = "1722\n92\n46\n679\n165\n146\n149\n93\n2211\n6495\n307\n139\n253\n123\n21\n84\n98\n13\n230\n\ + 277\n1\n986\n414\n144\n210\n0\n172\n165\n25\n97\n335\n558\n350\n369\n511\n245\n345\n8\n139\n55\n318\n2614\n\ + 1792\n16\n345\n123\n176\n1171\n20\n199\n147\n115\n335\n23\n847\n94\n315\n391\n176\n282\n459\n197\n978\n281\n\ + 27\n26\n281\n8124\n3\n430\n510\n61\n67\n17\n1601\n362\n202\n50\n10\n346\n258\n664\nNULL\n22\n164\n448\n365\n\ + 1640\n671\n203\n2087\n10060\n1015\n913\n9840\n16\n496\n264\n38\n1"; + assert_eq!(expected, actual); + Ok(()) +} Review comment: What do you think if we run against a controlled dataset instead of `aggregate_test_100`? I can't tell whether `expected` is correct. Wouldn't it make it easier if we use something like the `query_not` test is doing, on which it creates a temporary table and runs against that? ########## File path: rust/arrow/src/compute/kernels/boolean.rs ########## @@ -457,4 +516,20 @@ mod tests { assert_eq!(true, res.value(2)); assert_eq!(false, res.value(3)); } + + #[test] + fn test_nullif_int_array() { + let a = Int32Array::from(vec![Some(15), None, Some(8), Some(1), Some(9)]); + let comp = + BooleanArray::from(vec![Some(false), None, Some(true), Some(false), None]); + let res = nullif(&a, &comp).unwrap(); + + assert_eq!(15, res.value(0)); + assert_eq!(true, res.is_null(1)); + assert_eq!(true, res.is_null(2)); // comp true, slot 2 turned into null + assert_eq!(1, res.value(3)); + // Even though comp array / right is null, should still pass through original value + assert_eq!(9, res.value(4)); + assert_eq!(false, res.is_null(4)); // comp true, slot 2 turned into null Review comment: ```suggestion let expected = Int32Array::from(vec![ Some(15), None, None, // comp true, slot 2 turned into null Some(1), // Even though comp array / right is null, should still pass through original value // comp true, slot 2 turned into null Some(9), ]); assert_eq!(expected, res) ``` ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org