This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6a36c43604a6 [SPARK-49506][SQL] Optimize ArrayBinarySearch for 
foldable array
6a36c43604a6 is described below

commit 6a36c43604a638f603a3a40c22ee1e6bd3ae8d7e
Author: panbingkun <[email protected]>
AuthorDate: Tue Oct 29 07:23:57 2024 -0700

    [SPARK-49506][SQL] Optimize ArrayBinarySearch for foldable array
    
    ### What changes were proposed in this pull request?
    The pr aims to
    - optimize `ArrayBinarySearch` for `foldable` array.
    - fix a bug in the original implementation.
    
    ### Why are the changes needed?
    The changes improve performance of the `array_binary_search()` function.
    - create an instance of `foldable{DataType}ArrayData` only once at the 
initialization ( avoid frequent calls to `ArrayData.to{DataType}Array()` ), and 
reuse it inside of `replacement` in the case when the `array` parameter is 
foldable.
    
    Before:
    ```
    Running benchmark: array binary search
      Running case: no foldable optimize
      Stopped after 100 iterations, 93668 ms
    
    OpenJDK 64-Bit Server VM 17.0.10+7-LTS on Mac OS X 14.6.1
    Apple M2
    array binary search:                      Best Time(ms)   Avg Time(ms)   
Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
    
------------------------------------------------------------------------------------------------------------------------
    no foldable optimize                                916            937      
    24         10.9          91.6       1.0X
    ```
    
    After:
    ```
    Running benchmark: array binary search
      Running case: has foldable optimize
      Stopped after 100 iterations, 17206 ms
    
    OpenJDK 64-Bit Server VM 17.0.10+7-LTS on Mac OS X 14.6.1
    Apple M2
    array binary search:                      Best Time(ms)   Avg Time(ms)   
Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
    
------------------------------------------------------------------------------------------------------------------------
    has foldable optimize                               164            172      
    22         61.1          16.4       1.0X
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    - Update existed UT.
    - Pass GA.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #48225 from panbingkun/SPARK-49506_FOLLOWUP.
    
    Authored-by: panbingkun <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../catalyst/expressions/ArrayExpressionUtils.java | 152 +++++++++-------
 .../sql/catalyst/expressions/ToJavaArrayUtils.java | 112 ++++++++++++
 .../sql/catalyst/expressions/ToJavaArray.scala     | 105 +++++++++++
 .../expressions/collectionOperations.scala         |  50 ++----
 .../sql/catalyst/expressions/objects/objects.scala |  26 ++-
 .../expressions/CollectionExpressionsSuite.scala   | 192 ++++++++++++++++++++-
 6 files changed, 534 insertions(+), 103 deletions(-)

diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ArrayExpressionUtils.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ArrayExpressionUtils.java
index ff6525acbe53..5411aa684ea5 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ArrayExpressionUtils.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ArrayExpressionUtils.java
@@ -19,20 +19,13 @@ package org.apache.spark.sql.catalyst.expressions;
 import java.util.Arrays;
 import java.util.Comparator;
 
-import org.apache.spark.sql.catalyst.util.ArrayData;
 import org.apache.spark.sql.catalyst.util.SQLOrderingUtil;
-import org.apache.spark.sql.types.ByteType$;
-import org.apache.spark.sql.types.BooleanType$;
-import org.apache.spark.sql.types.DataType;
-import org.apache.spark.sql.types.DoubleType$;
-import org.apache.spark.sql.types.FloatType$;
-import org.apache.spark.sql.types.IntegerType$;
-import org.apache.spark.sql.types.LongType$;
-import org.apache.spark.sql.types.ShortType$;
 
 public class ArrayExpressionUtils {
 
-  private static final Comparator<Object> booleanComp = (o1, o2) -> {
+  // comparator
+  // Boolean ascending nullable comparator
+  private static final Comparator<Boolean> booleanComp = (o1, o2) -> {
     if (o1 == null && o2 == null) {
       return 0;
     } else if (o1 == null) {
@@ -40,11 +33,11 @@ public class ArrayExpressionUtils {
     } else if (o2 == null) {
       return 1;
     }
-    boolean c1 = (Boolean) o1, c2 = (Boolean) o2;
-    return c1 == c2 ? 0 : (c1 ? 1 : -1);
+    return o1.equals(o2) ? 0 : (o1 ? 1 : -1);
   };
 
-  private static final Comparator<Object> byteComp = (o1, o2) -> {
+  // Byte ascending nullable comparator
+  private static final Comparator<Byte> byteComp = (o1, o2) -> {
     if (o1 == null && o2 == null) {
       return 0;
     } else if (o1 == null) {
@@ -52,11 +45,11 @@ public class ArrayExpressionUtils {
     } else if (o2 == null) {
       return 1;
     }
-    byte c1 = (Byte) o1, c2 = (Byte) o2;
-    return Byte.compare(c1, c2);
+    return Byte.compare(o1, o2);
   };
 
-  private static final Comparator<Object> shortComp = (o1, o2) -> {
+  // Short ascending nullable comparator
+  private static final Comparator<Short> shortComp = (o1, o2) -> {
     if (o1 == null && o2 == null) {
       return 0;
     } else if (o1 == null) {
@@ -64,11 +57,11 @@ public class ArrayExpressionUtils {
     } else if (o2 == null) {
       return 1;
     }
-    short c1 = (Short) o1, c2 = (Short) o2;
-    return Short.compare(c1, c2);
+    return Short.compare(o1, o2);
   };
 
-  private static final Comparator<Object> integerComp = (o1, o2) -> {
+  // Integer ascending nullable comparator
+  private static final Comparator<Integer> integerComp = (o1, o2) -> {
     if (o1 == null && o2 == null) {
       return 0;
     } else if (o1 == null) {
@@ -76,11 +69,11 @@ public class ArrayExpressionUtils {
     } else if (o2 == null) {
       return 1;
     }
-    int c1 = (Integer) o1, c2 = (Integer) o2;
-    return Integer.compare(c1, c2);
+    return Integer.compare(o1, o2);
   };
 
-  private static final Comparator<Object> longComp = (o1, o2) -> {
+  // Long ascending nullable comparator
+  private static final Comparator<Long> longComp = (o1, o2) -> {
     if (o1 == null && o2 == null) {
       return 0;
     } else if (o1 == null) {
@@ -88,11 +81,11 @@ public class ArrayExpressionUtils {
     } else if (o2 == null) {
       return 1;
     }
-    long c1 = (Long) o1, c2 = (Long) o2;
-    return Long.compare(c1, c2);
+    return Long.compare(o1, o2);
   };
 
-  private static final Comparator<Object> floatComp = (o1, o2) -> {
+  // Float ascending nullable comparator
+  private static final Comparator<Float> floatComp = (o1, o2) -> {
     if (o1 == null && o2 == null) {
       return 0;
     } else if (o1 == null) {
@@ -100,11 +93,11 @@ public class ArrayExpressionUtils {
     } else if (o2 == null) {
       return 1;
     }
-    float c1 = (Float) o1, c2 = (Float) o2;
-    return SQLOrderingUtil.compareFloats(c1, c2);
+    return SQLOrderingUtil.compareFloats(o1, o2);
   };
 
-  private static final Comparator<Object> doubleComp = (o1, o2) -> {
+  // Double ascending nullable comparator
+  private static final Comparator<Double> doubleComp = (o1, o2) -> {
     if (o1 == null && o2 == null) {
       return 0;
     } else if (o1 == null) {
@@ -112,65 +105,104 @@ public class ArrayExpressionUtils {
     } else if (o2 == null) {
       return 1;
     }
-    double c1 = (Double) o1, c2 = (Double) o2;
-    return SQLOrderingUtil.compareDoubles(c1, c2);
+    return SQLOrderingUtil.compareDoubles(o1, o2);
   };
 
-  public static int binarySearchNullSafe(ArrayData data, Boolean value) {
-    return Arrays.binarySearch(data.toObjectArray(BooleanType$.MODULE$), 
value, booleanComp);
+  // boolean
+  // boolean non-nullable
+  public static int binarySearch(boolean[] data, boolean value) {
+    int low = 0;
+    int high = data.length - 1;
+
+    while (low <= high) {
+      int mid = (low + high) >>> 1;
+      boolean midVal = data[mid];
+
+      if (value == midVal) {
+        return mid; // key found
+      } else if (value) {
+        low = mid + 1;
+      } else {
+        high = mid - 1;
+      }
+    }
+
+    return -(low + 1);  // key not found.
+  }
+
+  // Boolean nullable
+  public static int binarySearch(Boolean[] data, Boolean value) {
+    return Arrays.binarySearch(data, value, booleanComp);
   }
 
-  public static int binarySearch(ArrayData data, byte value) {
-    return Arrays.binarySearch(data.toByteArray(), value);
+  // byte
+  // byte non-nullable
+  public static int binarySearch(byte[] data, byte value) {
+    return Arrays.binarySearch(data, value);
   }
 
-  public static int binarySearchNullSafe(ArrayData data, Byte value) {
-    return Arrays.binarySearch(data.toObjectArray(ByteType$.MODULE$), value, 
byteComp);
+  // Byte nullable
+  public static int binarySearch(Byte[] data, Byte value) {
+    return Arrays.binarySearch(data, value, byteComp);
   }
 
-  public static int binarySearch(ArrayData data, short value) {
-    return Arrays.binarySearch(data.toShortArray(), value);
+  // short
+  // short non-nullable
+  public static int binarySearch(short[] data, short value) {
+    return Arrays.binarySearch(data, value);
   }
 
-  public static int binarySearchNullSafe(ArrayData data, Short value) {
-    return Arrays.binarySearch(data.toObjectArray(ShortType$.MODULE$), value, 
shortComp);
+  // Short nullable
+  public static int binarySearch(Short[] data, Short value) {
+    return Arrays.binarySearch(data, value, shortComp);
   }
 
-  public static int binarySearch(ArrayData data, int value) {
-    return Arrays.binarySearch(data.toIntArray(), value);
+  // int
+  // int non-nullable
+  public static int binarySearch(int[] data, int value) {
+    return Arrays.binarySearch(data, value);
   }
 
-  public static int binarySearchNullSafe(ArrayData data, Integer value) {
-    return Arrays.binarySearch(data.toObjectArray(IntegerType$.MODULE$), 
value, integerComp);
+  // Integer nullable
+  public static int binarySearch(Integer[] data, Integer value) {
+    return Arrays.binarySearch(data, value, integerComp);
   }
 
-  public static int binarySearch(ArrayData data, long value) {
-    return Arrays.binarySearch(data.toLongArray(), value);
+  // long
+  // long non-nullable
+  public static int binarySearch(long[] data, long value) {
+    return Arrays.binarySearch(data, value);
   }
 
-  public static int binarySearchNullSafe(ArrayData data, Long value) {
-    return Arrays.binarySearch(data.toObjectArray(LongType$.MODULE$), value, 
longComp);
+  // Long nullable
+  public static int binarySearch(Long[] data, Long value) {
+    return Arrays.binarySearch(data, value, longComp);
   }
 
-  public static int binarySearch(ArrayData data, float value) {
-    return Arrays.binarySearch(data.toFloatArray(), value);
+  // float
+  // float non-nullable
+  public static int binarySearch(float[] data, float value) {
+    return Arrays.binarySearch(data, value);
   }
 
-  public static int binarySearchNullSafe(ArrayData data, Float value) {
-    return Arrays.binarySearch(data.toObjectArray(FloatType$.MODULE$), value, 
floatComp);
+  // Float nullable
+  public static int binarySearch(Float[] data, Float value) {
+    return Arrays.binarySearch(data, value, floatComp);
   }
 
-  public static int binarySearch(ArrayData data, double value) {
-    return Arrays.binarySearch(data.toDoubleArray(), value);
+  // double
+  // double non-nullable
+  public static int binarySearch(double[] data, double value) {
+    return Arrays.binarySearch(data, value);
   }
 
-  public static int binarySearchNullSafe(ArrayData data, Double value) {
-    return Arrays.binarySearch(data.toObjectArray(DoubleType$.MODULE$), value, 
doubleComp);
+  // Double nullable
+  public static int binarySearch(Double[] data, Double value) {
+    return Arrays.binarySearch(data, value, doubleComp);
   }
 
-  public static int binarySearch(
-    DataType elementType, Comparator<Object> comp, ArrayData data, Object 
value) {
-    Object[] array = data.toObjectArray(elementType);
-    return Arrays.binarySearch(array, value, comp);
+  // Object
+  public static int binarySearch(Object[] data, Object value, 
Comparator<Object> comp) {
+    return Arrays.binarySearch(data, value, comp);
   }
 }
diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ToJavaArrayUtils.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ToJavaArrayUtils.java
new file mode 100644
index 000000000000..ead138590ca5
--- /dev/null
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ToJavaArrayUtils.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions;
+
+import scala.reflect.ClassTag$;
+
+import org.apache.spark.sql.catalyst.util.ArrayData;
+
+import static org.apache.spark.sql.types.DataTypes.BooleanType;
+import static org.apache.spark.sql.types.DataTypes.ByteType;
+import static org.apache.spark.sql.types.DataTypes.DoubleType;
+import static org.apache.spark.sql.types.DataTypes.FloatType;
+import static org.apache.spark.sql.types.DataTypes.IntegerType;
+import static org.apache.spark.sql.types.DataTypes.LongType;
+import static org.apache.spark.sql.types.DataTypes.ShortType;
+
+public class ToJavaArrayUtils {
+
+  // boolean
+  // boolean non-nullable
+  public static boolean[] toBooleanArray(ArrayData arrayData) {
+    return arrayData.toBooleanArray();
+  }
+
+  // Boolean nullable
+  public static Boolean[] toBoxedBooleanArray(ArrayData arrayData) {
+    return (Boolean[]) arrayData.toArray(BooleanType,
+        ClassTag$.MODULE$.apply(java.lang.Boolean.class));
+  }
+
+  // byte
+  // byte non-nullable
+  public static byte[] toByteArray(ArrayData arrayData) {
+    return arrayData.toByteArray();
+  }
+
+  // Byte nullable
+  public static Byte[] toBoxedByteArray(ArrayData arrayData) {
+    return (Byte[]) arrayData.toArray(ByteType, 
ClassTag$.MODULE$.apply(java.lang.Byte.class));
+  }
+
+  // short
+  // short non-nullable
+  public static short[] toShortArray(ArrayData arrayData) {
+    return arrayData.toShortArray();
+  }
+
+  // Short nullable
+  public static Short[] toBoxedShortArray(ArrayData arrayData) {
+    return (Short[]) arrayData.toArray(ShortType, 
ClassTag$.MODULE$.apply(java.lang.Short.class));
+  }
+
+  // int
+  // int non-nullable
+  public static int[] toIntegerArray(ArrayData arrayData) {
+    return arrayData.toIntArray();
+  }
+
+  // Integer nullable
+  public static Integer[] toBoxedIntegerArray(ArrayData arrayData) {
+    return (Integer[]) arrayData.toArray(IntegerType,
+        ClassTag$.MODULE$.apply(java.lang.Integer.class));
+  }
+
+  // long
+  // long non-nullable
+  public static long[] toLongArray(ArrayData arrayData) {
+    return arrayData.toLongArray();
+  }
+
+  // Long nullable
+  public static Long[] toBoxedLongArray(ArrayData arrayData) {
+    return (Long[]) arrayData.toArray(LongType, 
ClassTag$.MODULE$.apply(java.lang.Long.class));
+  }
+
+  // float
+  // float non-nullable
+  public static float[] toFloatArray(ArrayData arrayData) {
+    return arrayData.toFloatArray();
+  }
+
+  // Float nullable
+  public static Float[] toBoxedFloatArray(ArrayData arrayData) {
+    return (Float[]) arrayData.toArray(FloatType, 
ClassTag$.MODULE$.apply(java.lang.Float.class));
+  }
+
+  // double
+  // double non-nullable
+  public static double[] toDoubleArray(ArrayData arrayData) {
+    return arrayData.toDoubleArray();
+  }
+
+  // Double nullable
+  public static Double[] toBoxedDoubleArray(ArrayData arrayData) {
+    return (Double[]) arrayData.toArray(DoubleType,
+        ClassTag$.MODULE$.apply(java.lang.Double.class));
+  }
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToJavaArray.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToJavaArray.scala
new file mode 100644
index 000000000000..861d7ff4024a
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToJavaArray.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import java.lang.reflect.{Array => JArray}
+
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
+import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator
+import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke}
+import org.apache.spark.sql.errors.QueryErrorsBase
+import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
+
+/**
+ * This expression converts data of `ArrayData` to an array of java type.
+ *
+ * NOTE: When the data type of expression is `ArrayType`, and the expression 
is foldable,
+ * the `ConstantFolding` can do constant folding optimization automatically,
+ * (avoiding frequent calls to `ArrayData.to{XXX}Array()`).
+ */
+case class ToJavaArray(array: Expression)
+  extends UnaryExpression
+  with NullIntolerant
+  with RuntimeReplaceable
+  with QueryErrorsBase {
+
+  override def checkInputDataTypes(): TypeCheckResult = array.dataType match {
+    case ArrayType(_, _) =>
+      TypeCheckResult.TypeCheckSuccess
+    case _ =>
+      DataTypeMismatch(
+        errorSubClass = "UNEXPECTED_INPUT_TYPE",
+        messageParameters = Map(
+          "paramIndex" -> ordinalNumber(0),
+          "requiredType" -> toSQLType(ArrayType),
+          "inputSql" -> toSQLExpr(array),
+          "inputType" -> toSQLType(array.dataType))
+      )
+  }
+
+  override def foldable: Boolean = array.foldable
+
+  override def child: Expression = array
+  override def prettyName: String = "to_java_array"
+
+  private def resultArrayElementNullable: Boolean =
+    array.dataType.asInstanceOf[ArrayType].containsNull
+  private def isPrimitiveType: Boolean = 
CodeGenerator.isPrimitiveType(elementType)
+  private def canPerformFast: Boolean = isPrimitiveType && 
!resultArrayElementNullable
+
+  @transient lazy val elementType: DataType =
+    array.dataType.asInstanceOf[ArrayType].elementType
+  @transient private lazy val elementObjectType = ObjectType(classOf[DataType])
+  @transient private lazy val elementCls: Class[_] = {
+    if (canPerformFast) {
+      CodeGenerator.javaClass(elementType)
+    } else if (isPrimitiveType) {
+      Utils.classForName(s"java.lang.${CodeGenerator.boxedType(elementType)}")
+    } else {
+      classOf[Object]
+    }
+  }
+  @transient private lazy val returnCls = JArray.newInstance(elementCls, 
0).getClass
+
+  override def dataType: DataType = ObjectType(returnCls)
+
+  override def replacement: Expression = {
+    if (isPrimitiveType) {
+      val funcNamePrefix = if (resultArrayElementNullable) "toBoxed" else "to"
+      val funcName = 
s"$funcNamePrefix${CodeGenerator.boxedType(elementType)}Array"
+      StaticInvoke(
+        classOf[ToJavaArrayUtils],
+        dataType,
+        funcName,
+        Seq(array),
+        Seq(array.dataType))
+    } else {
+      Invoke(
+        array,
+        "toObjectArray",
+        dataType,
+        Seq(Literal(elementType, elementObjectType)),
+        Seq(elementObjectType))
+    }
+  }
+
+  override protected def withNewChildInternal(newChild: Expression): 
Expression =
+    copy(array = newChild)
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 0d563530bcbc..10e64626d1a1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -1603,12 +1603,7 @@ case class ArrayBinarySearch(array: Expression, value: 
Expression)
 
   @transient private lazy val elementType: DataType =
     array.dataType.asInstanceOf[ArrayType].elementType
-  @transient private lazy val resultArrayElementNullable: Boolean =
-    array.dataType.asInstanceOf[ArrayType].containsNull
-
   @transient private lazy val isPrimitiveType: Boolean = 
CodeGenerator.isPrimitiveType(elementType)
-  @transient private lazy val canPerformFastBinarySearch: Boolean = 
isPrimitiveType &&
-    elementType != BooleanType && !resultArrayElementNullable
 
   @transient private lazy val comp: Comparator[Any] = new Comparator[Any] with 
Serializable {
     private val ordering = array.dataType match {
@@ -1619,39 +1614,28 @@ case class ArrayBinarySearch(array: Expression, value: 
Expression)
     override def compare(o1: Any, o2: Any): Int =
       (o1, o2) match {
         case (null, null) => 0
-        case (null, _) => 1
-        case (_, null) => -1
+        case (null, _) => -1
+        case (_, null) => 1
         case _ => ordering.compare(o1, o2)
       }
   }
 
-  @transient private lazy val elementObjectType = ObjectType(classOf[DataType])
-  @transient private lazy val  comparatorObjectType = 
ObjectType(classOf[Comparator[Object]])
-  override def replacement: Expression =
-    if (canPerformFastBinarySearch) {
-      StaticInvoke(
-        classOf[ArrayExpressionUtils],
-        IntegerType,
-        "binarySearch",
-        Seq(array, value),
-        inputTypes)
-    } else if (isPrimitiveType) {
-      StaticInvoke(
-        classOf[ArrayExpressionUtils],
-        IntegerType,
-        "binarySearchNullSafe",
-        Seq(array, value),
-        inputTypes)
+  @transient private lazy val comparatorObjectType = 
ObjectType(classOf[Comparator[Object]])
+
+  override def replacement: Expression = {
+    val toJavaArray = ToJavaArray(array)
+    val (arguments, inputTypes) = if (isPrimitiveType) {
+      (Seq(toJavaArray, value), Seq(toJavaArray.dataType, value.dataType))
     } else {
-      StaticInvoke(
-        classOf[ArrayExpressionUtils],
-        IntegerType,
-        "binarySearch",
-        Seq(Literal(elementType, elementObjectType),
-          Literal(comp, comparatorObjectType),
-          array,
-          value),
-        elementObjectType +: comparatorObjectType +: inputTypes)
+      (Seq(toJavaArray, value, Literal(comp, comparatorObjectType)),
+        Seq(toJavaArray.dataType, value.dataType, comparatorObjectType))
+    }
+    StaticInvoke(
+      classOf[ArrayExpressionUtils],
+      IntegerType,
+      "binarySearch",
+      arguments,
+      inputTypes)
   }
 
   override def prettyName: String = "array_binary_search"
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 9af63a754124..7c198f05cf49 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -86,11 +86,29 @@ trait InvokeLike extends Expression with NonSQLExpression 
with ImplicitCastInput
 
   // Returns true if we can trust all values of the given DataType can be 
serialized.
   private def trustedSerializable(dt: DataType): Boolean = {
-    // Right now we conservatively block all ObjectType (Java objects) 
regardless of
-    // serializability, because the type-level info with java.io.Serializable 
and
-    // java.io.Externalizable marker interfaces are not strong guarantees.
+    // Right now we conservatively block all ObjectType (Java objects) except 
for
+    // it's `cls` equal to `Array[JavaBoxedPrimitive]` & `JavaBoxedPrimitive`
+    // regardless of serializability, because the type-level info with 
java.io.Serializable
+    // and java.io.Externalizable marker interfaces are not strong guarantees.
     // This restriction can be relaxed in the future to expose more 
optimizations.
-    !dt.existsRecursively(_.isInstanceOf[ObjectType])
+    !dt.existsRecursively {
+      case ObjectType(cls) if cls == classOf[Array[java.lang.Boolean]] => false
+      case ObjectType(cls) if cls == classOf[Array[java.lang.Byte]] => false
+      case ObjectType(cls) if cls == classOf[Array[java.lang.Short]] => false
+      case ObjectType(cls) if cls == classOf[Array[java.lang.Integer]] => false
+      case ObjectType(cls) if cls == classOf[Array[java.lang.Long]] => false
+      case ObjectType(cls) if cls == classOf[Array[java.lang.Float]] => false
+      case ObjectType(cls) if cls == classOf[Array[java.lang.Double]] => false
+      case ObjectType(cls) if cls == classOf[java.lang.Boolean] => false
+      case ObjectType(cls) if cls == classOf[java.lang.Byte] => false
+      case ObjectType(cls) if cls == classOf[java.lang.Short] => false
+      case ObjectType(cls) if cls == classOf[java.lang.Integer] => false
+      case ObjectType(cls) if cls == classOf[java.lang.Long] => false
+      case ObjectType(cls) if cls == classOf[java.lang.Float] => false
+      case ObjectType(cls) if cls == classOf[java.lang.Double] => false
+      case ObjectType(_) => true
+      case _ => false
+    }
   }
 
   /**
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 55148978fa00..1907ec7c23aa 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -137,6 +137,8 @@ class CollectionExpressionsSuite
 
   test("ArrayBinarySearch") {
     // primitive type: boolean、byte、short、int、long、float、double
+    // boolean
+    // boolean foldable
     val a0_0 = Literal.create(Seq(false, true),
       ArrayType(BooleanType, containsNull = false))
     checkEvaluation(ArrayBinarySearch(a0_0, Literal(true)), 1)
@@ -144,7 +146,23 @@ class CollectionExpressionsSuite
     checkEvaluation(ArrayBinarySearch(a0_1, Literal(false)), 1)
     val a0_2 = Literal.create(Seq(null, false, true), ArrayType(BooleanType))
     checkEvaluation(ArrayBinarySearch(a0_2, Literal(null, BooleanType)), null)
-
+    val a0_3 = CreateArray(Seq(Literal(false), Literal(true)))
+    checkEvaluation(ArrayBinarySearch(a0_3, Literal(true)), 1)
+    val a0_4 = CreateArray(Seq(Literal(null, BooleanType), Literal(false), 
Literal(true)))
+    checkEvaluation(ArrayBinarySearch(a0_4, Literal(false)), 1)
+    val a0_5 = CreateArray(Seq(Literal(null, BooleanType), Literal(false), 
Literal(true)))
+    checkEvaluation(ArrayBinarySearch(a0_5, Literal(null, BooleanType)), null)
+    // boolean non-foldable
+    val a0_6 = NonFoldableLiteral.create(Seq(false, true),
+      ArrayType(BooleanType, containsNull = false))
+    checkEvaluation(ArrayBinarySearch(a0_6, Literal(true)), 1)
+    val a0_7 = NonFoldableLiteral.create(Seq(null, false, true), 
ArrayType(BooleanType))
+    checkEvaluation(ArrayBinarySearch(a0_7, Literal(false)), 1)
+    val a0_8 = NonFoldableLiteral.create(Seq(null, false, true), 
ArrayType(BooleanType))
+    checkEvaluation(ArrayBinarySearch(a0_8, Literal(null, BooleanType)), null)
+
+    // byte
+    // byte foldable
     val a1_0 = Literal.create(Seq(1.toByte, 2.toByte, 3.toByte),
       ArrayType(ByteType, containsNull = false))
     checkEvaluation(ArrayBinarySearch(a1_0, Literal(3.toByte)), 2)
@@ -155,18 +173,70 @@ class CollectionExpressionsSuite
     val a1_3 = Literal.create(Seq(1.toByte, 3.toByte, 4.toByte),
       ArrayType(ByteType, containsNull = false))
     checkEvaluation(ArrayBinarySearch(a1_3, Literal(2.toByte, ByteType)), -2)
+    val a1_4 = CreateArray(Seq(Literal(1.toByte), Literal(2.toByte), 
Literal(3.toByte)))
+    checkEvaluation(ArrayBinarySearch(a1_4, Literal(3.toByte)), 2)
+    val a1_5 = CreateArray(Seq(Literal(null, ByteType),
+      Literal(1.toByte), Literal(2.toByte), Literal(3.toByte)))
+    checkEvaluation(ArrayBinarySearch(a1_5, Literal(1.toByte)), 1)
+    val a1_6 = CreateArray(Seq(Literal(null, ByteType),
+      Literal(1.toByte), Literal(2.toByte), Literal(3.toByte)))
+    checkEvaluation(ArrayBinarySearch(a1_6, Literal(null, ByteType)), null)
+    val a1_7 = CreateArray(Seq(Literal(1.toByte), Literal(3.toByte), 
Literal(4.toByte)))
+    checkEvaluation(ArrayBinarySearch(a1_7, Literal(2.toByte, ByteType)), -2)
+    // byte non-foldable
+    val a1_8 = NonFoldableLiteral.create(Seq(1.toByte, 2.toByte, 3.toByte),
+      ArrayType(ByteType, containsNull = false))
+    checkEvaluation(ArrayBinarySearch(a1_8, Literal(3.toByte)), 2)
+    val a1_9 = NonFoldableLiteral.create(Seq(null, 1.toByte, 2.toByte, 
3.toByte),
+      ArrayType(ByteType))
+    checkEvaluation(ArrayBinarySearch(a1_9, Literal(1.toByte)), 1)
+    val a1_10 = NonFoldableLiteral.create(Seq(null, 1.toByte, 2.toByte, 
3.toByte),
+      ArrayType(ByteType))
+    checkEvaluation(ArrayBinarySearch(a1_10, Literal(null, ByteType)), null)
+    val a1_11 = NonFoldableLiteral.create(Seq(1.toByte, 3.toByte, 4.toByte),
+      ArrayType(ByteType, containsNull = false))
+    checkEvaluation(ArrayBinarySearch(a1_11, Literal(2.toByte, ByteType)), -2)
 
+    // short
+    // short foldable
     val a2_0 = Literal.create(Seq(1.toShort, 2.toShort, 3.toShort),
       ArrayType(ShortType, containsNull = false))
     checkEvaluation(ArrayBinarySearch(a2_0, Literal(1.toShort)), 0)
-    val a2_1 = Literal.create(Seq(null, 1.toShort, 2.toShort, 3.toShort), 
ArrayType(ShortType))
+    val a2_1 = Literal.create(Seq(null, 1.toShort, 2.toShort, 3.toShort),
+      ArrayType(ShortType))
     checkEvaluation(ArrayBinarySearch(a2_1, Literal(2.toShort)), 2)
-    val a2_2 = Literal.create(Seq(null, 1.toShort, 2.toShort, 3.toShort), 
ArrayType(ShortType))
+    val a2_2 = Literal.create(Seq(null, 1.toShort, 2.toShort, 3.toShort),
+      ArrayType(ShortType))
     checkEvaluation(ArrayBinarySearch(a2_2, Literal(null, ShortType)), null)
     val a2_3 = Literal.create(Seq(1.toShort, 3.toShort, 4.toShort),
       ArrayType(ShortType, containsNull = false))
     checkEvaluation(ArrayBinarySearch(a2_3, Literal(2.toShort, ShortType)), -2)
+    val a2_4 = CreateArray(Seq(Literal(1.toShort), Literal(2.toShort), 
Literal(3.toShort)))
+    checkEvaluation(ArrayBinarySearch(a2_4, Literal(1.toShort)), 0)
+    val a2_5 = CreateArray(Seq(Literal(null, ShortType),
+      Literal(1.toShort), Literal(2.toShort), Literal(3.toShort)))
+    checkEvaluation(ArrayBinarySearch(a2_5, Literal(2.toShort)), 2)
+    val a2_6 = CreateArray(Seq(Literal(null, ShortType),
+      Literal(1.toShort), Literal(2.toShort), Literal(3.toShort)))
+    checkEvaluation(ArrayBinarySearch(a2_6, Literal(null, ShortType)), null)
+    val a2_7 = CreateArray(Seq(Literal(1.toShort), Literal(3.toShort), 
Literal(4.toShort)))
+    checkEvaluation(ArrayBinarySearch(a2_7, Literal(2.toShort, ShortType)), -2)
+    // short non-foldable
+    val a2_8 = NonFoldableLiteral.create(Seq(1.toShort, 2.toShort, 3.toShort),
+      ArrayType(ShortType, containsNull = false))
+    checkEvaluation(ArrayBinarySearch(a2_8, Literal(1.toShort)), 0)
+    val a2_9 = NonFoldableLiteral.create(Seq(null, 1.toShort, 2.toShort, 
3.toShort),
+      ArrayType(ShortType))
+    checkEvaluation(ArrayBinarySearch(a2_9, Literal(2.toShort)), 2)
+    val a2_10 = NonFoldableLiteral.create(Seq(null, 1.toShort, 2.toShort, 
3.toShort),
+      ArrayType(ShortType))
+    checkEvaluation(ArrayBinarySearch(a2_10, Literal(null, ShortType)), null)
+    val a2_11 = NonFoldableLiteral.create(Seq(1.toShort, 3.toShort, 4.toShort),
+      ArrayType(ShortType, containsNull = false))
+    checkEvaluation(ArrayBinarySearch(a2_11, Literal(2.toShort, ShortType)), 
-2)
 
+    // int
+    // int foldable
     val a3_0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, 
containsNull = false))
     checkEvaluation(ArrayBinarySearch(a3_0, Literal(2)), 1)
     val a3_1 = Literal.create(Seq(null, 1, 2, 3), ArrayType(IntegerType))
@@ -175,7 +245,28 @@ class CollectionExpressionsSuite
     checkEvaluation(ArrayBinarySearch(a3_2, Literal(null, IntegerType)), null)
     val a3_3 = Literal.create(Seq(1, 3, 4), ArrayType(IntegerType, 
containsNull = false))
     checkEvaluation(ArrayBinarySearch(a3_3, Literal(2, IntegerType)), -2)
-
+    val a3_4 = CreateArray(Seq(Literal(1), Literal(2), Literal(3)))
+    checkEvaluation(ArrayBinarySearch(a3_4, Literal(2)), 1)
+    val a3_5 = CreateArray(Seq(Literal(null, IntegerType), Literal(1), 
Literal(2), Literal(3)))
+    checkEvaluation(ArrayBinarySearch(a3_5, Literal(2)), 2)
+    val a3_6 = CreateArray(Seq(Literal(null, IntegerType), Literal(1), 
Literal(2), Literal(3)))
+    checkEvaluation(ArrayBinarySearch(a3_6, Literal(null, IntegerType)), null)
+    val a3_7 = CreateArray(Seq(Literal(1), Literal(3), Literal(4)))
+    checkEvaluation(ArrayBinarySearch(a3_7, Literal(2, IntegerType)), -2)
+    // int non-foldable
+    val a3_8 = NonFoldableLiteral.create(Seq(1, 2, 3),
+      ArrayType(IntegerType, containsNull = false))
+    checkEvaluation(ArrayBinarySearch(a3_8, Literal(2)), 1)
+    val a3_9 = NonFoldableLiteral.create(Seq(null, 1, 2, 3), 
ArrayType(IntegerType))
+    checkEvaluation(ArrayBinarySearch(a3_9, Literal(2)), 2)
+    val a3_10 = NonFoldableLiteral.create(Seq(null, 1, 2, 3), 
ArrayType(IntegerType))
+    checkEvaluation(ArrayBinarySearch(a3_10, Literal(null, IntegerType)), null)
+    val a3_11 = NonFoldableLiteral.create(Seq(1, 3, 4),
+      ArrayType(IntegerType, containsNull = false))
+    checkEvaluation(ArrayBinarySearch(a3_11, Literal(2, IntegerType)), -2)
+
+    // long
+    // long foldable
     val a4_0 = Literal.create(Seq(1L, 2L, 3L), ArrayType(LongType, 
containsNull = false))
     checkEvaluation(ArrayBinarySearch(a4_0, Literal(2L)), 1)
     val a4_1 = Literal.create(Seq(null, 1L, 2L, 3L), ArrayType(LongType))
@@ -184,7 +275,30 @@ class CollectionExpressionsSuite
     checkEvaluation(ArrayBinarySearch(a4_2, Literal(null, LongType)), null)
     val a4_3 = Literal.create(Seq(1L, 3L, 4L), ArrayType(LongType, 
containsNull = false))
     checkEvaluation(ArrayBinarySearch(a4_3, Literal(2L, LongType)), -2)
-
+    val a4_4 = CreateArray(Seq(Literal(1L), Literal(2L), Literal(3L)))
+    checkEvaluation(ArrayBinarySearch(a4_4, Literal(2L)), 1)
+    val a4_5 = CreateArray(Seq(Literal(null, LongType),
+      Literal(1L), Literal(2L), Literal(3L)))
+    checkEvaluation(ArrayBinarySearch(a4_5, Literal(2L)), 2)
+    val a4_6 = CreateArray(Seq(Literal(null, LongType),
+      Literal(1L), Literal(2L), Literal(3L)))
+    checkEvaluation(ArrayBinarySearch(a4_6, Literal(null, LongType)), null)
+    val a4_7 = CreateArray(Seq(Literal(1L), Literal(3L), Literal(4L)))
+    checkEvaluation(ArrayBinarySearch(a4_7, Literal(2L, LongType)), -2)
+    // long non-foldable
+    val a4_8 = NonFoldableLiteral.create(Seq(1L, 2L, 3L),
+      ArrayType(LongType, containsNull = false))
+    checkEvaluation(ArrayBinarySearch(a4_8, Literal(2L)), 1)
+    val a4_9 = NonFoldableLiteral.create(Seq(null, 1L, 2L, 3L), 
ArrayType(LongType))
+    checkEvaluation(ArrayBinarySearch(a4_9, Literal(2L)), 2)
+    val a4_10 = NonFoldableLiteral.create(Seq(null, 1L, 2L, 3L), 
ArrayType(LongType))
+    checkEvaluation(ArrayBinarySearch(a4_10, Literal(null, LongType)), null)
+    val a4_11 = NonFoldableLiteral.create(Seq(1L, 3L, 4L),
+      ArrayType(LongType, containsNull = false))
+    checkEvaluation(ArrayBinarySearch(a4_11, Literal(2L, LongType)), -2)
+
+    // float
+    // float foldable
     val a5_0 = Literal.create(Seq(1.0F, 2.0F, 3.0F), ArrayType(FloatType, 
containsNull = false))
     checkEvaluation(ArrayBinarySearch(a5_0, Literal(3.0F)), 2)
     val a5_1 = Literal.create(Seq(null, 1.0F, 2.0F, 3.0F), 
ArrayType(FloatType))
@@ -193,7 +307,30 @@ class CollectionExpressionsSuite
     checkEvaluation(ArrayBinarySearch(a5_2, Literal(null, FloatType)), null)
     val a5_3 = Literal.create(Seq(1.0F, 2.0F, 3.0F), ArrayType(FloatType, 
containsNull = false))
     checkEvaluation(ArrayBinarySearch(a5_3, Literal(1.1F, FloatType)), -2)
-
+    val a5_4 = CreateArray(Seq(Literal(1.0F), Literal(2.0F), Literal(3.0F)))
+    checkEvaluation(ArrayBinarySearch(a5_4, Literal(3.0F)), 2)
+    val a5_5 = CreateArray(Seq(Literal(null, FloatType),
+      Literal(1.0F), Literal(2.0F), Literal(3.0F)))
+    checkEvaluation(ArrayBinarySearch(a5_5, Literal(1.0F)), 1)
+    val a5_6 = CreateArray(Seq(Literal(null, FloatType),
+      Literal(1.0F), Literal(2.0F), Literal(3.0F)))
+    checkEvaluation(ArrayBinarySearch(a5_6, Literal(null, FloatType)), null)
+    val a5_7 = CreateArray(Seq(Literal(1.0F), Literal(2.0F), Literal(3.0F)))
+    checkEvaluation(ArrayBinarySearch(a5_7, Literal(1.1F, FloatType)), -2)
+    // float non-foldable
+    val a5_8 = NonFoldableLiteral.create(Seq(1.0F, 2.0F, 3.0F),
+      ArrayType(FloatType, containsNull = false))
+    checkEvaluation(ArrayBinarySearch(a5_8, Literal(3.0F)), 2)
+    val a5_9 = NonFoldableLiteral.create(Seq(null, 1.0F, 2.0F, 3.0F), 
ArrayType(FloatType))
+    checkEvaluation(ArrayBinarySearch(a5_9, Literal(1.0F)), 1)
+    val a5_10 = NonFoldableLiteral.create(Seq(null, 1.0F, 2.0F, 3.0F), 
ArrayType(FloatType))
+    checkEvaluation(ArrayBinarySearch(a5_10, Literal(null, FloatType)), null)
+    val a5_11 = NonFoldableLiteral.create(Seq(1.0F, 2.0F, 3.0F),
+      ArrayType(FloatType, containsNull = false))
+    checkEvaluation(ArrayBinarySearch(a5_11, Literal(1.1F, FloatType)), -2)
+
+    // double
+    // double foldable
     val a6_0 = Literal.create(Seq(1.0d, 2.0d, 3.0d), ArrayType(DoubleType, 
containsNull = false))
     checkEvaluation(ArrayBinarySearch(a6_0, Literal(1.0d)), 0)
     val a6_1 = Literal.create(Seq(null, 1.0d, 2.0d, 3.0d), 
ArrayType(DoubleType))
@@ -202,8 +339,30 @@ class CollectionExpressionsSuite
     checkEvaluation(ArrayBinarySearch(a6_2, Literal(null, DoubleType)), null)
     val a6_3 = Literal.create(Seq(1.0d, 2.0d, 3.0d), ArrayType(DoubleType, 
containsNull = false))
     checkEvaluation(ArrayBinarySearch(a6_3, Literal(1.1d, DoubleType)), -2)
+    val a6_4 = CreateArray(Seq(Literal(1.0d), Literal(2.0d), Literal(3.0d)))
+    checkEvaluation(ArrayBinarySearch(a6_4, Literal(1.0d)), 0)
+    val a6_5 = CreateArray(Seq(Literal(null, DoubleType),
+      Literal(1.0d), Literal(2.0d), Literal(3.0d)))
+    checkEvaluation(ArrayBinarySearch(a6_5, Literal(1.0d)), 1)
+    val a6_6 = CreateArray(Seq(Literal(null, DoubleType),
+      Literal(1.0d), Literal(2.0d), Literal(3.0d)))
+    checkEvaluation(ArrayBinarySearch(a6_6, Literal(null, DoubleType)), null)
+    val a6_7 = CreateArray(Seq(Literal(1.0d), Literal(2.0d), Literal(3.0d)))
+    checkEvaluation(ArrayBinarySearch(a6_7, Literal(1.1d, DoubleType)), -2)
+    // double non-foldable
+    val a6_8 = NonFoldableLiteral.create(Seq(1.0d, 2.0d, 3.0d),
+      ArrayType(DoubleType, containsNull = false))
+    checkEvaluation(ArrayBinarySearch(a6_8, Literal(1.0d)), 0)
+    val a6_9 = NonFoldableLiteral.create(Seq(null, 1.0d, 2.0d, 3.0d), 
ArrayType(DoubleType))
+    checkEvaluation(ArrayBinarySearch(a6_9, Literal(1.0d)), 1)
+    val a6_10 = NonFoldableLiteral.create(Seq(null, 1.0d, 2.0d, 3.0d), 
ArrayType(DoubleType))
+    checkEvaluation(ArrayBinarySearch(a6_10, Literal(null, DoubleType)), null)
+    val a6_11 = NonFoldableLiteral.create(Seq(1.0d, 2.0d, 3.0d),
+      ArrayType(DoubleType, containsNull = false))
+    checkEvaluation(ArrayBinarySearch(a6_11, Literal(1.1d, DoubleType)), -2)
 
     // string
+    // string foldable
     val a7_0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, 
containsNull = false))
     checkEvaluation(ArrayBinarySearch(a7_0, Literal("a")), 0)
     val a7_1 = Literal.create(Seq(null, "a", "b", "c"), ArrayType(StringType))
@@ -212,6 +371,27 @@ class CollectionExpressionsSuite
     checkEvaluation(ArrayBinarySearch(a7_2, Literal(null, StringType)), null)
     val a7_3 = Literal.create(Seq("a", "c", "d"), ArrayType(StringType, 
containsNull = false))
     checkEvaluation(ArrayBinarySearch(a7_3, 
Literal(UTF8String.fromString("b"), StringType)), -2)
+    val a7_4 = CreateArray(Seq(Literal("a"), Literal("b"), Literal("c")))
+    checkEvaluation(ArrayBinarySearch(a7_4, Literal("a")), 0)
+    val a7_5 = CreateArray(Seq(Literal(null, StringType),
+      Literal("a"), Literal("b"), Literal("c")))
+    checkEvaluation(ArrayBinarySearch(a7_5, Literal("c")), 3)
+    val a7_6 = CreateArray(Seq(Literal(null, StringType),
+      Literal("a"), Literal("b"), Literal("c")))
+    checkEvaluation(ArrayBinarySearch(a7_6, Literal(null, StringType)), null)
+    val a7_7 = CreateArray(Seq(Literal("a"), Literal("c"), Literal("d")))
+    checkEvaluation(ArrayBinarySearch(a7_7, 
Literal(UTF8String.fromString("b"), StringType)), -2)
+    // string non-foldable
+    val a7_8 = NonFoldableLiteral.create(Seq("a", "b", "c"),
+      ArrayType(StringType, containsNull = false))
+    checkEvaluation(ArrayBinarySearch(a7_8, Literal("a")), 0)
+    val a7_9 = NonFoldableLiteral.create(Seq(null, "a", "b", "c"), 
ArrayType(StringType))
+    checkEvaluation(ArrayBinarySearch(a7_9, Literal("c")), 3)
+    val a7_10 = NonFoldableLiteral.create(Seq(null, "a", "b", "c"), 
ArrayType(StringType))
+    checkEvaluation(ArrayBinarySearch(a7_10, Literal(null, StringType)), null)
+    val a7_11 = NonFoldableLiteral.create(Seq("a", "c", "d"),
+      ArrayType(StringType, containsNull = false))
+    checkEvaluation(ArrayBinarySearch(a7_11, 
Literal(UTF8String.fromString("b"), StringType)), -2)
   }
 
   test("MapEntries") {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to