This is an automated email from the ASF dual-hosted git repository.

mbudiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new 4e6a320bf4 [CALCITE-6265] Type coercion is failing for numeric values 
in prepared statements
4e6a320bf4 is described below

commit 4e6a320bf476c4b3c313d86f8ff3c6edf4c2a578
Author: Tim Nieradzik <[email protected]>
AuthorDate: Wed Feb 14 18:13:03 2024 +0200

    [CALCITE-6265] Type coercion is failing for numeric values in prepared 
statements
    
    Given a column of type `INT`. When providing a `short` value as a
    placeholder in a prepared statement, a `ClassCastException` is thrown.
    
    Test case:
    ```
    final String sql =
        "select \"empid\" from \"hr\".\"emps\" where \"empid\" in (?, ?)";
    CalciteAssert.hr()
        .query(sql)
        .consumesPreparedStatement(p -> {
            p.setShort(1, (short) 100);
            p.setShort(2, (short) 110);
        })
        .returnsUnordered("empid=100", "empid=110");
    ```
    
    Stack trace:
    ```
    java.lang.ClassCastException: class java.lang.Short cannot be cast to class 
java.lang.Integer (java.lang.Short and java.lang.Integer are in module 
java.base of loader 'bootstrap')
         at Baz$1$1.moveNext(Unknown Source)
         at 
org.apache.calcite.linq4j.Linq4j$EnumeratorIterator.<init>(Linq4j.java:679)
    ```
---
 .../adapter/enumerable/RexToLixTranslator.java     | 20 ++++-
 .../java/org/apache/calcite/test/JdbcTest.java     | 93 ++++++++++++++++++++++
 .../apache/calcite/linq4j/tree/Expressions.java    | 56 ++++++++++++-
 .../java/org/apache/calcite/linq4j/tree/Types.java | 25 ++++--
 4 files changed, 185 insertions(+), 9 deletions(-)

diff --git 
a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java
 
b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java
index 5389fdccf7..4ea1d23701 100644
--- 
a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java
+++ 
b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java
@@ -31,6 +31,7 @@ import org.apache.calcite.linq4j.tree.Expressions;
 import org.apache.calcite.linq4j.tree.ParameterExpression;
 import org.apache.calcite.linq4j.tree.Primitive;
 import org.apache.calcite.linq4j.tree.Statement;
+import org.apache.calcite.linq4j.tree.Types;
 import org.apache.calcite.rel.type.RelDataType;
 import org.apache.calcite.rex.RexBuilder;
 import org.apache.calcite.rex.RexCall;
@@ -1374,11 +1375,26 @@ public class RexToLixTranslator implements 
RexVisitor<RexToLixTranslator.Result>
     }
     final Type storageType = currentStorageType != null
         ? currentStorageType : 
typeFactory.getJavaClass(dynamicParam.getType());
-    final Expression valueExpression =
+
+    final boolean isNumeric = 
SqlTypeFamily.NUMERIC.contains(dynamicParam.getType());
+
+    // For numeric types, use java.lang.Number to prevent cast exception
+    // when the parameter type differs from the target type
+    Expression argumentExpression =
         EnumUtils.convert(
             Expressions.call(root, BuiltInMethod.DATA_CONTEXT_GET.method,
                 Expressions.constant("?" + dynamicParam.getIndex())),
-            storageType);
+            isNumeric ? java.lang.Number.class : storageType);
+
+    // Short-circuit if the expression evaluates to null. The cast
+    // may throw a NullPointerException as it calls methods on the
+    // object such as longValue().
+    Expression valueExpression =
+        Expressions.condition(
+            Expressions.equal(argumentExpression, Expressions.constant(null)),
+            Expressions.constant(null),
+            Types.castIfNecessary(storageType, argumentExpression));
+
     final ParameterExpression valueVariable =
         Expressions.parameter(valueExpression.getType(),
             list.newName("value_dynamic_param"));
diff --git a/core/src/test/java/org/apache/calcite/test/JdbcTest.java 
b/core/src/test/java/org/apache/calcite/test/JdbcTest.java
index 087b3854d1..8fe3de21ac 100644
--- a/core/src/test/java/org/apache/calcite/test/JdbcTest.java
+++ b/core/src/test/java/org/apache/calcite/test/JdbcTest.java
@@ -84,6 +84,7 @@ import org.apache.calcite.sql.SqlSpecialOperator;
 import org.apache.calcite.sql.parser.SqlParser;
 import org.apache.calcite.sql.parser.SqlParserPos;
 import org.apache.calcite.sql.parser.impl.SqlParserImpl;
+import org.apache.calcite.sql.type.SqlTypeName;
 import org.apache.calcite.sql2rel.SqlToRelConverter.Config;
 import org.apache.calcite.test.schemata.catchall.CatchallSchema;
 import org.apache.calcite.test.schemata.foodmart.FoodmartSchema;
@@ -8423,6 +8424,98 @@ public class JdbcTest {
         });
   }
 
+  @Test void bindByteParameter() {
+    for (SqlTypeName tpe : SqlTypeName.INT_TYPES) {
+      final String sql =
+          "with cte as (select cast(100 as " + tpe.getName() + ") as empid)"
+              + "select * from cte where empid = ?";
+      CalciteAssert.hr()
+          .query(sql)
+          .consumesPreparedStatement(p -> {
+            p.setByte(1, (byte) 100);
+          })
+          .returnsUnordered("EMPID=100");
+    }
+  }
+
+  @Test void bindShortParameter() {
+    for (SqlTypeName tpe : SqlTypeName.INT_TYPES) {
+      final String sql =
+          "with cte as (select cast(100 as " + tpe.getName() + ") as empid)"
+              + "select * from cte where empid = ?";
+
+      CalciteAssert.hr()
+          .query(sql)
+          .consumesPreparedStatement(p -> {
+            p.setShort(1, (short) 100);
+          })
+          .returnsUnordered("EMPID=100");
+    }
+  }
+
+  @Test void bindOverflowingTinyIntParameter() {
+    final String sql =
+        "with cte as (select cast(300 as smallint) as empid)"
+            + "select * from cte where empid = cast(? as tinyint)";
+
+    java.sql.SQLException t =
+        assertThrows(
+          java.sql.SQLException.class,
+          () -> CalciteAssert.hr()
+            .query(sql)
+            .consumesPreparedStatement(p -> {
+              p.setShort(1, (short) 300);
+            })
+            .returns(""));
+
+    assertThat(
+        "message matches",
+        t.getMessage().contains("value is outside the range of 
java.lang.Byte"));
+  }
+
+  @Test void bindIntParameter() {
+    for (SqlTypeName tpe : SqlTypeName.INT_TYPES) {
+      final String sql =
+          "with cte as (select cast(100 as " + tpe.getName() + ") as empid)"
+              + "select * from cte where empid = ?";
+
+      CalciteAssert.hr()
+          .query(sql)
+          .consumesPreparedStatement(p -> {
+            p.setInt(1, 100);
+          })
+          .returnsUnordered("EMPID=100");
+    }
+  }
+
+  @Test void bindLongParameter() {
+    for (SqlTypeName tpe : SqlTypeName.INT_TYPES) {
+      final String sql =
+          "with cte as (select cast(100 as " + tpe.getName() + ") as empid)"
+              + "select * from cte where empid = ?";
+
+      CalciteAssert.hr()
+          .query(sql)
+          .consumesPreparedStatement(p -> {
+            p.setLong(1, 100);
+          })
+          .returnsUnordered("EMPID=100");
+    }
+  }
+
+  @Test void bindNumericParameter() {
+    final String sql =
+        "with cte as (select cast(100 as numeric(5)) as empid)"
+            + "select * from cte where empid = ?";
+
+    CalciteAssert.hr()
+        .query(sql)
+        .consumesPreparedStatement(p -> {
+          p.setLong(1, 100);
+        })
+        .returnsUnordered("EMPID=100");
+  }
+
   private static String sums(int n, boolean c) {
     final StringBuilder b = new StringBuilder();
     for (int i = 0; i < n; i++) {
diff --git 
a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Expressions.java 
b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Expressions.java
index 2e49cbcb05..a7f7545b8e 100644
--- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Expressions.java
+++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Expressions.java
@@ -622,9 +622,49 @@ public abstract class Expressions {
    * operation that throws an exception if the target type is
    * overflowed.
    */
-  public static UnaryExpression convertChecked(Expression expression,
+  public static Expression convertChecked(Expression expression,
       Type type) {
-    throw Extensions.todo();
+    if (type == Byte.class
+            || type == Short.class
+            || type == Integer.class
+            || type == Long.class) {
+      Class<?> typeClass = (Class<?>) type;
+
+      Object minValue;
+      Object maxValue;
+
+      try {
+        minValue = typeClass.getField("MIN_VALUE").get(null);
+        maxValue = typeClass.getField("MAX_VALUE").get(null);
+      } catch (IllegalAccessException | NoSuchFieldException e) {
+        throw new RuntimeException(e);
+      }
+
+      ThrowStatement throwStmt =
+          Expressions.throw_(
+              Expressions.new_(
+                IllegalArgumentException.class,
+                Expressions.constant("value is outside the range of " + 
typeClass.getName())));
+
+      // Covers all lower precision types
+      Expression longValue = Expressions.call(expression, "longValue");
+
+      Expression minCheck = Expressions.lessThan(longValue, 
Expressions.constant(minValue));
+      Expression maxCheck = Expressions.greaterThan(longValue, 
Expressions.constant(maxValue));
+
+      Primitive primitive = requireNonNull(Primitive.ofBox(type));
+      String primitiveName = requireNonNull(primitive.primitiveName);
+      Expression convertExpr = Expressions.call(expression, primitiveName + 
"Value");
+
+      return Expressions.convert_(
+          Expressions.makeTernary(
+            ExpressionType.Conditional,
+            Expressions.or(minCheck, maxCheck),
+            Expressions.fromStatement(throwStmt),
+            convertExpr), type);
+    }
+
+    throw new IllegalArgumentException("Type " + type.getTypeName() + " is not 
supported yet");
   }
 
   /**
@@ -2822,6 +2862,18 @@ public abstract class Expressions {
     throw Extensions.todo();
   }
 
+  /**
+   * Create an expression from a statement.
+   */
+  public static Expression fromStatement(Statement statement) {
+    FunctionExpression<Function<?>> lambda =
+        Expressions.lambda(
+            Blocks.toFunctionBlock(statement),
+            Collections.emptyList());
+
+    return Expressions.call(lambda, "apply");
+  }
+
   /**
    * Creates a statement that represents the throwing of an exception.
    */
diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Types.java 
b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Types.java
index 0be03c9fc0..3dec960cf7 100644
--- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Types.java
+++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Types.java
@@ -28,6 +28,7 @@ import java.lang.reflect.Method;
 import java.lang.reflect.ParameterizedType;
 import java.lang.reflect.Type;
 import java.lang.reflect.TypeVariable;
+import java.math.BigDecimal;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
@@ -428,11 +429,25 @@ public abstract class Types {
         && Number.class.isAssignableFrom((Class) returnType)
         && type instanceof Class
         && Number.class.isAssignableFrom((Class) type)) {
-      // E.g.
-      //   Integer foo(BigDecimal o) {
-      //     return o.intValue();
-      //   }
-      return Expressions.unbox(expression, 
requireNonNull(Primitive.ofBox(returnType)));
+
+      if (returnType == BigDecimal.class) {
+        return Expressions.call(
+            BigDecimal.class,
+            "valueOf",
+            Expressions.call(expression, "longValue"));
+      } else if (
+          returnType == Byte.class
+              || returnType == Short.class
+              || returnType == Integer.class
+              || returnType == Long.class) {
+        return Expressions.convertChecked(expression, returnType);
+      } else {
+        // E.g.
+        //   Integer foo(BigDecimal o) {
+        //     return o.intValue();
+        //   }
+        return Expressions.unbox(expression, 
requireNonNull(Primitive.ofBox(returnType)));
+      }
     }
     if (Primitive.is(returnType) && !Primitive.is(type)) {
       // E.g.

Reply via email to