This is an automated email from the ASF dual-hosted git repository.
mbudiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/main by this push:
new 4e6a320bf4 [CALCITE-6265] Type coercion is failing for numeric values
in prepared statements
4e6a320bf4 is described below
commit 4e6a320bf476c4b3c313d86f8ff3c6edf4c2a578
Author: Tim Nieradzik <[email protected]>
AuthorDate: Wed Feb 14 18:13:03 2024 +0200
[CALCITE-6265] Type coercion is failing for numeric values in prepared
statements
Given a column of type `INT`. When providing a `short` value as a
placeholder in a prepared statement, a `ClassCastException` is thrown.
Test case:
```
final String sql =
"select \"empid\" from \"hr\".\"emps\" where \"empid\" in (?, ?)";
CalciteAssert.hr()
.query(sql)
.consumesPreparedStatement(p -> {
p.setShort(1, (short) 100);
p.setShort(2, (short) 110);
})
.returnsUnordered("empid=100", "empid=110");
```
Stack trace:
```
java.lang.ClassCastException: class java.lang.Short cannot be cast to class
java.lang.Integer (java.lang.Short and java.lang.Integer are in module
java.base of loader 'bootstrap')
at Baz$1$1.moveNext(Unknown Source)
at
org.apache.calcite.linq4j.Linq4j$EnumeratorIterator.<init>(Linq4j.java:679)
```
---
.../adapter/enumerable/RexToLixTranslator.java | 20 ++++-
.../java/org/apache/calcite/test/JdbcTest.java | 93 ++++++++++++++++++++++
.../apache/calcite/linq4j/tree/Expressions.java | 56 ++++++++++++-
.../java/org/apache/calcite/linq4j/tree/Types.java | 25 ++++--
4 files changed, 185 insertions(+), 9 deletions(-)
diff --git
a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java
b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java
index 5389fdccf7..4ea1d23701 100644
---
a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java
+++
b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexToLixTranslator.java
@@ -31,6 +31,7 @@ import org.apache.calcite.linq4j.tree.Expressions;
import org.apache.calcite.linq4j.tree.ParameterExpression;
import org.apache.calcite.linq4j.tree.Primitive;
import org.apache.calcite.linq4j.tree.Statement;
+import org.apache.calcite.linq4j.tree.Types;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
@@ -1374,11 +1375,26 @@ public class RexToLixTranslator implements
RexVisitor<RexToLixTranslator.Result>
}
final Type storageType = currentStorageType != null
? currentStorageType :
typeFactory.getJavaClass(dynamicParam.getType());
- final Expression valueExpression =
+
+ final boolean isNumeric =
SqlTypeFamily.NUMERIC.contains(dynamicParam.getType());
+
+ // For numeric types, use java.lang.Number to prevent cast exception
+ // when the parameter type differs from the target type
+ Expression argumentExpression =
EnumUtils.convert(
Expressions.call(root, BuiltInMethod.DATA_CONTEXT_GET.method,
Expressions.constant("?" + dynamicParam.getIndex())),
- storageType);
+ isNumeric ? java.lang.Number.class : storageType);
+
+ // Short-circuit if the expression evaluates to null. The cast
+ // may throw a NullPointerException as it calls methods on the
+ // object such as longValue().
+ Expression valueExpression =
+ Expressions.condition(
+ Expressions.equal(argumentExpression, Expressions.constant(null)),
+ Expressions.constant(null),
+ Types.castIfNecessary(storageType, argumentExpression));
+
final ParameterExpression valueVariable =
Expressions.parameter(valueExpression.getType(),
list.newName("value_dynamic_param"));
diff --git a/core/src/test/java/org/apache/calcite/test/JdbcTest.java
b/core/src/test/java/org/apache/calcite/test/JdbcTest.java
index 087b3854d1..8fe3de21ac 100644
--- a/core/src/test/java/org/apache/calcite/test/JdbcTest.java
+++ b/core/src/test/java/org/apache/calcite/test/JdbcTest.java
@@ -84,6 +84,7 @@ import org.apache.calcite.sql.SqlSpecialOperator;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.parser.impl.SqlParserImpl;
+import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql2rel.SqlToRelConverter.Config;
import org.apache.calcite.test.schemata.catchall.CatchallSchema;
import org.apache.calcite.test.schemata.foodmart.FoodmartSchema;
@@ -8423,6 +8424,98 @@ public class JdbcTest {
});
}
+ @Test void bindByteParameter() {
+ for (SqlTypeName tpe : SqlTypeName.INT_TYPES) {
+ final String sql =
+ "with cte as (select cast(100 as " + tpe.getName() + ") as empid)"
+ + "select * from cte where empid = ?";
+ CalciteAssert.hr()
+ .query(sql)
+ .consumesPreparedStatement(p -> {
+ p.setByte(1, (byte) 100);
+ })
+ .returnsUnordered("EMPID=100");
+ }
+ }
+
+ @Test void bindShortParameter() {
+ for (SqlTypeName tpe : SqlTypeName.INT_TYPES) {
+ final String sql =
+ "with cte as (select cast(100 as " + tpe.getName() + ") as empid)"
+ + "select * from cte where empid = ?";
+
+ CalciteAssert.hr()
+ .query(sql)
+ .consumesPreparedStatement(p -> {
+ p.setShort(1, (short) 100);
+ })
+ .returnsUnordered("EMPID=100");
+ }
+ }
+
+ @Test void bindOverflowingTinyIntParameter() {
+ final String sql =
+ "with cte as (select cast(300 as smallint) as empid)"
+ + "select * from cte where empid = cast(? as tinyint)";
+
+ java.sql.SQLException t =
+ assertThrows(
+ java.sql.SQLException.class,
+ () -> CalciteAssert.hr()
+ .query(sql)
+ .consumesPreparedStatement(p -> {
+ p.setShort(1, (short) 300);
+ })
+ .returns(""));
+
+ assertThat(
+ "message matches",
+ t.getMessage().contains("value is outside the range of
java.lang.Byte"));
+ }
+
+ @Test void bindIntParameter() {
+ for (SqlTypeName tpe : SqlTypeName.INT_TYPES) {
+ final String sql =
+ "with cte as (select cast(100 as " + tpe.getName() + ") as empid)"
+ + "select * from cte where empid = ?";
+
+ CalciteAssert.hr()
+ .query(sql)
+ .consumesPreparedStatement(p -> {
+ p.setInt(1, 100);
+ })
+ .returnsUnordered("EMPID=100");
+ }
+ }
+
+ @Test void bindLongParameter() {
+ for (SqlTypeName tpe : SqlTypeName.INT_TYPES) {
+ final String sql =
+ "with cte as (select cast(100 as " + tpe.getName() + ") as empid)"
+ + "select * from cte where empid = ?";
+
+ CalciteAssert.hr()
+ .query(sql)
+ .consumesPreparedStatement(p -> {
+ p.setLong(1, 100);
+ })
+ .returnsUnordered("EMPID=100");
+ }
+ }
+
+ @Test void bindNumericParameter() {
+ final String sql =
+ "with cte as (select cast(100 as numeric(5)) as empid)"
+ + "select * from cte where empid = ?";
+
+ CalciteAssert.hr()
+ .query(sql)
+ .consumesPreparedStatement(p -> {
+ p.setLong(1, 100);
+ })
+ .returnsUnordered("EMPID=100");
+ }
+
private static String sums(int n, boolean c) {
final StringBuilder b = new StringBuilder();
for (int i = 0; i < n; i++) {
diff --git
a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Expressions.java
b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Expressions.java
index 2e49cbcb05..a7f7545b8e 100644
--- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Expressions.java
+++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Expressions.java
@@ -622,9 +622,49 @@ public abstract class Expressions {
* operation that throws an exception if the target type is
* overflowed.
*/
- public static UnaryExpression convertChecked(Expression expression,
+ public static Expression convertChecked(Expression expression,
Type type) {
- throw Extensions.todo();
+ if (type == Byte.class
+ || type == Short.class
+ || type == Integer.class
+ || type == Long.class) {
+ Class<?> typeClass = (Class<?>) type;
+
+ Object minValue;
+ Object maxValue;
+
+ try {
+ minValue = typeClass.getField("MIN_VALUE").get(null);
+ maxValue = typeClass.getField("MAX_VALUE").get(null);
+ } catch (IllegalAccessException | NoSuchFieldException e) {
+ throw new RuntimeException(e);
+ }
+
+ ThrowStatement throwStmt =
+ Expressions.throw_(
+ Expressions.new_(
+ IllegalArgumentException.class,
+ Expressions.constant("value is outside the range of " +
typeClass.getName())));
+
+ // Covers all lower precision types
+ Expression longValue = Expressions.call(expression, "longValue");
+
+ Expression minCheck = Expressions.lessThan(longValue,
Expressions.constant(minValue));
+ Expression maxCheck = Expressions.greaterThan(longValue,
Expressions.constant(maxValue));
+
+ Primitive primitive = requireNonNull(Primitive.ofBox(type));
+ String primitiveName = requireNonNull(primitive.primitiveName);
+ Expression convertExpr = Expressions.call(expression, primitiveName +
"Value");
+
+ return Expressions.convert_(
+ Expressions.makeTernary(
+ ExpressionType.Conditional,
+ Expressions.or(minCheck, maxCheck),
+ Expressions.fromStatement(throwStmt),
+ convertExpr), type);
+ }
+
+ throw new IllegalArgumentException("Type " + type.getTypeName() + " is not
supported yet");
}
/**
@@ -2822,6 +2862,18 @@ public abstract class Expressions {
throw Extensions.todo();
}
+ /**
+ * Create an expression from a statement.
+ */
+ public static Expression fromStatement(Statement statement) {
+ FunctionExpression<Function<?>> lambda =
+ Expressions.lambda(
+ Blocks.toFunctionBlock(statement),
+ Collections.emptyList());
+
+ return Expressions.call(lambda, "apply");
+ }
+
/**
* Creates a statement that represents the throwing of an exception.
*/
diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Types.java
b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Types.java
index 0be03c9fc0..3dec960cf7 100644
--- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Types.java
+++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Types.java
@@ -28,6 +28,7 @@ import java.lang.reflect.Method;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.lang.reflect.TypeVariable;
+import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
@@ -428,11 +429,25 @@ public abstract class Types {
&& Number.class.isAssignableFrom((Class) returnType)
&& type instanceof Class
&& Number.class.isAssignableFrom((Class) type)) {
- // E.g.
- // Integer foo(BigDecimal o) {
- // return o.intValue();
- // }
- return Expressions.unbox(expression,
requireNonNull(Primitive.ofBox(returnType)));
+
+ if (returnType == BigDecimal.class) {
+ return Expressions.call(
+ BigDecimal.class,
+ "valueOf",
+ Expressions.call(expression, "longValue"));
+ } else if (
+ returnType == Byte.class
+ || returnType == Short.class
+ || returnType == Integer.class
+ || returnType == Long.class) {
+ return Expressions.convertChecked(expression, returnType);
+ } else {
+ // E.g.
+ // Integer foo(BigDecimal o) {
+ // return o.intValue();
+ // }
+ return Expressions.unbox(expression,
requireNonNull(Primitive.ofBox(returnType)));
+ }
}
if (Primitive.is(returnType) && !Primitive.is(type)) {
// E.g.