alamb commented on code in PR #4147:
URL: https://github.com/apache/arrow-datafusion/pull/4147#discussion_r1018281512


##########
datafusion/optimizer/src/unwrap_cast_in_comparison.rs:
##########
@@ -653,4 +655,196 @@ mod tests {
     fn null_decimal(precision: u8, scale: u8) -> Expr {
         lit(ScalarValue::Decimal128(None, precision, scale))
     }
+
+    #[test]
+    fn test_try_cast_to_type_nulls() {
+        // test values that can be cast to/from all integer types
+        let scalars = vec![
+            ScalarValue::Int8(None),
+            ScalarValue::Int16(None),
+            ScalarValue::Int32(None),
+            ScalarValue::Int64(None),
+            ScalarValue::Decimal128(None, 3, 0),
+            ScalarValue::Decimal128(None, 8, 2),
+        ];
+
+        for s1 in &scalars {
+            for s2 in &scalars {
+                expect_cast(
+                    s1.clone(),
+                    s2.get_datatype(),
+                    ExpectedCast::Value(s2.clone()),
+                );
+            }
+        }
+    }
+
+    #[test]
+    fn test_try_cast_to_type_int_in_range() {
+        // test values that can be cast to/from all integer types
+        let scalars = vec![
+            ScalarValue::Int8(Some(123)),
+            ScalarValue::Int16(Some(123)),
+            ScalarValue::Int32(Some(123)),
+            ScalarValue::Int64(Some(123)),
+            ScalarValue::Decimal128(Some(123), 3, 0),
+            ScalarValue::Decimal128(Some(12300), 8, 2),
+        ];
+
+        for s1 in &scalars {
+            for s2 in &scalars {
+                expect_cast(
+                    s1.clone(),
+                    s2.get_datatype(),
+                    ExpectedCast::Value(s2.clone()),
+                );
+            }
+        }
+    }
+
+    #[test]
+    fn test_try_cast_to_type_int_out_of_range() {
+        let max_i64 = ScalarValue::Int64(Some(i64::MAX));
+        let max_u64 = ScalarValue::UInt64(Some(u64::MAX));
+        expect_cast(max_i64.clone(), DataType::Int8, ExpectedCast::NoValue);
+
+        expect_cast(max_i64.clone(), DataType::Int16, ExpectedCast::NoValue);
+
+        expect_cast(max_i64, DataType::Int32, ExpectedCast::NoValue);
+
+        expect_cast(max_u64, DataType::Int64, ExpectedCast::NoValue);
+
+        // decimal out of range
+        expect_cast(
+            
ScalarValue::Decimal128(Some(99999999999999999999999999999999999900), 38, 0),
+            DataType::Int64,
+            ExpectedCast::NoValue,
+        );
+
+        expect_cast(
+            ScalarValue::Decimal128(Some(-9999999999999999999999999999999999), 
37, 1),
+            DataType::Int64,
+            ExpectedCast::NoValue,
+        );
+    }
+
+    #[test]
+    fn test_try_decimal_cast_in_range() {
+        expect_cast(
+            ScalarValue::Decimal128(Some(12300), 5, 2),
+            DataType::Decimal128(3, 0),
+            ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 3, 0)),
+        );
+
+        expect_cast(
+            ScalarValue::Decimal128(Some(12300), 5, 2),
+            DataType::Decimal128(8, 0),
+            ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 8, 0)),
+        );
+
+        expect_cast(
+            ScalarValue::Decimal128(Some(12300), 5, 2),
+            DataType::Decimal128(8, 5),
+            ExpectedCast::Value(ScalarValue::Decimal128(Some(12300000), 8, 5)),
+        );
+    }
+
+    #[test]
+    fn test_try_decimal_cast_out_of_range() {
+        // decimal would lose precision
+        expect_cast(
+            ScalarValue::Decimal128(Some(12345), 5, 2),
+            DataType::Decimal128(3, 0),
+            ExpectedCast::NoValue,
+        );
+
+        // decimal would lose precision
+        expect_cast(
+            ScalarValue::Decimal128(Some(12300), 5, 2),
+            DataType::Decimal128(2, 0),
+            ExpectedCast::NoValue,
+        );
+    }
+
+    #[test]
+    fn test_try_cast_to_type_unsupported() {
+        // int64 to list
+        expect_cast(
+            ScalarValue::Int64(Some(12345)),
+            DataType::List(Box::new(Field::new("f", DataType::Int32, true))),
+            ExpectedCast::NoValue,
+        );
+    }
+
+    #[derive(Debug, Clone)]
+    enum ExpectedCast {
+        /// test successfully cast value and it is as specified
+        Value(ScalarValue),
+        /// test returned OK, but could not cast the value
+        NoValue,
+    }
+
+    /// Runs try_cast_literal_to_type with the specified inputs and
+    /// ensure it computes the expected output, and ensures the
+    /// casting is consistent with the Arrow kernels
+    fn expect_cast(
+        literal: ScalarValue,
+        target_type: DataType,
+        expected_result: ExpectedCast,
+    ) {
+        let actual_result = try_cast_literal_to_type(&literal, &target_type);
+
+        println!("expect_cast: ");
+        println!("  {:?} --> {:?}", literal, target_type);
+        println!("  expected_result: {:?}", expected_result);
+        println!("  actual_result:   {:?}", actual_result);
+
+        match expected_result {
+            ExpectedCast::Value(expected_value) => {
+                let actual_value = actual_result
+                    .expect("Expected success but got error")
+                    .expect("Expected cast value but got None");
+
+                assert_eq!(actual_value, expected_value);
+
+                // Verify that calling the arrow
+                // cast kernel yields the same results
+                // input array
+                let literal_array = literal.to_array_of_size(1);
+                let expected_array = expected_value.to_array_of_size(1);
+                let cast_array = cast_with_options(
+                    &literal_array,
+                    &target_type,
+                    &CastOptions { safe: true },
+                )
+                .expect("Expected to be cast array with arrow cast kernel");
+
+                assert_eq!(
+                    &expected_array, &cast_array,
+                    "Result of casing {:?} with arrow was\n {:#?}\nbut 
expected\n{:#?}",
+                    literal, cast_array, expected_array
+                );
+
+                // Verify that for timestamp types the timezones are the same
+                // (ScalarValue::cmp doesn't account for timezones);
+                if let (
+                    DataType::Timestamp(left_unit, left_tz),
+                    DataType::Timestamp(right_unit, right_tz),
+                ) = (actual_value.get_datatype(), 
expected_value.get_datatype())
+                {
+                    assert_eq!(left_unit, right_unit);
+                    assert_eq!(left_tz, right_tz);
+                }
+            }

Review Comment:
   > I know the timestamp is most important type in the time series database.
   > Optimization about the type of time will benefit the InfluxIO
   
   Yes, indeed this is why I care so much about timestamps :)



##########
datafusion/optimizer/src/unwrap_cast_in_comparison.rs:
##########
@@ -653,4 +655,196 @@ mod tests {
     fn null_decimal(precision: u8, scale: u8) -> Expr {
         lit(ScalarValue::Decimal128(None, precision, scale))
     }
+
+    #[test]
+    fn test_try_cast_to_type_nulls() {
+        // test values that can be cast to/from all integer types
+        let scalars = vec![
+            ScalarValue::Int8(None),
+            ScalarValue::Int16(None),
+            ScalarValue::Int32(None),
+            ScalarValue::Int64(None),
+            ScalarValue::Decimal128(None, 3, 0),
+            ScalarValue::Decimal128(None, 8, 2),
+        ];
+
+        for s1 in &scalars {
+            for s2 in &scalars {
+                expect_cast(
+                    s1.clone(),
+                    s2.get_datatype(),
+                    ExpectedCast::Value(s2.clone()),
+                );
+            }
+        }
+    }
+
+    #[test]
+    fn test_try_cast_to_type_int_in_range() {
+        // test values that can be cast to/from all integer types
+        let scalars = vec![
+            ScalarValue::Int8(Some(123)),
+            ScalarValue::Int16(Some(123)),
+            ScalarValue::Int32(Some(123)),
+            ScalarValue::Int64(Some(123)),
+            ScalarValue::Decimal128(Some(123), 3, 0),
+            ScalarValue::Decimal128(Some(12300), 8, 2),
+        ];
+
+        for s1 in &scalars {
+            for s2 in &scalars {
+                expect_cast(
+                    s1.clone(),
+                    s2.get_datatype(),
+                    ExpectedCast::Value(s2.clone()),
+                );
+            }
+        }
+    }
+
+    #[test]
+    fn test_try_cast_to_type_int_out_of_range() {
+        let max_i64 = ScalarValue::Int64(Some(i64::MAX));
+        let max_u64 = ScalarValue::UInt64(Some(u64::MAX));
+        expect_cast(max_i64.clone(), DataType::Int8, ExpectedCast::NoValue);
+
+        expect_cast(max_i64.clone(), DataType::Int16, ExpectedCast::NoValue);
+
+        expect_cast(max_i64, DataType::Int32, ExpectedCast::NoValue);
+
+        expect_cast(max_u64, DataType::Int64, ExpectedCast::NoValue);
+
+        // decimal out of range
+        expect_cast(
+            
ScalarValue::Decimal128(Some(99999999999999999999999999999999999900), 38, 0),
+            DataType::Int64,
+            ExpectedCast::NoValue,
+        );
+
+        expect_cast(
+            ScalarValue::Decimal128(Some(-9999999999999999999999999999999999), 
37, 1),
+            DataType::Int64,
+            ExpectedCast::NoValue,
+        );
+    }
+
+    #[test]
+    fn test_try_decimal_cast_in_range() {
+        expect_cast(
+            ScalarValue::Decimal128(Some(12300), 5, 2),
+            DataType::Decimal128(3, 0),
+            ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 3, 0)),
+        );
+
+        expect_cast(
+            ScalarValue::Decimal128(Some(12300), 5, 2),
+            DataType::Decimal128(8, 0),
+            ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 8, 0)),
+        );
+
+        expect_cast(
+            ScalarValue::Decimal128(Some(12300), 5, 2),
+            DataType::Decimal128(8, 5),
+            ExpectedCast::Value(ScalarValue::Decimal128(Some(12300000), 8, 5)),
+        );
+    }
+
+    #[test]
+    fn test_try_decimal_cast_out_of_range() {
+        // decimal would lose precision
+        expect_cast(
+            ScalarValue::Decimal128(Some(12345), 5, 2),
+            DataType::Decimal128(3, 0),
+            ExpectedCast::NoValue,
+        );
+
+        // decimal would lose precision
+        expect_cast(
+            ScalarValue::Decimal128(Some(12300), 5, 2),
+            DataType::Decimal128(2, 0),
+            ExpectedCast::NoValue,
+        );
+    }
+
+    #[test]
+    fn test_try_cast_to_type_unsupported() {
+        // int64 to list
+        expect_cast(
+            ScalarValue::Int64(Some(12345)),
+            DataType::List(Box::new(Field::new("f", DataType::Int32, true))),
+            ExpectedCast::NoValue,
+        );
+    }
+
+    #[derive(Debug, Clone)]
+    enum ExpectedCast {
+        /// test successfully cast value and it is as specified
+        Value(ScalarValue),
+        /// test returned OK, but could not cast the value
+        NoValue,
+    }
+
+    /// Runs try_cast_literal_to_type with the specified inputs and
+    /// ensure it computes the expected output, and ensures the
+    /// casting is consistent with the Arrow kernels
+    fn expect_cast(
+        literal: ScalarValue,
+        target_type: DataType,
+        expected_result: ExpectedCast,
+    ) {
+        let actual_result = try_cast_literal_to_type(&literal, &target_type);
+
+        println!("expect_cast: ");
+        println!("  {:?} --> {:?}", literal, target_type);
+        println!("  expected_result: {:?}", expected_result);
+        println!("  actual_result:   {:?}", actual_result);
+
+        match expected_result {
+            ExpectedCast::Value(expected_value) => {
+                let actual_value = actual_result
+                    .expect("Expected success but got error")
+                    .expect("Expected cast value but got None");
+
+                assert_eq!(actual_value, expected_value);
+
+                // Verify that calling the arrow
+                // cast kernel yields the same results
+                // input array
+                let literal_array = literal.to_array_of_size(1);
+                let expected_array = expected_value.to_array_of_size(1);
+                let cast_array = cast_with_options(

Review Comment:
   I am glad I had it too 😅   as I rediscovered that the arrow cast kernel 
doesn't support decimal <--> unsigned and so my initial implementation to 
support unsigned would have been inconsistent. 



##########
datafusion/optimizer/src/unwrap_cast_in_comparison.rs:
##########
@@ -653,4 +655,196 @@ mod tests {
     fn null_decimal(precision: u8, scale: u8) -> Expr {
         lit(ScalarValue::Decimal128(None, precision, scale))
     }
+
+    #[test]
+    fn test_try_cast_to_type_nulls() {
+        // test values that can be cast to/from all integer types
+        let scalars = vec![
+            ScalarValue::Int8(None),
+            ScalarValue::Int16(None),
+            ScalarValue::Int32(None),
+            ScalarValue::Int64(None),
+            ScalarValue::Decimal128(None, 3, 0),
+            ScalarValue::Decimal128(None, 8, 2),
+        ];
+
+        for s1 in &scalars {
+            for s2 in &scalars {
+                expect_cast(
+                    s1.clone(),
+                    s2.get_datatype(),
+                    ExpectedCast::Value(s2.clone()),
+                );
+            }
+        }
+    }
+
+    #[test]
+    fn test_try_cast_to_type_int_in_range() {
+        // test values that can be cast to/from all integer types
+        let scalars = vec![
+            ScalarValue::Int8(Some(123)),
+            ScalarValue::Int16(Some(123)),
+            ScalarValue::Int32(Some(123)),
+            ScalarValue::Int64(Some(123)),
+            ScalarValue::Decimal128(Some(123), 3, 0),
+            ScalarValue::Decimal128(Some(12300), 8, 2),
+        ];
+
+        for s1 in &scalars {
+            for s2 in &scalars {
+                expect_cast(
+                    s1.clone(),
+                    s2.get_datatype(),
+                    ExpectedCast::Value(s2.clone()),
+                );
+            }
+        }
+    }
+
+    #[test]
+    fn test_try_cast_to_type_int_out_of_range() {
+        let max_i64 = ScalarValue::Int64(Some(i64::MAX));
+        let max_u64 = ScalarValue::UInt64(Some(u64::MAX));
+        expect_cast(max_i64.clone(), DataType::Int8, ExpectedCast::NoValue);
+
+        expect_cast(max_i64.clone(), DataType::Int16, ExpectedCast::NoValue);
+
+        expect_cast(max_i64, DataType::Int32, ExpectedCast::NoValue);
+
+        expect_cast(max_u64, DataType::Int64, ExpectedCast::NoValue);
+
+        // decimal out of range
+        expect_cast(
+            
ScalarValue::Decimal128(Some(99999999999999999999999999999999999900), 38, 0),
+            DataType::Int64,
+            ExpectedCast::NoValue,
+        );
+
+        expect_cast(
+            ScalarValue::Decimal128(Some(-9999999999999999999999999999999999), 
37, 1),
+            DataType::Int64,
+            ExpectedCast::NoValue,
+        );
+    }
+
+    #[test]
+    fn test_try_decimal_cast_in_range() {
+        expect_cast(
+            ScalarValue::Decimal128(Some(12300), 5, 2),
+            DataType::Decimal128(3, 0),
+            ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 3, 0)),
+        );
+
+        expect_cast(
+            ScalarValue::Decimal128(Some(12300), 5, 2),
+            DataType::Decimal128(8, 0),
+            ExpectedCast::Value(ScalarValue::Decimal128(Some(123), 8, 0)),
+        );
+
+        expect_cast(
+            ScalarValue::Decimal128(Some(12300), 5, 2),
+            DataType::Decimal128(8, 5),
+            ExpectedCast::Value(ScalarValue::Decimal128(Some(12300000), 8, 5)),
+        );
+    }
+
+    #[test]
+    fn test_try_decimal_cast_out_of_range() {

Review Comment:
   I agree you have covered most (if not all ) of these cases already in 
`test_not_unwrap_cast_with_decimal_comparison`
   
   
   The reason I want to add redundant coverage was while write tests for 
unwrapping timestamp casts I also wanted to verify they were consistent with 
the cast kernel which I couldn't figure out how to do easily.
   
   Also, the tests as written tested both `IN` list cast removal as well as 
comparison cast removal, but the code that handles those structures is the 
same. Thus adding tests for unwrapping timestamps for both of those forms 
seemed like it didn't add extra coverage and was overly verbose and got even 
worse as I contemplated testing all the other types.
   
   So since the only code that is different for different data types was 
`try_cast_to_type` I figured I would write a test for that function to get 
enough coverage
   
   



##########
datafusion/optimizer/src/unwrap_cast_in_comparison.rs:
##########
@@ -653,4 +655,196 @@ mod tests {
     fn null_decimal(precision: u8, scale: u8) -> Expr {
         lit(ScalarValue::Decimal128(None, precision, scale))
     }
+
+    #[test]
+    fn test_try_cast_to_type_nulls() {
+        // test values that can be cast to/from all integer types

Review Comment:
   This is a copy/pasted comment 🤦  -- I will fix it in the next PR. Sorry for 
the confusion and nice spot



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to