This is an automated email from the ASF dual-hosted git repository.

mbudiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new e2c84a6485 [CALCITE-6111] Explicit cast from expression to numeric 
type doesn't check overflow
e2c84a6485 is described below

commit e2c84a6485afce05bd0dcf5a9d6e9aafcb8af65a
Author: Mihai Budiu <[email protected]>
AuthorDate: Fri Jan 19 14:20:07 2024 -0800

    [CALCITE-6111] Explicit cast from expression to numeric type doesn't check 
overflow
    
    Signed-off-by: Mihai Budiu <[email protected]>
---
 .../adapter/enumerable/RexToLixTranslator.java     | 14 +++++
 .../org/apache/calcite/util/BuiltInMethod.java     |  1 +
 .../org/apache/calcite/linq4j/tree/Primitive.java  | 11 ++++
 .../calcite/sql/test/SqlOperatorFixture.java       | 12 +++-
 .../org/apache/calcite/test/SqlOperatorTest.java   | 70 +++++++++++++++++-----
 5 files changed, 90 insertions(+), 18 deletions(-)

diff --git 
a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java
 
b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java
index 8f53929cb7..5389fdccf7 100644
--- 
a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java
+++ 
b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java
@@ -59,6 +59,7 @@ import org.apache.calcite.sql.SqlOperator;
 import org.apache.calcite.sql.SqlWindowTableFunction;
 import org.apache.calcite.sql.fun.SqlStdOperatorTable;
 import org.apache.calcite.sql.type.SqlTypeFamily;
+import org.apache.calcite.sql.type.SqlTypeName;
 import org.apache.calcite.sql.type.SqlTypeUtil;
 import org.apache.calcite.sql.validate.SqlConformance;
 import org.apache.calcite.util.BuiltInMethod;
@@ -534,6 +535,19 @@ public class RexToLixTranslator implements 
RexVisitor<RexToLixTranslator.Result>
         return defaultExpression.get();
       }
 
+    case BIGINT:
+    case INTEGER:
+    case TINYINT:
+    case SMALLINT: {
+      if (SqlTypeName.NUMERIC_TYPES.contains(sourceType.getSqlTypeName())) {
+        return Expressions.call(
+            BuiltInMethod.INTEGER_CAST.method,
+            
Expressions.constant(Primitive.of(typeFactory.getJavaClass(targetType))),
+            operand);
+      }
+      return defaultExpression.get();
+    }
+
     default:
       return defaultExpression.get();
     }
diff --git a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java 
b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
index 142f1d30d2..3f56f3f321 100644
--- a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
+++ b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java
@@ -291,6 +291,7 @@ public enum BuiltInMethod {
   ENUMERABLE_TO_LIST(ExtendedEnumerable.class, "toList"),
   ENUMERABLE_TO_MAP(ExtendedEnumerable.class, "toMap", Function1.class, 
Function1.class),
   AS_LIST(Primitive.class, "asList", Object.class),
+  INTEGER_CAST(Primitive.class, "integerCast", Primitive.class, Object.class),
   MEMORY_GET0(MemoryFactory.Memory.class, "get"),
   MEMORY_GET1(MemoryFactory.Memory.class, "get", int.class),
   ENUMERATOR_CURRENT(Enumerator.class, "current"),
diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Primitive.java 
b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Primitive.java
index b3a43f1ae2..6aa575c795 100644
--- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Primitive.java
+++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Primitive.java
@@ -384,6 +384,10 @@ public enum Primitive {
     }
   }
 
+  public static @Nullable Object integerCast(Primitive primitive, final Object 
value) {
+    return requireNonNull(primitive, "primitive").numberValue((Number) value);
+  }
+
   /**
    * Converts a number into a value of the type specified by this primitive
    * using the SQL CAST rules.  If the value conversion causes loss of 
significant digits,
@@ -424,6 +428,13 @@ public enum Primitive {
         // longValueExact will throw ArithmeticException if out of range
         return decimal.longValueExact();
       }
+      if (value instanceof BigDecimal) {
+        BigDecimal decimal = ((BigDecimal) value)
+            // Round to an integer
+            .setScale(0, RoundingMode.DOWN);
+        // longValueExact will throw ArithmeticException if out of range
+        return decimal.longValueExact();
+      }
       throw new AssertionError("Unexpected Number type "
           + value.getClass().getSimpleName());
     case FLOAT:
diff --git 
a/testkit/src/main/java/org/apache/calcite/sql/test/SqlOperatorFixture.java 
b/testkit/src/main/java/org/apache/calcite/sql/test/SqlOperatorFixture.java
index 7bcfcb1cb0..f87208c02a 100644
--- a/testkit/src/main/java/org/apache/calcite/sql/test/SqlOperatorFixture.java
+++ b/testkit/src/main/java/org/apache/calcite/sql/test/SqlOperatorFixture.java
@@ -71,7 +71,9 @@ public interface SqlOperatorFixture extends AutoCloseable {
   // TODO: Change message
   String INVALID_CHAR_MESSAGE = "(?s).*";
 
-  String OUT_OF_RANGE_MESSAGE = ".* out of range";
+  String OUT_OF_RANGE_MESSAGE = ".* out of range.*";
+
+  String WRONG_FORMAT_MESSAGE = "Number has wrong format.*";
 
   // TODO: Change message
   String DIVISION_BY_ZERO_MESSAGE = "(?s).*";
@@ -643,8 +645,12 @@ public interface SqlOperatorFixture extends AutoCloseable {
 
   default void checkCastFails(String value, String targetType,
       String expectedError, boolean runtime, CastType castType) {
-    final String castString = getCastString(value, targetType, !runtime, 
castType);
-    checkFails(castString, expectedError, runtime);
+    final String query = getCastString(value, targetType, !runtime, castType);
+    if (castType == CastType.CAST || !runtime) {
+      checkFails(query, expectedError, runtime);
+    } else {
+      checkNull(query);
+    }
   }
 
   default void checkCastToString(String value, @Nullable String type,
diff --git a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java 
b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
index 2eb931d85a..dabb16e752 100644
--- a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
+++ b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
@@ -133,6 +133,7 @@ import static 
org.apache.calcite.sql.test.SqlOperatorFixture.INVALID_EXTRACT_UNI
 import static 
org.apache.calcite.sql.test.SqlOperatorFixture.INVALID_EXTRACT_UNIT_VALIDATION_ERROR;
 import static 
org.apache.calcite.sql.test.SqlOperatorFixture.LITERAL_OUT_OF_RANGE_MESSAGE;
 import static 
org.apache.calcite.sql.test.SqlOperatorFixture.OUT_OF_RANGE_MESSAGE;
+import static 
org.apache.calcite.sql.test.SqlOperatorFixture.WRONG_FORMAT_MESSAGE;
 import static org.apache.calcite.util.DateTimeStringUtils.getDateFormatter;
 
 import static org.hamcrest.CoreMatchers.equalTo;
@@ -623,13 +624,16 @@ public class SqlOperatorTest {
 
       // Overflow test
       if (numeric == Numeric.BIGINT) {
-        // Literal of range
+        // Calcite cannot even represent a literal so large, so
+        // for this query even the safe casts fail at compile-time
+        // (runtime == false).
         f.checkCastFails(numeric.maxOverflowNumericString,
             type, LITERAL_OUT_OF_RANGE_MESSAGE, false, castType);
         f.checkCastFails(numeric.minOverflowNumericString,
             type, LITERAL_OUT_OF_RANGE_MESSAGE, false, castType);
       } else {
-        if (numeric != Numeric.DECIMAL5_2 || Bug.CALCITE_2539_FIXED) {
+        if (numeric != Numeric.DECIMAL5_2) {
+          // This condition is for bug [CALCITE-6078], not yet fixed
           f.checkCastFails(numeric.maxOverflowNumericString,
               type, OUT_OF_RANGE_MESSAGE, true, castType);
           f.checkCastFails(numeric.minOverflowNumericString,
@@ -643,11 +647,12 @@ public class SqlOperatorTest {
       f.checkCastToScalarOkay("'" + numeric.minNumericString + "'",
           type, numeric.minNumericString, castType);
 
-      if (Bug.CALCITE_2539_FIXED) {
+      if (numeric != Numeric.DECIMAL5_2) {
+        // The above condition is for bug CALCITE-6078
         f.checkCastFails("'" + numeric.maxOverflowNumericString + "'",
-            type, OUT_OF_RANGE_MESSAGE, true, castType);
+            type, WRONG_FORMAT_MESSAGE, true, castType);
         f.checkCastFails("'" + numeric.minOverflowNumericString + "'",
-            type, OUT_OF_RANGE_MESSAGE, true, castType);
+            type, WRONG_FORMAT_MESSAGE, true, castType);
       }
 
       // Convert from type to string
@@ -657,10 +662,8 @@ public class SqlOperatorTest {
       f.checkCastToString(numeric.minNumericString, null, null, castType);
       f.checkCastToString(numeric.minNumericString, type, null, castType);
 
-      if (Bug.CALCITE_2539_FIXED) {
-        f.checkCastFails("'notnumeric'", type, INVALID_CHAR_MESSAGE, true,
-            castType);
-      }
+      f.checkCastFails("'notnumeric'", type, INVALID_CHAR_MESSAGE, true,
+          castType);
     });
   }
 
@@ -1128,14 +1131,14 @@ public class SqlOperatorTest {
     // ExceptionInInitializerError.
     f.checkScalarExact("cast('15' as integer)", "INTEGER NOT NULL", "15");
     if (castType == CastType.CAST) { // Safe casts should not fail
-      f.checkFails("cast('15.4' as integer)", "Number has wrong format.*", 
true);
-      f.checkFails("cast('15.6' as integer)", "Number has wrong format.*", 
true);
+      f.checkFails("cast('15.4' as integer)", WRONG_FORMAT_MESSAGE, true);
+      f.checkFails("cast('15.6' as integer)", WRONG_FORMAT_MESSAGE, true);
       f.checkFails("cast('ue' as boolean)", "Invalid character for cast.*", 
true);
       f.checkFails("cast('' as boolean)", "Invalid character for cast.*", 
true);
-      f.checkFails("cast('' as integer)", "Number has wrong format.*", true);
-      f.checkFails("cast('' as real)", "Number has wrong format.*", true);
-      f.checkFails("cast('' as double)", "Number has wrong format.*", true);
-      f.checkFails("cast('' as smallint)", "Number has wrong format.*", true);
+      f.checkFails("cast('' as integer)", WRONG_FORMAT_MESSAGE, true);
+      f.checkFails("cast('' as real)", WRONG_FORMAT_MESSAGE, true);
+      f.checkFails("cast('' as double)", WRONG_FORMAT_MESSAGE, true);
+      f.checkFails("cast('' as smallint)", WRONG_FORMAT_MESSAGE, true);
     } else {
       f.checkNull("cast('15.4' as integer)");
       f.checkNull("cast('15.6' as integer)");
@@ -13695,6 +13698,43 @@ public class SqlOperatorTest {
     }
   }
 
+  /**
+   * Test cases for
+   * <a 
href="https://issues.apache.org/jira/browse/CALCITE-6111";>[CALCITE-6111]
+   * Explicit cast from expression to numeric type doesn't check overflow</a>. 
*/
+  @Test public void testOverflow() {
+    final SqlOperatorFixture f = fixture();
+    f.checkFails(String.format(Locale.US, "SELECT cast(%d+30 as tinyint)", 
Byte.MAX_VALUE),
+        OUT_OF_RANGE_MESSAGE, true);
+    f.checkFails(String.format(Locale.US, "SELECT cast(%d+30 as smallint)", 
Short.MAX_VALUE),
+        OUT_OF_RANGE_MESSAGE, true);
+    // We use a long value because otherwise calcite interprets the literal as 
an integer.
+    f.checkFails(String.format(Locale.US, "SELECT cast(%d as int)", 
Long.MAX_VALUE),
+        OUT_OF_RANGE_MESSAGE, true);
+
+    // Casting a floating point value larger than the maximum allowed value.
+    // 1e60 is larger than the largest BIGINT value allowed.
+    f.checkFails("SELECT cast(1e60+30 as tinyint)",
+        OUT_OF_RANGE_MESSAGE, true);
+    f.checkFails("SELECT cast(1e60+30 as smallint)",
+        OUT_OF_RANGE_MESSAGE, true);
+    f.checkFails("SELECT cast(1e60+30 as int)",
+        OUT_OF_RANGE_MESSAGE, true);
+    f.checkFails("SELECT cast(1e60+30 as bigint)",
+        ".*Overflow", true);
+
+    // Casting a decimal value larger than the maximum allowed value.
+    // Concatenating .0 to a value makes it decimal.
+    f.checkFails(String.format(Locale.US, "SELECT cast(%d.0 AS tinyint)", 
Short.MAX_VALUE),
+        OUT_OF_RANGE_MESSAGE, true);
+    f.checkFails(String.format(Locale.US, "SELECT cast(%d.0 AS smallint)", 
Integer.MAX_VALUE),
+        OUT_OF_RANGE_MESSAGE, true);
+    // Dividing Long.MAX_VALUE by 10 ensures that the resulting decimal does 
not exceed the
+    // maximum allowed precision for decimals but is still too large for an 
integer.
+    f.checkFails(String.format(Locale.US, "SELECT cast(%d.0 AS int)", 
Long.MAX_VALUE / 10),
+        OUT_OF_RANGE_MESSAGE, true);
+  }
+
   @ParameterizedTest
   @MethodSource("safeParameters")
   void testCastTruncates(CastType castType, SqlOperatorFixture f) {

Reply via email to