This is an automated email from the ASF dual-hosted git repository.

mbudiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new 90beb01b07 [CALCITE-6283] Function ARRAY_APPEND with a NULL array 
argument crashes with NullPointerException
90beb01b07 is described below

commit 90beb01b0713a010167aeb2c810ffebccf3aa3e1
Author: Mihai Budiu <[email protected]>
AuthorDate: Mon Mar 11 09:50:04 2024 -0700

    [CALCITE-6283] Function ARRAY_APPEND with a NULL array argument crashes 
with NullPointerException
    
    Signed-off-by: Mihai Budiu <[email protected]>
---
 .../calcite/sql/fun/SqlLibraryOperators.java       | 16 ++++---
 .../sql/type/ArrayElementOperandTypeChecker.java   | 54 +++++++++++-----------
 .../org/apache/calcite/sql/type/OperandTypes.java  |  9 +++-
 .../org/apache/calcite/test/SqlOperatorTest.java   | 21 ++++++++-
 4 files changed, 65 insertions(+), 35 deletions(-)

diff --git 
a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java 
b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
index 130e01fe18..003e66fea1 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlLibraryOperators.java
@@ -1180,6 +1180,10 @@ public abstract class SqlLibraryOperators {
   private static RelDataType arrayAppendPrependReturnType(SqlOperatorBinding 
opBinding) {
     final RelDataType arrayType = opBinding.collectOperandTypes().get(0);
     final RelDataType componentType = arrayType.getComponentType();
+    if (componentType == null) {
+      // NULL used for array.
+      return arrayType;
+    }
     final RelDataType elementType = opBinding.collectOperandTypes().get(1);
     RelDataType type =
         opBinding.getTypeFactory().leastRestrictive(
@@ -1196,7 +1200,7 @@ public abstract class SqlLibraryOperators {
   public static final SqlFunction ARRAY_APPEND =
       SqlBasicFunction.create(SqlKind.ARRAY_APPEND,
           SqlLibraryOperators::arrayAppendPrependReturnType,
-          OperandTypes.ARRAY_ELEMENT);
+          OperandTypes.ARRAY_ELEMENT_NONNULL);
 
   /** The "EXISTS(array, lambda)" function (Spark); returns whether a 
predicate holds
    * for one or more elements in the array. */
@@ -1311,35 +1315,35 @@ public abstract class SqlLibraryOperators {
   public static final SqlFunction ARRAY_MAX =
       SqlBasicFunction.create(SqlKind.ARRAY_MAX,
           ReturnTypes.TO_COLLECTION_ELEMENT_FORCE_NULLABLE,
-          OperandTypes.ARRAY);
+          OperandTypes.ARRAY_NONNULL);
 
   /** The "ARRAY_MAX(array)" function. */
   @LibraryOperator(libraries = {SPARK})
   public static final SqlFunction ARRAY_MIN =
       SqlBasicFunction.create(SqlKind.ARRAY_MIN,
           ReturnTypes.TO_COLLECTION_ELEMENT_FORCE_NULLABLE,
-          OperandTypes.ARRAY);
+          OperandTypes.ARRAY_NONNULL);
 
   /** The "ARRAY_POSITION(array, element)" function. */
   @LibraryOperator(libraries = {SPARK})
   public static final SqlFunction ARRAY_POSITION =
       SqlBasicFunction.create(SqlKind.ARRAY_POSITION,
           ReturnTypes.BIGINT_NULLABLE,
-          OperandTypes.ARRAY_ELEMENT);
+          OperandTypes.ARRAY_ELEMENT_NONNULL);
 
   /** The "ARRAY_PREPEND(array, element)" function. */
   @LibraryOperator(libraries = {SPARK})
   public static final SqlFunction ARRAY_PREPEND =
       SqlBasicFunction.create(SqlKind.ARRAY_PREPEND,
           SqlLibraryOperators::arrayAppendPrependReturnType,
-          OperandTypes.ARRAY_ELEMENT);
+          OperandTypes.ARRAY_ELEMENT_NONNULL);
 
   /** The "ARRAY_REMOVE(array, element)" function. */
   @LibraryOperator(libraries = {SPARK})
   public static final SqlFunction ARRAY_REMOVE =
       SqlBasicFunction.create(SqlKind.ARRAY_REMOVE,
           ReturnTypes.ARG0_NULLABLE,
-          OperandTypes.ARRAY_ELEMENT);
+          OperandTypes.ARRAY_ELEMENT_NONNULL);
 
   /** The "ARRAY_REPEAT(element, count)" function. */
   @LibraryOperator(libraries = {SPARK})
diff --git 
a/core/src/main/java/org/apache/calcite/sql/type/ArrayElementOperandTypeChecker.java
 
b/core/src/main/java/org/apache/calcite/sql/type/ArrayElementOperandTypeChecker.java
index bed38a73e6..13eb7f3414 100644
--- 
a/core/src/main/java/org/apache/calcite/sql/type/ArrayElementOperandTypeChecker.java
+++ 
b/core/src/main/java/org/apache/calcite/sql/type/ArrayElementOperandTypeChecker.java
@@ -21,7 +21,6 @@ import org.apache.calcite.sql.SqlCallBinding;
 import org.apache.calcite.sql.SqlNode;
 import org.apache.calcite.sql.SqlOperandCountRange;
 import org.apache.calcite.sql.SqlOperator;
-import org.apache.calcite.sql.SqlUtil;
 
 import com.google.common.collect.ImmutableList;
 
@@ -34,19 +33,14 @@ import static org.apache.calcite.util.Static.RESOURCE;
 public class ArrayElementOperandTypeChecker implements SqlOperandTypeChecker {
   //~ Instance fields --------------------------------------------------------
 
-  private final boolean allowNullCheck;
-  private final boolean allowCast;
+  private final boolean arrayMayBeNull;
+  private final boolean elementMayBeNull;
 
   //~ Constructors -----------------------------------------------------------
 
-  public ArrayElementOperandTypeChecker() {
-    this.allowNullCheck = false;
-    this.allowCast = false;
-  }
-
-  public ArrayElementOperandTypeChecker(boolean allowNullCheck, boolean 
allowCast) {
-    this.allowNullCheck = allowNullCheck;
-    this.allowCast = allowCast;
+  public ArrayElementOperandTypeChecker(boolean arrayMayBeNull, boolean 
elementMayBeNull) {
+    this.arrayMayBeNull = arrayMayBeNull;
+    this.elementMayBeNull = elementMayBeNull;
   }
 
   //~ Methods ----------------------------------------------------------------
@@ -54,20 +48,19 @@ public class ArrayElementOperandTypeChecker implements 
SqlOperandTypeChecker {
   @Override public boolean checkOperandTypes(
       SqlCallBinding callBinding,
       boolean throwOnFailure) {
-    if (allowNullCheck) {
-      // no operand can be null for type-checking to succeed
-      for (SqlNode node : callBinding.operands()) {
-        if (SqlUtil.isNullLiteral(node, allowCast)) {
-          if (throwOnFailure) {
-            throw callBinding.getValidator().newValidationError(node, 
RESOURCE.nullIllegal());
-          } else {
-            return false;
-          }
-        }
+    final SqlNode op0 = callBinding.operand(0);
+    RelDataType arrayType = SqlTypeUtil.deriveType(callBinding, op0);
+
+    // Check if op0 is allowed to be NULL
+    if (!this.arrayMayBeNull && arrayType.getSqlTypeName() == 
SqlTypeName.NULL) {
+      if (throwOnFailure) {
+        throw callBinding.getValidator().newValidationError(op0, 
RESOURCE.nullIllegal());
+      } else {
+        return false;
       }
     }
 
-    final SqlNode op0 = callBinding.operand(0);
+    // Check that op0 is an ARRAY type
     if (!OperandTypes.ARRAY.checkSingleOperandType(
         callBinding,
         op0,
@@ -75,20 +68,29 @@ public class ArrayElementOperandTypeChecker implements 
SqlOperandTypeChecker {
         throwOnFailure)) {
       return false;
     }
-
     RelDataType arrayComponentType =
         getComponentTypeOrThrow(SqlTypeUtil.deriveType(callBinding, op0));
+
     final SqlNode op1 = callBinding.operand(1);
-    RelDataType aryType1 = SqlTypeUtil.deriveType(callBinding, op1);
+    RelDataType elementType = SqlTypeUtil.deriveType(callBinding, op1);
+
+    // Check if elementType is allowed to be NULL
+    if (!this.elementMayBeNull && elementType.getSqlTypeName() == 
SqlTypeName.NULL) {
+      if (throwOnFailure) {
+        throw callBinding.getValidator().newValidationError(op1, 
RESOURCE.nullIllegal());
+      } else {
+        return false;
+      }
+    }
 
     RelDataType biggest =
         callBinding.getTypeFactory().leastRestrictive(
-            ImmutableList.of(arrayComponentType, aryType1));
+            ImmutableList.of(arrayComponentType, elementType));
     if (biggest == null) {
       if (throwOnFailure) {
         throw callBinding.newError(
             RESOURCE.typeNotComparable(
-                arrayComponentType.toString(), aryType1.toString()));
+                arrayComponentType.toString(), elementType.toString()));
       }
 
       return false;
diff --git a/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java 
b/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java
index 18767f5a7d..67367ba3f4 100644
--- a/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java
+++ b/core/src/main/java/org/apache/calcite/sql/type/OperandTypes.java
@@ -612,10 +612,15 @@ public abstract class OperandTypes {
       new ArrayFunctionOperandTypeChecker();
 
   public static final SqlOperandTypeChecker ARRAY_ELEMENT =
-      new ArrayElementOperandTypeChecker();
+      new ArrayElementOperandTypeChecker(true, true);
 
   public static final SqlOperandTypeChecker ARRAY_ELEMENT_NONNULL =
-      new ArrayElementOperandTypeChecker(true, false);
+      new ArrayElementOperandTypeChecker(false, true);
+
+  /** Type checker that accepts an ARRAY as the first argument, but not
+   * an expression with type NULL (i.e. a NULL literal). */
+  public static final SqlOperandTypeChecker ARRAY_NONNULL =
+      family(SqlTypeFamily.ARRAY).and(new NotNullOperandTypeChecker(1, false));
 
   public static final SqlOperandTypeChecker ARRAY_INSERT =
       new ArrayInsertOperandTypeChecker();
diff --git a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java 
b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
index 0a0e37e959..8aaef82fcf 100644
--- a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
+++ b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java
@@ -6314,6 +6314,25 @@ public class SqlOperatorTest {
     f.checkScalar("rand_integer(2, 11)", 1, "INTEGER NOT NULL");
   }
 
+  /** Test case for <a 
href="https://issues.apache.org/jira/browse/CALCITE-6283";>
+   * [CALCITE-6283] Function array_append with a NULL array argument crashes 
with
+   * NullPointerException</a>. */
+  @Test void testArrayNullFunc() {
+    final String expected = "Illegal use of 'NULL'";
+    final SqlOperatorFixture f = fixture().withLibrary(SqlLibrary.SPARK);
+    f.checkFails("array_append(^null^, 2)", expected, false);
+    f.checkFails("array_prepend(^null^, 2)", expected, false);
+    f.checkFails("array_remove(^null^, 2)", expected, false);
+    f.checkFails("array_contains(^null^, 2)", expected, false);
+    f.checkFails("array_position(^null^, 2)", expected, false);
+    f.checkFails("^array_min(null)^",
+        "Cannot apply 'ARRAY_MIN' to arguments of type 
'ARRAY_MIN\\(<NULL>\\)'."
+            + " Supported form\\(s\\): 'ARRAY_MIN\\(<ARRAY>\\)'", false);
+    f.checkFails("^array_max(null)^",
+        "Cannot apply 'ARRAY_MAX' to arguments of type 
'ARRAY_MAX\\(<NULL>\\)'."
+        + " Supported form\\(s\\): 'ARRAY_MAX\\(<ARRAY>\\)'", false);
+  }
+
   /** Tests {@code ARRAY_APPEND} function from Spark. */
   @Test void testArrayAppendFunc() {
     final SqlOperatorFixture f0 = fixture();
@@ -6421,7 +6440,7 @@ public class SqlOperatorTest {
         "INTEGER is not comparable to BOOLEAN", false);
 
     // check null without cast
-    f.checkFails("array_contains(array[1, 2], ^null^)", "Illegal use of 
'NULL'", false);
+    f.checkNull("array_contains(array[1, 2], null)");
     f.checkFails("array_contains(^null^, array[1, 2])", "Illegal use of 
'NULL'", false);
     f.checkFails("array_contains(^null^, null)", "Illegal use of 'NULL'", 
false);
   }

Reply via email to