This is an automated email from the ASF dual-hosted git repository.

taoran pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git

commit e066266dcde21b3a0de90ec601ff539be1b8d7a3
Author: Ran Tao <chucheng...@gmail.com>
AuthorDate: Tue Dec 12 15:44:00 2023 +0800

    [CALCITE-6127] The spark array function gives NullPointerException when 
element is row type
---
 .../calcite/sql/fun/SqlLibraryOperators.java       | 18 ++++---
 .../org/apache/calcite/sql/type/OperandTypes.java  | 56 ++++++++++++++++++++++
 .../org/apache/calcite/test/SqlOperatorTest.java   | 28 +++++++++++
 3 files changed, 96 insertions(+), 6 deletions(-)

diff --git 
a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java 
b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
index a5440a211d..427e305d50 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
@@ -1061,13 +1061,22 @@ public abstract class SqlLibraryOperators {
   private static RelDataType arrayReturnType(SqlOperatorBinding opBinding) {
     final List<RelDataType> operandTypes = opBinding.collectOperandTypes();
 
-    // only numeric & character types check
+    // only numeric & character types check, this is a special spark array case
+    // the form like ARRAY(1, 2, '3') will return ["1", "2", "3"]
     boolean hasNumeric = false;
     boolean hasCharacter = false;
     boolean hasOthers = false;
     for (RelDataType type : operandTypes) {
       SqlTypeFamily family = type.getSqlTypeName().getFamily();
-      requireNonNull(family, "array element type family");
+      // some types such as Row, the family is null, fallback to normal 
inferred type logic
+      if (family == null) {
+        hasOthers = true;
+        break;
+      }
+      // skip it because we allow NULL literal
+      if (SqlTypeUtil.isNull(type)) {
+        continue;
+      }
       switch (family) {
       case NUMERIC:
         hasNumeric = true;
@@ -1075,9 +1084,6 @@ public abstract class SqlLibraryOperators {
       case CHARACTER:
         hasCharacter = true;
         break;
-      case NULL:
-        // skip it becase we allow null
-        break;
       default:
         hasOthers = true;
         break;
@@ -1113,7 +1119,7 @@ public abstract class SqlLibraryOperators {
   public static final SqlFunction ARRAY =
       SqlBasicFunction.create("ARRAY",
           SqlLibraryOperators::arrayReturnType,
-          OperandTypes.SAME_VARIADIC,
+          OperandTypes.ARRAY_FUNCTION,
           SqlFunctionCategory.SYSTEM);
 
   private static RelDataType mapReturnType(SqlOperatorBinding opBinding) {
diff --git a/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java 
b/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java
index 295c95e0f3..8abb7c8178 100644
--- a/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java
+++ b/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java
@@ -26,6 +26,7 @@ import org.apache.calcite.sql.SqlLiteral;
 import org.apache.calcite.sql.SqlNode;
 import org.apache.calcite.sql.SqlOperandCountRange;
 import org.apache.calcite.sql.SqlOperator;
+import org.apache.calcite.sql.SqlOperatorBinding;
 import org.apache.calcite.sql.SqlUtil;
 import org.apache.calcite.sql.validate.SqlValidatorScope;
 import org.apache.calcite.util.ImmutableIntList;
@@ -560,6 +561,9 @@ public abstract class OperandTypes {
   public static final SqlSingleOperandTypeChecker MAP =
       family(SqlTypeFamily.MAP);
 
+  public static final SqlOperandTypeChecker ARRAY_FUNCTION =
+      new ArrayFunctionOperandTypeChecker();
+
   public static final SqlOperandTypeChecker ARRAY_ELEMENT =
       new ArrayElementOperandTypeChecker();
 
@@ -1225,6 +1229,58 @@ public abstract class OperandTypes {
     }
   }
 
+  /**
+   * Operand type-checking strategy for a ARRAY function, it allows empty 
array.
+   */
+  private static class ArrayFunctionOperandTypeChecker
+      extends SameOperandTypeChecker {
+
+    ArrayFunctionOperandTypeChecker() {
+      // The args of array are non-fixed, so we set to -1 here. then 
operandCount
+      // can dynamically set according to the number of input args.
+      // details please see SameOperandTypeChecker#getOperandList.
+      super(-1);
+    }
+
+    @Override protected boolean checkOperandTypesImpl(
+        SqlOperatorBinding operatorBinding,
+        boolean throwOnFailure,
+        @Nullable SqlCallBinding callBinding) {
+      if (throwOnFailure && callBinding == null) {
+        throw new IllegalArgumentException(
+            "callBinding must be non-null in case throwOnFailure=true");
+      }
+      int nOperandsActual = nOperands;
+      if (nOperandsActual == -1) {
+        nOperandsActual = operatorBinding.getOperandCount();
+      }
+      RelDataType[] types = new RelDataType[nOperandsActual];
+      final List<Integer> operandList =
+          getOperandList(operatorBinding.getOperandCount());
+      for (int i : operandList) {
+        types[i] = operatorBinding.getOperandType(i);
+      }
+      for (int i : operandList) {
+        if (i > 0) {
+          // we replace SqlTypeUtil.isComparable with 
SqlTypeUtil.leastRestrictiveForComparison
+          // to handle struct type and NULL constant.
+          // details please see: 
https://issues.apache.org/jira/browse/CALCITE-6163
+          RelDataType type =
+              
SqlTypeUtil.leastRestrictiveForComparison(operatorBinding.getTypeFactory(),
+                  types[i], types[i - 1]);
+          if (type == null) {
+            if (!throwOnFailure) {
+              return false;
+            }
+            throw requireNonNull(callBinding, 
"callBinding").newValidationError(
+                RESOURCE.needSameTypeParameter());
+          }
+        }
+      }
+      return true;
+    }
+  }
+
   /**
    * Operand type-checking strategy for a MAP function, it allows empty map.
    */
diff --git a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java 
b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
index 9578522bd8..2f21b83dbd 100644
--- a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
+++ b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
@@ -10537,6 +10537,34 @@ public class SqlOperatorTest {
         "[null, foo]", "CHAR(3) ARRAY NOT NULL");
     f2.checkScalar("array(null)",
         "[null]", "NULL ARRAY NOT NULL");
+    // check complex type
+    f2.checkScalar("array(row(1))", "[{1}]",
+        "RecordType(INTEGER NOT NULL EXPR$0) NOT NULL ARRAY NOT NULL");
+    f2.checkScalar("array(row(1, null))", "[{1, null}]",
+        "RecordType(INTEGER NOT NULL EXPR$0, NULL EXPR$1) NOT NULL ARRAY NOT 
NULL");
+    f2.checkScalar("array(row(null, 1))", "[{null, 1}]",
+        "RecordType(NULL EXPR$0, INTEGER NOT NULL EXPR$1) NOT NULL ARRAY NOT 
NULL");
+    f2.checkScalar("array(row(1, 2))", "[{1, 2}]",
+        "RecordType(INTEGER NOT NULL EXPR$0, INTEGER NOT NULL EXPR$1) NOT NULL 
ARRAY NOT NULL");
+    f2.checkScalar("array(row(1, 2), null)",
+        "[{1, 2}, null]", "RecordType(INTEGER EXPR$0, INTEGER EXPR$1) ARRAY 
NOT NULL");
+    f2.checkScalar("array(null, row(1, 2))",
+        "[null, {1, 2}]", "RecordType(INTEGER EXPR$0, INTEGER EXPR$1) ARRAY 
NOT NULL");
+    f2.checkScalar("array(row(1, null), row(2, null))", "[{1, null}, {2, 
null}]",
+        "RecordType(INTEGER NOT NULL EXPR$0, NULL EXPR$1) NOT NULL ARRAY NOT 
NULL");
+    f2.checkScalar("array(row(null, 1), row(null, 2))", "[{null, 1}, {null, 
2}]",
+        "RecordType(NULL EXPR$0, INTEGER NOT NULL EXPR$1) NOT NULL ARRAY NOT 
NULL");
+    f2.checkScalar("array(row(1, null), row(null, 2))", "[{1, null}, {null, 
2}]",
+        "RecordType(INTEGER EXPR$0, INTEGER EXPR$1) NOT NULL ARRAY NOT NULL");
+    f2.checkScalar("array(row(null, 1), row(2, null))", "[{null, 1}, {2, 
null}]",
+        "RecordType(INTEGER EXPR$0, INTEGER EXPR$1) NOT NULL ARRAY NOT NULL");
+    f2.checkScalar("array(row(1, 2), row(3, 4))", "[{1, 2}, {3, 4}]",
+        "RecordType(INTEGER NOT NULL EXPR$0, INTEGER NOT NULL EXPR$1) NOT NULL 
ARRAY NOT NULL");
+    // checkFails
+    f2.checkFails("^array(row(1), row(2, 3))^",
+        "Parameters must be of the same type", false);
+    f2.checkFails("^array(row(1), row(2, 3), null)^",
+        "Parameters must be of the same type", false);
     // calcite default cast char type will fill extra spaces
     f2.checkScalar("array(1, 2, 'Hi')",
         "[1 , 2 , Hi]", "CHAR(2) NOT NULL ARRAY NOT NULL");

Reply via email to