This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 32b054c40602 [SPARK-49203][SQL] Add expression for 
`java.util.Arrays.binarySearch`
32b054c40602 is described below

commit 32b054c40602c7355176903fa32224774f0c1bec
Author: panbingkun <[email protected]>
AuthorDate: Tue Sep 3 14:47:47 2024 +0800

    [SPARK-49203][SQL] Add expression for `java.util.Arrays.binarySearch`
    
    ### What changes were proposed in this pull request?
    The pr aims to an expression `array_binary_search` for 
`java.util.Arrays.binarySearch`.
    
    ### Why are the changes needed?
    We can use it to implement `histogram plot` in the client side (no longer 
need to depend on mllib's `Bucketizer`.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Add new UT.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #47741 from panbingkun/SPARK-49203.
    
    Authored-by: panbingkun <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../catalyst/expressions/ArrayExpressionUtils.java | 176 +++++++++++++++++++++
 .../sql/catalyst/analysis/FunctionRegistry.scala   |   1 +
 .../expressions/collectionOperations.scala         | 136 ++++++++++++++++
 .../expressions/CollectionExpressionsSuite.scala   |  79 +++++++++
 4 files changed, 392 insertions(+)

diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ArrayExpressionUtils.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ArrayExpressionUtils.java
new file mode 100644
index 000000000000..ff6525acbe53
--- /dev/null
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ArrayExpressionUtils.java
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions;
+
+import java.util.Arrays;
+import java.util.Comparator;
+
+import org.apache.spark.sql.catalyst.util.ArrayData;
+import org.apache.spark.sql.catalyst.util.SQLOrderingUtil;
+import org.apache.spark.sql.types.ByteType$;
+import org.apache.spark.sql.types.BooleanType$;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DoubleType$;
+import org.apache.spark.sql.types.FloatType$;
+import org.apache.spark.sql.types.IntegerType$;
+import org.apache.spark.sql.types.LongType$;
+import org.apache.spark.sql.types.ShortType$;
+
+public class ArrayExpressionUtils {
+
+  private static final Comparator<Object> booleanComp = (o1, o2) -> {
+    if (o1 == null && o2 == null) {
+      return 0;
+    } else if (o1 == null) {
+      return -1;
+    } else if (o2 == null) {
+      return 1;
+    }
+    boolean c1 = (Boolean) o1, c2 = (Boolean) o2;
+    return c1 == c2 ? 0 : (c1 ? 1 : -1);
+  };
+
+  private static final Comparator<Object> byteComp = (o1, o2) -> {
+    if (o1 == null && o2 == null) {
+      return 0;
+    } else if (o1 == null) {
+      return -1;
+    } else if (o2 == null) {
+      return 1;
+    }
+    byte c1 = (Byte) o1, c2 = (Byte) o2;
+    return Byte.compare(c1, c2);
+  };
+
+  private static final Comparator<Object> shortComp = (o1, o2) -> {
+    if (o1 == null && o2 == null) {
+      return 0;
+    } else if (o1 == null) {
+      return -1;
+    } else if (o2 == null) {
+      return 1;
+    }
+    short c1 = (Short) o1, c2 = (Short) o2;
+    return Short.compare(c1, c2);
+  };
+
+  private static final Comparator<Object> integerComp = (o1, o2) -> {
+    if (o1 == null && o2 == null) {
+      return 0;
+    } else if (o1 == null) {
+      return -1;
+    } else if (o2 == null) {
+      return 1;
+    }
+    int c1 = (Integer) o1, c2 = (Integer) o2;
+    return Integer.compare(c1, c2);
+  };
+
+  private static final Comparator<Object> longComp = (o1, o2) -> {
+    if (o1 == null && o2 == null) {
+      return 0;
+    } else if (o1 == null) {
+      return -1;
+    } else if (o2 == null) {
+      return 1;
+    }
+    long c1 = (Long) o1, c2 = (Long) o2;
+    return Long.compare(c1, c2);
+  };
+
+  private static final Comparator<Object> floatComp = (o1, o2) -> {
+    if (o1 == null && o2 == null) {
+      return 0;
+    } else if (o1 == null) {
+      return -1;
+    } else if (o2 == null) {
+      return 1;
+    }
+    float c1 = (Float) o1, c2 = (Float) o2;
+    return SQLOrderingUtil.compareFloats(c1, c2);
+  };
+
+  private static final Comparator<Object> doubleComp = (o1, o2) -> {
+    if (o1 == null && o2 == null) {
+      return 0;
+    } else if (o1 == null) {
+      return -1;
+    } else if (o2 == null) {
+      return 1;
+    }
+    double c1 = (Double) o1, c2 = (Double) o2;
+    return SQLOrderingUtil.compareDoubles(c1, c2);
+  };
+
+  public static int binarySearchNullSafe(ArrayData data, Boolean value) {
+    return Arrays.binarySearch(data.toObjectArray(BooleanType$.MODULE$), 
value, booleanComp);
+  }
+
+  public static int binarySearch(ArrayData data, byte value) {
+    return Arrays.binarySearch(data.toByteArray(), value);
+  }
+
+  public static int binarySearchNullSafe(ArrayData data, Byte value) {
+    return Arrays.binarySearch(data.toObjectArray(ByteType$.MODULE$), value, 
byteComp);
+  }
+
+  public static int binarySearch(ArrayData data, short value) {
+    return Arrays.binarySearch(data.toShortArray(), value);
+  }
+
+  public static int binarySearchNullSafe(ArrayData data, Short value) {
+    return Arrays.binarySearch(data.toObjectArray(ShortType$.MODULE$), value, 
shortComp);
+  }
+
+  public static int binarySearch(ArrayData data, int value) {
+    return Arrays.binarySearch(data.toIntArray(), value);
+  }
+
+  public static int binarySearchNullSafe(ArrayData data, Integer value) {
+    return Arrays.binarySearch(data.toObjectArray(IntegerType$.MODULE$), 
value, integerComp);
+  }
+
+  public static int binarySearch(ArrayData data, long value) {
+    return Arrays.binarySearch(data.toLongArray(), value);
+  }
+
+  public static int binarySearchNullSafe(ArrayData data, Long value) {
+    return Arrays.binarySearch(data.toObjectArray(LongType$.MODULE$), value, 
longComp);
+  }
+
+  public static int binarySearch(ArrayData data, float value) {
+    return Arrays.binarySearch(data.toFloatArray(), value);
+  }
+
+  public static int binarySearchNullSafe(ArrayData data, Float value) {
+    return Arrays.binarySearch(data.toObjectArray(FloatType$.MODULE$), value, 
floatComp);
+  }
+
+  public static int binarySearch(ArrayData data, double value) {
+    return Arrays.binarySearch(data.toDoubleArray(), value);
+  }
+
+  public static int binarySearchNullSafe(ArrayData data, Double value) {
+    return Arrays.binarySearch(data.toObjectArray(DoubleType$.MODULE$), value, 
doubleComp);
+  }
+
+  public static int binarySearch(
+    DataType elementType, Comparator<Object> comp, ArrayData data, Object 
value) {
+    Object[] array = data.toObjectArray(elementType);
+    return Arrays.binarySearch(array, value, comp);
+  }
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index e851d5f2b91c..dfe1bd12bb7f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -919,6 +919,7 @@ object FunctionRegistry {
   registerInternalExpression[EWM]("ewm")
   registerInternalExpression[NullIndex]("null_index")
   registerInternalExpression[CastTimestampNTZToLong]("timestamp_ntz_to_long")
+  registerInternalExpression[ArrayBinarySearch]("array_binary_search")
 
   private def makeExprInfoForVirtualOperator(name: String, usage: String): 
ExpressionInfo = {
     new ExpressionInfo(
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 516f521bc964..375a2bde5923 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -29,6 +29,7 @@ import 
org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, Un
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
 import org.apache.spark.sql.catalyst.trees.{BinaryLike, UnaryLike}
 import org.apache.spark.sql.catalyst.trees.TreePattern.{ARRAYS_ZIP, CONCAT, 
TreePattern}
 import org.apache.spark.sql.catalyst.types.{DataTypeUtils, PhysicalDataType, 
PhysicalIntegralType}
@@ -1518,6 +1519,141 @@ case class ArrayContains(left: Expression, right: 
Expression)
     copy(left = newLeft, right = newRight)
 }
 
+/**
+ * Searches the specified array for the specified object using the binary 
search algorithm.
+ *
+ * NOTE: The input array must be in ascending order before calling this 
method; if the array is
+ * not sorted, the results are undefined.
+ *
+ * This expression is dedicated only for PySpark and Spark-ML.
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(array, value) - Return index (0-based) of the search value, 
" +
+    "if it is contained in the array; otherwise, (-<insertion point> - 1).",
+  examples = """
+    Examples:
+      > SELECT _FUNC_(array(1, 2, 3), 2);
+       1
+      > SELECT _FUNC_(array(null, 1, 2, 3), 2);
+       2
+      > SELECT _FUNC_(array(1.0F, 2.0F, 3.0F), 1.1F);
+       -2
+  """,
+  group = "array_funcs",
+  since = "4.0.0")
+case class ArrayBinarySearch(array: Expression, value: Expression)
+  extends BinaryExpression
+  with ImplicitCastInputTypes
+  with NullIntolerant
+  with RuntimeReplaceable
+  with QueryErrorsBase {
+
+  override def left: Expression = array
+  override def right: Expression = value
+  override def dataType: DataType = IntegerType
+
+  override def inputTypes: Seq[AbstractDataType] = {
+    (left.dataType, right.dataType) match {
+      case (_, NullType) => Seq.empty
+      case (ArrayType(e1, hasNull), e2) =>
+        TypeCoercion.findTightestCommonType(e1, e2) match {
+          case Some(dt) => Seq(ArrayType(dt, hasNull), dt)
+          case _ => Seq.empty
+        }
+      case _ => Seq.empty
+    }
+  }
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    (left.dataType, right.dataType) match {
+      case (NullType, _) | (_, NullType) =>
+        DataTypeMismatch(
+          errorSubClass = "NULL_TYPE",
+          Map("functionName" -> toSQLId(prettyName)))
+      case (t, _) if !ArrayType.acceptsType(t) =>
+        DataTypeMismatch(
+          errorSubClass = "UNEXPECTED_INPUT_TYPE",
+          messageParameters = Map(
+            "paramIndex" -> ordinalNumber(0),
+            "requiredType" -> toSQLType(ArrayType),
+            "inputSql" -> toSQLExpr(left),
+            "inputType" -> toSQLType(left.dataType))
+        )
+      case (ArrayType(e1, _), e2) if DataTypeUtils.sameType(e1, e2) =>
+        TypeUtils.checkForOrderingExpr(e2, prettyName)
+      case _ =>
+        DataTypeMismatch(
+          errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES",
+          messageParameters = Map(
+            "functionName" -> toSQLId(prettyName),
+            "dataType" -> toSQLType(ArrayType),
+            "leftType" -> toSQLType(left.dataType),
+            "rightType" -> toSQLType(right.dataType)
+          )
+        )
+    }
+  }
+
+  @transient private lazy val elementType: DataType =
+    array.dataType.asInstanceOf[ArrayType].elementType
+  @transient private lazy val resultArrayElementNullable: Boolean =
+    array.dataType.asInstanceOf[ArrayType].containsNull
+
+  @transient private lazy val isPrimitiveType: Boolean = 
CodeGenerator.isPrimitiveType(elementType)
+  @transient private lazy val canPerformFastBinarySearch: Boolean = 
isPrimitiveType &&
+    elementType != BooleanType && !resultArrayElementNullable
+
+  @transient private lazy val comp: Comparator[Any] = new Comparator[Any] with 
Serializable {
+    private val ordering = array.dataType match {
+      case _ @ ArrayType(n, _) =>
+        PhysicalDataType.ordering(n)
+    }
+
+    override def compare(o1: Any, o2: Any): Int =
+      (o1, o2) match {
+        case (null, null) => 0
+        case (null, _) => 1
+        case (_, null) => -1
+        case _ => ordering.compare(o1, o2)
+      }
+  }
+
+  @transient private lazy val elementObjectType = ObjectType(classOf[DataType])
+  @transient private lazy val  comparatorObjectType = 
ObjectType(classOf[Comparator[Object]])
+  override def replacement: Expression =
+    if (canPerformFastBinarySearch) {
+      StaticInvoke(
+        classOf[ArrayExpressionUtils],
+        IntegerType,
+        "binarySearch",
+        Seq(array, value),
+        inputTypes)
+    } else if (isPrimitiveType) {
+      StaticInvoke(
+        classOf[ArrayExpressionUtils],
+        IntegerType,
+        "binarySearchNullSafe",
+        Seq(array, value),
+        inputTypes)
+    } else {
+      StaticInvoke(
+        classOf[ArrayExpressionUtils],
+        IntegerType,
+        "binarySearch",
+        Seq(Literal(elementType, elementObjectType),
+          Literal(comp, comparatorObjectType),
+          array,
+          value),
+        elementObjectType +: comparatorObjectType +: inputTypes)
+  }
+
+  override def prettyName: String = "array_binary_search"
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): ArrayBinarySearch =
+    copy(array = newLeft, value = newRight)
+}
+
 trait ArrayPendBase extends RuntimeReplaceable
   with ImplicitCastInputTypes with BinaryLike[Expression] with QueryErrorsBase 
{
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index d14118eb3f1d..c7e995feb5ed 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -135,6 +135,85 @@ class CollectionExpressionsSuite
     checkEvaluation(ArrayContains(MapKeys(m1), Literal("a")), null)
   }
 
+  test("ArrayBinarySearch") {
+    // primitive type: boolean、byte、short、int、long、float、double
+    val a0_0 = Literal.create(Seq(false, true),
+      ArrayType(BooleanType, containsNull = false))
+    checkEvaluation(ArrayBinarySearch(a0_0, Literal(true)), 1)
+    val a0_1 = Literal.create(Seq(null, false, true), ArrayType(BooleanType))
+    checkEvaluation(ArrayBinarySearch(a0_1, Literal(false)), 1)
+    val a0_2 = Literal.create(Seq(null, false, true), ArrayType(BooleanType))
+    checkEvaluation(ArrayBinarySearch(a0_2, Literal(null, BooleanType)), null)
+
+    val a1_0 = Literal.create(Seq(1.toByte, 2.toByte, 3.toByte),
+      ArrayType(ByteType, containsNull = false))
+    checkEvaluation(ArrayBinarySearch(a1_0, Literal(3.toByte)), 2)
+    val a1_1 = Literal.create(Seq(null, 1.toByte, 2.toByte, 3.toByte), 
ArrayType(ByteType))
+    checkEvaluation(ArrayBinarySearch(a1_1, Literal(1.toByte)), 1)
+    val a1_2 = Literal.create(Seq(null, 1.toByte, 2.toByte, 3.toByte), 
ArrayType(ByteType))
+    checkEvaluation(ArrayBinarySearch(a1_2, Literal(null, ByteType)), null)
+    val a1_3 = Literal.create(Seq(1.toByte, 3.toByte, 4.toByte),
+      ArrayType(ByteType, containsNull = false))
+    checkEvaluation(ArrayBinarySearch(a1_3, Literal(2.toByte, ByteType)), -2)
+
+    val a2_0 = Literal.create(Seq(1.toShort, 2.toShort, 3.toShort),
+      ArrayType(ShortType, containsNull = false))
+    checkEvaluation(ArrayBinarySearch(a2_0, Literal(1.toShort)), 0)
+    val a2_1 = Literal.create(Seq(null, 1.toShort, 2.toShort, 3.toShort), 
ArrayType(ShortType))
+    checkEvaluation(ArrayBinarySearch(a2_1, Literal(2.toShort)), 2)
+    val a2_2 = Literal.create(Seq(null, 1.toShort, 2.toShort, 3.toShort), 
ArrayType(ShortType))
+    checkEvaluation(ArrayBinarySearch(a2_2, Literal(null, ShortType)), null)
+    val a2_3 = Literal.create(Seq(1.toShort, 3.toShort, 4.toShort),
+      ArrayType(ShortType, containsNull = false))
+    checkEvaluation(ArrayBinarySearch(a2_3, Literal(2.toShort, ShortType)), -2)
+
+    val a3_0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, 
containsNull = false))
+    checkEvaluation(ArrayBinarySearch(a3_0, Literal(2)), 1)
+    val a3_1 = Literal.create(Seq(null, 1, 2, 3), ArrayType(IntegerType))
+    checkEvaluation(ArrayBinarySearch(a3_1, Literal(2)), 2)
+    val a3_2 = Literal.create(Seq(null, 1, 2, 3), ArrayType(IntegerType))
+    checkEvaluation(ArrayBinarySearch(a3_2, Literal(null, IntegerType)), null)
+    val a3_3 = Literal.create(Seq(1, 3, 4), ArrayType(IntegerType, 
containsNull = false))
+    checkEvaluation(ArrayBinarySearch(a3_3, Literal(2, IntegerType)), -2)
+
+    val a4_0 = Literal.create(Seq(1L, 2L, 3L), ArrayType(LongType, 
containsNull = false))
+    checkEvaluation(ArrayBinarySearch(a4_0, Literal(2L)), 1)
+    val a4_1 = Literal.create(Seq(null, 1L, 2L, 3L), ArrayType(LongType))
+    checkEvaluation(ArrayBinarySearch(a4_1, Literal(2L)), 2)
+    val a4_2 = Literal.create(Seq(null, 1L, 2L, 3L), ArrayType(LongType))
+    checkEvaluation(ArrayBinarySearch(a4_2, Literal(null, LongType)), null)
+    val a4_3 = Literal.create(Seq(1L, 3L, 4L), ArrayType(LongType, 
containsNull = false))
+    checkEvaluation(ArrayBinarySearch(a4_3, Literal(2L, LongType)), -2)
+
+    val a5_0 = Literal.create(Seq(1.0F, 2.0F, 3.0F), ArrayType(FloatType, 
containsNull = false))
+    checkEvaluation(ArrayBinarySearch(a5_0, Literal(3.0F)), 2)
+    val a5_1 = Literal.create(Seq(null, 1.0F, 2.0F, 3.0F), 
ArrayType(FloatType))
+    checkEvaluation(ArrayBinarySearch(a5_1, Literal(1.0F)), 1)
+    val a5_2 = Literal.create(Seq(null, 1.0F, 2.0F, 3.0F), 
ArrayType(FloatType))
+    checkEvaluation(ArrayBinarySearch(a5_2, Literal(null, FloatType)), null)
+    val a5_3 = Literal.create(Seq(1.0F, 2.0F, 3.0F), ArrayType(FloatType, 
containsNull = false))
+    checkEvaluation(ArrayBinarySearch(a5_3, Literal(1.1F, FloatType)), -2)
+
+    val a6_0 = Literal.create(Seq(1.0d, 2.0d, 3.0d), ArrayType(DoubleType, 
containsNull = false))
+    checkEvaluation(ArrayBinarySearch(a6_0, Literal(1.0d)), 0)
+    val a6_1 = Literal.create(Seq(null, 1.0d, 2.0d, 3.0d), 
ArrayType(DoubleType))
+    checkEvaluation(ArrayBinarySearch(a6_1, Literal(1.0d)), 1)
+    val a6_2 = Literal.create(Seq(null, 1.0d, 2.0d, 3.0d), 
ArrayType(DoubleType))
+    checkEvaluation(ArrayBinarySearch(a6_2, Literal(null, DoubleType)), null)
+    val a6_3 = Literal.create(Seq(1.0d, 2.0d, 3.0d), ArrayType(DoubleType, 
containsNull = false))
+    checkEvaluation(ArrayBinarySearch(a6_3, Literal(1.1d, DoubleType)), -2)
+
+    // string
+    val a7_0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, 
containsNull = false))
+    checkEvaluation(ArrayBinarySearch(a7_0, Literal("a")), 0)
+    val a7_1 = Literal.create(Seq(null, "a", "b", "c"), ArrayType(StringType))
+    checkEvaluation(ArrayBinarySearch(a7_1, Literal("c")), 3)
+    val a7_2 = Literal.create(Seq(null, "a", "b", "c"), ArrayType(StringType))
+    checkEvaluation(ArrayBinarySearch(a7_2, Literal(null, StringType)), null)
+    val a7_3 = Literal.create(Seq("a", "c", "d"), ArrayType(StringType, 
containsNull = false))
+    checkEvaluation(ArrayBinarySearch(a7_3, 
Literal(UTF8String.fromString("b"), StringType)), -2)
+  }
+
   test("MapEntries") {
     def r(values: Any*): InternalRow = create_row(values: _*)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to