This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 32b054c40602 [SPARK-49203][SQL] Add expression for
`java.util.Arrays.binarySearch`
32b054c40602 is described below
commit 32b054c40602c7355176903fa32224774f0c1bec
Author: panbingkun <[email protected]>
AuthorDate: Tue Sep 3 14:47:47 2024 +0800
[SPARK-49203][SQL] Add expression for `java.util.Arrays.binarySearch`
### What changes were proposed in this pull request?
The pr aims to an expression `array_binary_search` for
`java.util.Arrays.binarySearch`.
### Why are the changes needed?
We can use it to implement `histogram plot` in the client side (no longer
need to depend on mllib's `Bucketizer`.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Add new UT.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #47741 from panbingkun/SPARK-49203.
Authored-by: panbingkun <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../catalyst/expressions/ArrayExpressionUtils.java | 176 +++++++++++++++++++++
.../sql/catalyst/analysis/FunctionRegistry.scala | 1 +
.../expressions/collectionOperations.scala | 136 ++++++++++++++++
.../expressions/CollectionExpressionsSuite.scala | 79 +++++++++
4 files changed, 392 insertions(+)
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ArrayExpressionUtils.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ArrayExpressionUtils.java
new file mode 100644
index 000000000000..ff6525acbe53
--- /dev/null
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ArrayExpressionUtils.java
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.expressions;
+
+import java.util.Arrays;
+import java.util.Comparator;
+
+import org.apache.spark.sql.catalyst.util.ArrayData;
+import org.apache.spark.sql.catalyst.util.SQLOrderingUtil;
+import org.apache.spark.sql.types.ByteType$;
+import org.apache.spark.sql.types.BooleanType$;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DoubleType$;
+import org.apache.spark.sql.types.FloatType$;
+import org.apache.spark.sql.types.IntegerType$;
+import org.apache.spark.sql.types.LongType$;
+import org.apache.spark.sql.types.ShortType$;
+
+public class ArrayExpressionUtils {
+
+ private static final Comparator<Object> booleanComp = (o1, o2) -> {
+ if (o1 == null && o2 == null) {
+ return 0;
+ } else if (o1 == null) {
+ return -1;
+ } else if (o2 == null) {
+ return 1;
+ }
+ boolean c1 = (Boolean) o1, c2 = (Boolean) o2;
+ return c1 == c2 ? 0 : (c1 ? 1 : -1);
+ };
+
+ private static final Comparator<Object> byteComp = (o1, o2) -> {
+ if (o1 == null && o2 == null) {
+ return 0;
+ } else if (o1 == null) {
+ return -1;
+ } else if (o2 == null) {
+ return 1;
+ }
+ byte c1 = (Byte) o1, c2 = (Byte) o2;
+ return Byte.compare(c1, c2);
+ };
+
+ private static final Comparator<Object> shortComp = (o1, o2) -> {
+ if (o1 == null && o2 == null) {
+ return 0;
+ } else if (o1 == null) {
+ return -1;
+ } else if (o2 == null) {
+ return 1;
+ }
+ short c1 = (Short) o1, c2 = (Short) o2;
+ return Short.compare(c1, c2);
+ };
+
+ private static final Comparator<Object> integerComp = (o1, o2) -> {
+ if (o1 == null && o2 == null) {
+ return 0;
+ } else if (o1 == null) {
+ return -1;
+ } else if (o2 == null) {
+ return 1;
+ }
+ int c1 = (Integer) o1, c2 = (Integer) o2;
+ return Integer.compare(c1, c2);
+ };
+
+ private static final Comparator<Object> longComp = (o1, o2) -> {
+ if (o1 == null && o2 == null) {
+ return 0;
+ } else if (o1 == null) {
+ return -1;
+ } else if (o2 == null) {
+ return 1;
+ }
+ long c1 = (Long) o1, c2 = (Long) o2;
+ return Long.compare(c1, c2);
+ };
+
+ private static final Comparator<Object> floatComp = (o1, o2) -> {
+ if (o1 == null && o2 == null) {
+ return 0;
+ } else if (o1 == null) {
+ return -1;
+ } else if (o2 == null) {
+ return 1;
+ }
+ float c1 = (Float) o1, c2 = (Float) o2;
+ return SQLOrderingUtil.compareFloats(c1, c2);
+ };
+
+ private static final Comparator<Object> doubleComp = (o1, o2) -> {
+ if (o1 == null && o2 == null) {
+ return 0;
+ } else if (o1 == null) {
+ return -1;
+ } else if (o2 == null) {
+ return 1;
+ }
+ double c1 = (Double) o1, c2 = (Double) o2;
+ return SQLOrderingUtil.compareDoubles(c1, c2);
+ };
+
+ public static int binarySearchNullSafe(ArrayData data, Boolean value) {
+ return Arrays.binarySearch(data.toObjectArray(BooleanType$.MODULE$),
value, booleanComp);
+ }
+
+ public static int binarySearch(ArrayData data, byte value) {
+ return Arrays.binarySearch(data.toByteArray(), value);
+ }
+
+ public static int binarySearchNullSafe(ArrayData data, Byte value) {
+ return Arrays.binarySearch(data.toObjectArray(ByteType$.MODULE$), value,
byteComp);
+ }
+
+ public static int binarySearch(ArrayData data, short value) {
+ return Arrays.binarySearch(data.toShortArray(), value);
+ }
+
+ public static int binarySearchNullSafe(ArrayData data, Short value) {
+ return Arrays.binarySearch(data.toObjectArray(ShortType$.MODULE$), value,
shortComp);
+ }
+
+ public static int binarySearch(ArrayData data, int value) {
+ return Arrays.binarySearch(data.toIntArray(), value);
+ }
+
+ public static int binarySearchNullSafe(ArrayData data, Integer value) {
+ return Arrays.binarySearch(data.toObjectArray(IntegerType$.MODULE$),
value, integerComp);
+ }
+
+ public static int binarySearch(ArrayData data, long value) {
+ return Arrays.binarySearch(data.toLongArray(), value);
+ }
+
+ public static int binarySearchNullSafe(ArrayData data, Long value) {
+ return Arrays.binarySearch(data.toObjectArray(LongType$.MODULE$), value,
longComp);
+ }
+
+ public static int binarySearch(ArrayData data, float value) {
+ return Arrays.binarySearch(data.toFloatArray(), value);
+ }
+
+ public static int binarySearchNullSafe(ArrayData data, Float value) {
+ return Arrays.binarySearch(data.toObjectArray(FloatType$.MODULE$), value,
floatComp);
+ }
+
+ public static int binarySearch(ArrayData data, double value) {
+ return Arrays.binarySearch(data.toDoubleArray(), value);
+ }
+
+ public static int binarySearchNullSafe(ArrayData data, Double value) {
+ return Arrays.binarySearch(data.toObjectArray(DoubleType$.MODULE$), value,
doubleComp);
+ }
+
+ public static int binarySearch(
+ DataType elementType, Comparator<Object> comp, ArrayData data, Object
value) {
+ Object[] array = data.toObjectArray(elementType);
+ return Arrays.binarySearch(array, value, comp);
+ }
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index e851d5f2b91c..dfe1bd12bb7f 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -919,6 +919,7 @@ object FunctionRegistry {
registerInternalExpression[EWM]("ewm")
registerInternalExpression[NullIndex]("null_index")
registerInternalExpression[CastTimestampNTZToLong]("timestamp_ntz_to_long")
+ registerInternalExpression[ArrayBinarySearch]("array_binary_search")
private def makeExprInfoForVirtualOperator(name: String, usage: String):
ExpressionInfo = {
new ExpressionInfo(
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 516f521bc964..375a2bde5923 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -29,6 +29,7 @@ import
org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, Un
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.trees.{BinaryLike, UnaryLike}
import org.apache.spark.sql.catalyst.trees.TreePattern.{ARRAYS_ZIP, CONCAT,
TreePattern}
import org.apache.spark.sql.catalyst.types.{DataTypeUtils, PhysicalDataType,
PhysicalIntegralType}
@@ -1518,6 +1519,141 @@ case class ArrayContains(left: Expression, right:
Expression)
copy(left = newLeft, right = newRight)
}
+/**
+ * Searches the specified array for the specified object using the binary
search algorithm.
+ *
+ * NOTE: The input array must be in ascending order before calling this
method; if the array is
+ * not sorted, the results are undefined.
+ *
+ * This expression is dedicated only for PySpark and Spark-ML.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(array, value) - Return index (0-based) of the search value,
" +
+ "if it is contained in the array; otherwise, (-<insertion point> - 1).",
+ examples = """
+ Examples:
+ > SELECT _FUNC_(array(1, 2, 3), 2);
+ 1
+ > SELECT _FUNC_(array(null, 1, 2, 3), 2);
+ 2
+ > SELECT _FUNC_(array(1.0F, 2.0F, 3.0F), 1.1F);
+ -2
+ """,
+ group = "array_funcs",
+ since = "4.0.0")
+case class ArrayBinarySearch(array: Expression, value: Expression)
+ extends BinaryExpression
+ with ImplicitCastInputTypes
+ with NullIntolerant
+ with RuntimeReplaceable
+ with QueryErrorsBase {
+
+ override def left: Expression = array
+ override def right: Expression = value
+ override def dataType: DataType = IntegerType
+
+ override def inputTypes: Seq[AbstractDataType] = {
+ (left.dataType, right.dataType) match {
+ case (_, NullType) => Seq.empty
+ case (ArrayType(e1, hasNull), e2) =>
+ TypeCoercion.findTightestCommonType(e1, e2) match {
+ case Some(dt) => Seq(ArrayType(dt, hasNull), dt)
+ case _ => Seq.empty
+ }
+ case _ => Seq.empty
+ }
+ }
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ (left.dataType, right.dataType) match {
+ case (NullType, _) | (_, NullType) =>
+ DataTypeMismatch(
+ errorSubClass = "NULL_TYPE",
+ Map("functionName" -> toSQLId(prettyName)))
+ case (t, _) if !ArrayType.acceptsType(t) =>
+ DataTypeMismatch(
+ errorSubClass = "UNEXPECTED_INPUT_TYPE",
+ messageParameters = Map(
+ "paramIndex" -> ordinalNumber(0),
+ "requiredType" -> toSQLType(ArrayType),
+ "inputSql" -> toSQLExpr(left),
+ "inputType" -> toSQLType(left.dataType))
+ )
+ case (ArrayType(e1, _), e2) if DataTypeUtils.sameType(e1, e2) =>
+ TypeUtils.checkForOrderingExpr(e2, prettyName)
+ case _ =>
+ DataTypeMismatch(
+ errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES",
+ messageParameters = Map(
+ "functionName" -> toSQLId(prettyName),
+ "dataType" -> toSQLType(ArrayType),
+ "leftType" -> toSQLType(left.dataType),
+ "rightType" -> toSQLType(right.dataType)
+ )
+ )
+ }
+ }
+
+ @transient private lazy val elementType: DataType =
+ array.dataType.asInstanceOf[ArrayType].elementType
+ @transient private lazy val resultArrayElementNullable: Boolean =
+ array.dataType.asInstanceOf[ArrayType].containsNull
+
+ @transient private lazy val isPrimitiveType: Boolean =
CodeGenerator.isPrimitiveType(elementType)
+ @transient private lazy val canPerformFastBinarySearch: Boolean =
isPrimitiveType &&
+ elementType != BooleanType && !resultArrayElementNullable
+
+ @transient private lazy val comp: Comparator[Any] = new Comparator[Any] with
Serializable {
+ private val ordering = array.dataType match {
+ case _ @ ArrayType(n, _) =>
+ PhysicalDataType.ordering(n)
+ }
+
+ override def compare(o1: Any, o2: Any): Int =
+ (o1, o2) match {
+ case (null, null) => 0
+ case (null, _) => 1
+ case (_, null) => -1
+ case _ => ordering.compare(o1, o2)
+ }
+ }
+
+ @transient private lazy val elementObjectType = ObjectType(classOf[DataType])
+ @transient private lazy val comparatorObjectType =
ObjectType(classOf[Comparator[Object]])
+ override def replacement: Expression =
+ if (canPerformFastBinarySearch) {
+ StaticInvoke(
+ classOf[ArrayExpressionUtils],
+ IntegerType,
+ "binarySearch",
+ Seq(array, value),
+ inputTypes)
+ } else if (isPrimitiveType) {
+ StaticInvoke(
+ classOf[ArrayExpressionUtils],
+ IntegerType,
+ "binarySearchNullSafe",
+ Seq(array, value),
+ inputTypes)
+ } else {
+ StaticInvoke(
+ classOf[ArrayExpressionUtils],
+ IntegerType,
+ "binarySearch",
+ Seq(Literal(elementType, elementObjectType),
+ Literal(comp, comparatorObjectType),
+ array,
+ value),
+ elementObjectType +: comparatorObjectType +: inputTypes)
+ }
+
+ override def prettyName: String = "array_binary_search"
+
+ override protected def withNewChildrenInternal(
+ newLeft: Expression, newRight: Expression): ArrayBinarySearch =
+ copy(array = newLeft, value = newRight)
+}
+
trait ArrayPendBase extends RuntimeReplaceable
with ImplicitCastInputTypes with BinaryLike[Expression] with QueryErrorsBase
{
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index d14118eb3f1d..c7e995feb5ed 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -135,6 +135,85 @@ class CollectionExpressionsSuite
checkEvaluation(ArrayContains(MapKeys(m1), Literal("a")), null)
}
+ test("ArrayBinarySearch") {
+ // primitive type: boolean、byte、short、int、long、float、double
+ val a0_0 = Literal.create(Seq(false, true),
+ ArrayType(BooleanType, containsNull = false))
+ checkEvaluation(ArrayBinarySearch(a0_0, Literal(true)), 1)
+ val a0_1 = Literal.create(Seq(null, false, true), ArrayType(BooleanType))
+ checkEvaluation(ArrayBinarySearch(a0_1, Literal(false)), 1)
+ val a0_2 = Literal.create(Seq(null, false, true), ArrayType(BooleanType))
+ checkEvaluation(ArrayBinarySearch(a0_2, Literal(null, BooleanType)), null)
+
+ val a1_0 = Literal.create(Seq(1.toByte, 2.toByte, 3.toByte),
+ ArrayType(ByteType, containsNull = false))
+ checkEvaluation(ArrayBinarySearch(a1_0, Literal(3.toByte)), 2)
+ val a1_1 = Literal.create(Seq(null, 1.toByte, 2.toByte, 3.toByte),
ArrayType(ByteType))
+ checkEvaluation(ArrayBinarySearch(a1_1, Literal(1.toByte)), 1)
+ val a1_2 = Literal.create(Seq(null, 1.toByte, 2.toByte, 3.toByte),
ArrayType(ByteType))
+ checkEvaluation(ArrayBinarySearch(a1_2, Literal(null, ByteType)), null)
+ val a1_3 = Literal.create(Seq(1.toByte, 3.toByte, 4.toByte),
+ ArrayType(ByteType, containsNull = false))
+ checkEvaluation(ArrayBinarySearch(a1_3, Literal(2.toByte, ByteType)), -2)
+
+ val a2_0 = Literal.create(Seq(1.toShort, 2.toShort, 3.toShort),
+ ArrayType(ShortType, containsNull = false))
+ checkEvaluation(ArrayBinarySearch(a2_0, Literal(1.toShort)), 0)
+ val a2_1 = Literal.create(Seq(null, 1.toShort, 2.toShort, 3.toShort),
ArrayType(ShortType))
+ checkEvaluation(ArrayBinarySearch(a2_1, Literal(2.toShort)), 2)
+ val a2_2 = Literal.create(Seq(null, 1.toShort, 2.toShort, 3.toShort),
ArrayType(ShortType))
+ checkEvaluation(ArrayBinarySearch(a2_2, Literal(null, ShortType)), null)
+ val a2_3 = Literal.create(Seq(1.toShort, 3.toShort, 4.toShort),
+ ArrayType(ShortType, containsNull = false))
+ checkEvaluation(ArrayBinarySearch(a2_3, Literal(2.toShort, ShortType)), -2)
+
+ val a3_0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType,
containsNull = false))
+ checkEvaluation(ArrayBinarySearch(a3_0, Literal(2)), 1)
+ val a3_1 = Literal.create(Seq(null, 1, 2, 3), ArrayType(IntegerType))
+ checkEvaluation(ArrayBinarySearch(a3_1, Literal(2)), 2)
+ val a3_2 = Literal.create(Seq(null, 1, 2, 3), ArrayType(IntegerType))
+ checkEvaluation(ArrayBinarySearch(a3_2, Literal(null, IntegerType)), null)
+ val a3_3 = Literal.create(Seq(1, 3, 4), ArrayType(IntegerType,
containsNull = false))
+ checkEvaluation(ArrayBinarySearch(a3_3, Literal(2, IntegerType)), -2)
+
+ val a4_0 = Literal.create(Seq(1L, 2L, 3L), ArrayType(LongType,
containsNull = false))
+ checkEvaluation(ArrayBinarySearch(a4_0, Literal(2L)), 1)
+ val a4_1 = Literal.create(Seq(null, 1L, 2L, 3L), ArrayType(LongType))
+ checkEvaluation(ArrayBinarySearch(a4_1, Literal(2L)), 2)
+ val a4_2 = Literal.create(Seq(null, 1L, 2L, 3L), ArrayType(LongType))
+ checkEvaluation(ArrayBinarySearch(a4_2, Literal(null, LongType)), null)
+ val a4_3 = Literal.create(Seq(1L, 3L, 4L), ArrayType(LongType,
containsNull = false))
+ checkEvaluation(ArrayBinarySearch(a4_3, Literal(2L, LongType)), -2)
+
+ val a5_0 = Literal.create(Seq(1.0F, 2.0F, 3.0F), ArrayType(FloatType,
containsNull = false))
+ checkEvaluation(ArrayBinarySearch(a5_0, Literal(3.0F)), 2)
+ val a5_1 = Literal.create(Seq(null, 1.0F, 2.0F, 3.0F),
ArrayType(FloatType))
+ checkEvaluation(ArrayBinarySearch(a5_1, Literal(1.0F)), 1)
+ val a5_2 = Literal.create(Seq(null, 1.0F, 2.0F, 3.0F),
ArrayType(FloatType))
+ checkEvaluation(ArrayBinarySearch(a5_2, Literal(null, FloatType)), null)
+ val a5_3 = Literal.create(Seq(1.0F, 2.0F, 3.0F), ArrayType(FloatType,
containsNull = false))
+ checkEvaluation(ArrayBinarySearch(a5_3, Literal(1.1F, FloatType)), -2)
+
+ val a6_0 = Literal.create(Seq(1.0d, 2.0d, 3.0d), ArrayType(DoubleType,
containsNull = false))
+ checkEvaluation(ArrayBinarySearch(a6_0, Literal(1.0d)), 0)
+ val a6_1 = Literal.create(Seq(null, 1.0d, 2.0d, 3.0d),
ArrayType(DoubleType))
+ checkEvaluation(ArrayBinarySearch(a6_1, Literal(1.0d)), 1)
+ val a6_2 = Literal.create(Seq(null, 1.0d, 2.0d, 3.0d),
ArrayType(DoubleType))
+ checkEvaluation(ArrayBinarySearch(a6_2, Literal(null, DoubleType)), null)
+ val a6_3 = Literal.create(Seq(1.0d, 2.0d, 3.0d), ArrayType(DoubleType,
containsNull = false))
+ checkEvaluation(ArrayBinarySearch(a6_3, Literal(1.1d, DoubleType)), -2)
+
+ // string
+ val a7_0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType,
containsNull = false))
+ checkEvaluation(ArrayBinarySearch(a7_0, Literal("a")), 0)
+ val a7_1 = Literal.create(Seq(null, "a", "b", "c"), ArrayType(StringType))
+ checkEvaluation(ArrayBinarySearch(a7_1, Literal("c")), 3)
+ val a7_2 = Literal.create(Seq(null, "a", "b", "c"), ArrayType(StringType))
+ checkEvaluation(ArrayBinarySearch(a7_2, Literal(null, StringType)), null)
+ val a7_3 = Literal.create(Seq("a", "c", "d"), ArrayType(StringType,
containsNull = false))
+ checkEvaluation(ArrayBinarySearch(a7_3,
Literal(UTF8String.fromString("b"), StringType)), -2)
+ }
+
test("MapEntries") {
def r(values: Any*): InternalRow = create_row(values: _*)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]