This is an automated email from the ASF dual-hosted git repository.

taoran pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new f0dc2b0aea [CALCITE-5976] Function 
ARRAY_PREPEND/ARRAY_APPEND/ARRAY_INSERT gives exception when inserted element 
type not equals array component type
f0dc2b0aea is described below

commit f0dc2b0aea46b1fd3f37e0cc126edaf82ade2344
Author: caicancai <[email protected]>
AuthorDate: Mon Feb 26 23:14:21 2024 +0800

    [CALCITE-5976] Function ARRAY_PREPEND/ARRAY_APPEND/ARRAY_INSERT gives 
exception when inserted element type not equals array component type
    
    Co-authored-by: Ran Tao <[email protected]>
---
 .../calcite/sql/fun/SqlLibraryOperators.java       | 42 ++++++++++-
 .../calcite/sql/validate/SqlValidatorUtil.java     | 70 +++++++++++++++++
 .../org/apache/calcite/test/SqlOperatorTest.java   | 88 +++++++++++++++++++++-
 3 files changed, 196 insertions(+), 4 deletions(-)

diff --git 
a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java 
b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
index 38b5b48f77..0c700d9de3 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
@@ -1199,13 +1199,31 @@ public abstract class SqlLibraryOperators {
       return arrayType;
     }
     final RelDataType elementType = opBinding.collectOperandTypes().get(1);
+    requireNonNull(componentType, () -> "componentType of " + arrayType);
+
     RelDataType type =
         opBinding.getTypeFactory().leastRestrictive(
             ImmutableList.of(componentType, elementType));
+    requireNonNull(type, "inferred array element type");
+
     if (elementType.isNullable()) {
       type = opBinding.getTypeFactory().createTypeWithNullability(type, true);
     }
-    requireNonNull(type, "inferred array element type");
+
+    // make explicit CAST for array elements and inserted element to the 
biggest type
+    // if array component type not equals to inserted element type
+    if (!componentType.equalsSansFieldNames(elementType)) {
+      // 0, 1 is the operand index to be CAST
+      // For array_append/array_prepend, 0 is the array arg and 1 is the 
inserted element
+      if (componentType.equalsSansFieldNames(type)) {
+        SqlValidatorUtil.
+            adjustTypeForArrayFunctions(type, opBinding, 1);
+      } else {
+        SqlValidatorUtil.
+            adjustTypeForArrayFunctions(type, opBinding, 0);
+      }
+    }
+
     return SqlTypeUtil.createArrayType(opBinding.getTypeFactory(), type, 
arrayType.isNullable());
   }
 
@@ -1282,14 +1300,32 @@ public abstract class SqlLibraryOperators {
     final RelDataType arrayType = opBinding.collectOperandTypes().get(0);
     final RelDataType componentType = arrayType.getComponentType();
     final RelDataType elementType = opBinding.collectOperandTypes().get(2);
+    requireNonNull(componentType, () -> "componentType of " + arrayType);
+
     // we don't need to do leastRestrictive on componentType and elementType,
     // because in operand checker we limit the elementType must equals array 
component type.
     // So we use componentType directly.
-    RelDataType type = componentType;
+    RelDataType type =
+        opBinding.getTypeFactory().leastRestrictive(
+            ImmutableList.of(componentType, elementType));
+    requireNonNull(type, "inferred array element type");
+
     if (elementType.isNullable()) {
       type = opBinding.getTypeFactory().createTypeWithNullability(type, true);
     }
-    requireNonNull(type, "inferred array element type");
+    // make explicit CAST for array elements and inserted element to the 
biggest type
+    // if array component type not equals to inserted element type
+    if (!componentType.equalsSansFieldNames(elementType)) {
+      // 0, 2 is the operand index to be CAST
+      // For array_insert, 0 is the array arg and 2 is the inserted element
+      if (componentType.equalsSansFieldNames(type)) {
+        SqlValidatorUtil.
+            adjustTypeForArrayFunctions(type, opBinding, 2);
+      } else {
+        SqlValidatorUtil.
+            adjustTypeForArrayFunctions(type, opBinding, 0);
+      }
+    }
     return SqlTypeUtil.createArrayType(opBinding.getTypeFactory(), type, 
arrayType.isNullable());
   }
 
diff --git 
a/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorUtil.java 
b/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorUtil.java
index fd18c2d7b1..9aeec2da20 100644
--- a/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorUtil.java
+++ b/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorUtil.java
@@ -37,6 +37,7 @@ import org.apache.calcite.schema.ExtensibleTable;
 import org.apache.calcite.schema.Table;
 import org.apache.calcite.schema.impl.AbstractSchema;
 import org.apache.calcite.schema.impl.AbstractTable;
+import org.apache.calcite.sql.SqlBasicCall;
 import org.apache.calcite.sql.SqlCall;
 import org.apache.calcite.sql.SqlCallBinding;
 import org.apache.calcite.sql.SqlDataTypeSpec;
@@ -1326,6 +1327,51 @@ public class SqlValidatorUtil {
     }
   }
 
+  /**
+   * Adjusts the types of specified operands in an array operation to match a 
given target type.
+   * This is particularly useful in the context of SQL operations involving 
array functions,
+   * where it's necessary to ensure that all operands have consistent types 
for the operation
+   * to be valid.
+   *
+   * <p>This method operates on the assumption that the operands to be 
adjusted are part of a
+   * {@link SqlCall}, which is bound within a {@link SqlOperatorBinding}. The 
operands to be
+   * cast are identified by their indexes within the {@code operands} list of 
the {@link SqlCall}.
+   * The method performs a dynamic check to determine if an operand is a basic 
call to an array.
+   * If so, it casts each element within the array to the target type.
+   * Otherwise, it casts the operand itself to the target type.
+   *
+   * <p>Example usage: For an operation like {@code array_append(array(1,2), 
cast(2 as tinyint))},
+   * if targetType is double, this method would ensure that the elements of the
+   * first array and the second operand are cast to double.
+   *
+   * @param targetType The target {@link RelDataType} to which the operands 
should be cast.
+   * @param opBinding  The {@link SqlOperatorBinding} context, which provides 
access to the
+   *                   {@link SqlCall} and its operands.
+   * @param indexes    The indexes of the operands within the {@link SqlCall} 
that need to be
+   *                   adjusted to the target type.
+   * @throws NullPointerException if {@code targetType} is {@code null}.
+   */
+  public static void adjustTypeForArrayFunctions(
+      RelDataType targetType, SqlOperatorBinding opBinding, int... indexes) {
+    if (opBinding instanceof SqlCallBinding) {
+      requireNonNull(targetType, "array function target type");
+      SqlCall call = ((SqlCallBinding) opBinding).getCall();
+      List<SqlNode> operands = call.getOperandList();
+      for (int idx : indexes) {
+        SqlNode operand = operands.get(idx);
+        if (operand instanceof SqlBasicCall
+            // not use SqlKind to compare because some other array function 
forms
+            // such as spark array, the SqlKind is other function.
+            // however, the name is same for those different array forms.
+            && "ARRAY".equals(((SqlBasicCall) 
operand).getOperator().getName())) {
+          call.setOperand(idx, castArrayElementTo(operand, targetType));
+        } else {
+          call.setOperand(idx, castTo(operand, targetType));
+        }
+      }
+    }
+  }
+
   /**
    * When the map key or value does not equal the map component key type or 
value type,
    * make explicit casting.
@@ -1398,6 +1444,30 @@ public class SqlValidatorUtil {
         SqlTypeUtil.convertTypeToSpec(type).withNullable(type.isNullable()));
   }
 
+  /**
+   * Creates a CAST operation that cast each element of the given {@link 
SqlNode} to the
+   * specified type. The {@link SqlNode} representing an array and a {@link 
RelDataType}
+   * representing the target type. This method uses the {@link 
SqlStdOperatorTable#CAST}
+   * operator to create a new {@link SqlCall} node representing a CAST 
operation.
+   * Each element of original 'node' is cast to the desired 'type', preserving 
the
+   * nullability of the 'type'.
+   *
+   * @param node the {@link SqlNode} the sqlnode representing an array
+   * @param type the target {@link RelDataType} the target type
+   * @return a new {@link SqlNode} representing the CAST operation
+   */
+  private static SqlNode castArrayElementTo(SqlNode node, RelDataType type) {
+    int i = 0;
+    for (SqlNode operand : ((SqlBasicCall) node).getOperandList()) {
+      SqlNode castedOperand =
+          SqlStdOperatorTable.CAST.createCall(SqlParserPos.ZERO,
+              operand,
+              
SqlTypeUtil.convertTypeToSpec(type).withNullable(type.isNullable()));
+      ((SqlBasicCall) node).setOperand(i++, castedOperand);
+    }
+    return node;
+  }
+
   //~ Inner Classes ----------------------------------------------------------
 
   /**
diff --git a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java 
b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
index ebf23f290a..01f308adfe 100644
--- a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
+++ b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
@@ -6412,6 +6412,42 @@ public class SqlOperatorTest {
     f.checkType("array_append(cast(null as integer array), 1)", "INTEGER NOT 
NULL ARRAY");
     f.checkFails("^array_append(array[1, 2], true)^",
         "INTEGER is not comparable to BOOLEAN", false);
+
+    // element cast to the biggest type
+    f.checkScalar("array_append(array(cast(1 as tinyint)), 2)", "[1, 2]",
+        "INTEGER NOT NULL ARRAY NOT NULL");
+    f.checkScalar("array_append(array(cast(1 as double)), cast(2 as float))", 
"[1.0, 2.0]",
+        "DOUBLE NOT NULL ARRAY NOT NULL");
+    f.checkScalar("array_append(array(1), cast(2 as float))", "[1.0, 2.0]",
+        "FLOAT NOT NULL ARRAY NOT NULL");
+    f.checkScalar("array_append(array(1), cast(2 as double))", "[1.0, 2.0]",
+        "DOUBLE NOT NULL ARRAY NOT NULL");
+    f.checkScalar("array_append(array(1), cast(2 as bigint))", "[1, 2]",
+        "BIGINT NOT NULL ARRAY NOT NULL");
+    f.checkScalar("array_append(array(1, 2), cast(3 as double))", "[1.0, 2.0, 
3.0]",
+        "DOUBLE NOT NULL ARRAY NOT NULL");
+    f.checkScalar("array_append(array(1, 2), cast(3 as float))", "[1.0, 2.0, 
3.0]",
+        "FLOAT NOT NULL ARRAY NOT NULL");
+    f.checkScalar("array_append(array(1, 2), cast(3 as bigint))", "[1, 2, 3]",
+        "BIGINT NOT NULL ARRAY NOT NULL");
+    f.checkScalar("array_append(array(1, 2), cast(null as double))", "[1.0, 
2.0, null]",
+        "DOUBLE ARRAY NOT NULL");
+    f.checkScalar("array_append(array(1, 2), cast(null as float))", "[1.0, 
2.0, null]",
+        "FLOAT ARRAY NOT NULL");
+    f.checkScalar("array_append(array(1), cast(null as bigint))", "[1, null]",
+        "BIGINT ARRAY NOT NULL");
+    f.checkScalar("array_append(array(1), cast(100 as decimal))", "[1, 100]",
+        "DECIMAL(19, 0) NOT NULL ARRAY NOT NULL");
+    f.checkScalar("array_append(array(1), 10e6)", "[1.0, 1.0E7]",
+        "DOUBLE NOT NULL ARRAY NOT NULL");
+    f.checkScalar("array_append(array(), cast(null as double))", "[null]",
+        "DOUBLE ARRAY NOT NULL");
+    f.checkScalar("array_append(array(), cast(null as float))", "[null]",
+        "FLOAT ARRAY NOT NULL");
+    f.checkScalar("array_append(array(), cast(null as tinyint))", "[null]",
+        "TINYINT ARRAY NOT NULL");
+    f.checkScalar("array_append(array(), cast(null as bigint))", "[null]",
+        "BIGINT ARRAY NOT NULL");
   }
 
   /** Tests {@code ARRAY_COMPACT} function from Spark. */
@@ -6648,7 +6684,7 @@ public class SqlOperatorTest {
         "NULL ARRAY NOT NULL");
     f.checkScalar("array_prepend(array(), null)", "[null]",
         "UNKNOWN ARRAY NOT NULL");
-    f.checkScalar("array_append(array(), 1)", "[1]",
+    f.checkScalar("array_prepend(array(), 1)", "[1]",
         "INTEGER NOT NULL ARRAY NOT NULL");
     f.checkScalar("array_prepend(array[array[1, 2]], array[3, 4])", "[[3, 4], 
[1, 2]]",
         "INTEGER NOT NULL ARRAY NOT NULL ARRAY NOT NULL");
@@ -6658,6 +6694,40 @@ public class SqlOperatorTest {
     f.checkType("array_prepend(cast(null as integer array), 1)", "INTEGER NOT 
NULL ARRAY");
     f.checkFails("^array_prepend(array[1, 2], true)^",
         "INTEGER is not comparable to BOOLEAN", false);
+
+    // element cast to the biggest type
+    f.checkScalar("array_prepend(array(1), cast(3 as float))", "[3.0, 1.0]",
+        "FLOAT NOT NULL ARRAY NOT NULL");
+    f.checkScalar("array_prepend(array(1), cast(3 as bigint))", "[3, 1]",
+        "BIGINT NOT NULL ARRAY NOT NULL");
+    f.checkScalar("array_prepend(array(2), cast(3 as double))", "[3.0, 2.0]",
+        "DOUBLE NOT NULL ARRAY NOT NULL");
+    f.checkScalar("array_prepend(array(1, 2), cast(3 as float))", "[3.0, 1.0, 
2.0]",
+        "FLOAT NOT NULL ARRAY NOT NULL");
+    f.checkScalar("array_prepend(array(2, 1), cast(3 as double))", "[3.0, 2.0, 
1.0]",
+        "DOUBLE NOT NULL ARRAY NOT NULL");
+    f.checkScalar("array_prepend(array(1, 2), cast(3 as tinyint))", "[3, 1, 
2]",
+        "INTEGER NOT NULL ARRAY NOT NULL");
+    f.checkScalar("array_prepend(array(1, 2), cast(3 as bigint))", "[3, 1, 2]",
+        "BIGINT NOT NULL ARRAY NOT NULL");
+    f.checkScalar("array_prepend(array(1, 2), cast(null as double))", "[null, 
1.0, 2.0]",
+        "DOUBLE ARRAY NOT NULL");
+    f.checkScalar("array_prepend(array(1, 2), cast(null as float))", "[null, 
1.0, 2.0]",
+        "FLOAT ARRAY NOT NULL");
+    f.checkScalar("array_prepend(array(1), cast(null as bigint))", "[null, 1]",
+        "BIGINT ARRAY NOT NULL");
+    f.checkScalar("array_prepend(array(1), cast(100 as decimal))", "[100, 1]",
+        "DECIMAL(19, 0) NOT NULL ARRAY NOT NULL");
+    f.checkScalar("array_prepend(array(1), 10e6)", "[1.0E7, 1.0]",
+        "DOUBLE NOT NULL ARRAY NOT NULL");
+    f.checkScalar("array_prepend(array(), cast(null as double))", "[null]",
+        "DOUBLE ARRAY NOT NULL");
+    f.checkScalar("array_prepend(array(), cast(null as float))", "[null]",
+        "FLOAT ARRAY NOT NULL");
+    f.checkScalar("array_prepend(array(), cast(null as tinyint))", "[null]",
+        "TINYINT ARRAY NOT NULL");
+    f.checkScalar("array_prepend(array(), cast(null as bigint))", "[null]",
+        "BIGINT ARRAY NOT NULL");
   }
 
   /** Tests {@code ARRAY_REMOVE} function from Spark. */
@@ -6944,6 +7014,22 @@ public class SqlOperatorTest {
         "(INTEGER NOT NULL, CHAR(1) NOT NULL) MAP NOT NULL ARRAY NOT NULL");
     f1.checkScalar("array_insert(array[map[1, 'a']], -1, map[2, 'b'])", 
"[{2=b}, {1=a}]",
         "(INTEGER NOT NULL, CHAR(1) NOT NULL) MAP NOT NULL ARRAY NOT NULL");
+
+    // element cast to the biggest type
+    f1.checkScalar("array_insert(array(1, 2, 3), 3, cast(4 as tinyint))",
+        "[1, 2, 4, 3]", "INTEGER NOT NULL ARRAY NOT NULL");
+    f1.checkScalar("array_insert(array(1, 2, 3), 3, cast(4 as double))",
+        "[1.0, 2.0, 4.0, 3.0]", "DOUBLE NOT NULL ARRAY NOT NULL");
+    f1.checkScalar("array_insert(array(1, 2, 3), 3, cast(4 as float))",
+        "[1.0, 2.0, 4.0, 3.0]", "FLOAT NOT NULL ARRAY NOT NULL");
+    f1.checkScalar("array_insert(array(1, 2, 3), 3, cast(4 as bigint))",
+        "[1, 2, 4, 3]", "BIGINT NOT NULL ARRAY NOT NULL");
+    f1.checkScalar("array_insert(array(1, 2, 3), 3, cast(null as bigint))",
+        "[1, 2, null, 3]", "BIGINT ARRAY NOT NULL");
+    f1.checkScalar("array_insert(array(1, 2, 3), 3, cast(null as float))",
+        "[1.0, 2.0, null, 3.0]", "FLOAT ARRAY NOT NULL");
+    f1.checkScalar("array_insert(array(1, 2, 3), 3, cast(null as tinyint))",
+        "[1, 2, null, 3]", "INTEGER ARRAY NOT NULL");
   }
 
   /** Tests {@code ARRAY_INTERSECT} function from Spark. */

Reply via email to