LuciferYang commented on code in PR #47741:
URL: https://github.com/apache/spark/pull/47741#discussion_r1718351015
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala:
##########
@@ -1518,6 +1519,118 @@ case class ArrayContains(left: Expression, right:
Expression)
copy(left = newLeft, right = newRight)
}
+/**
+ * Searches the specified array for the specified object using the binary
search algorithm.
+ * This expression is dedicated only for PySpark and Spark-ML.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(array, value) - Return index (0-based) of the search value,
" +
+ "if it is contained in the array; otherwise, (-<insertion point> - 1).",
+ examples = """
+ Examples:
+ > SELECT _FUNC_(array(1, 2, 3), 2);
+ 1
+ > SELECT _FUNC_(array(null, 1, 2, 3), 2);
+ 2
+ > SELECT _FUNC_(array(1.0F, 2.0F, 3.0F), 1.1F);
+ -2
+ """,
+ group = "array_funcs",
+ since = "4.0.0")
+case class ArrayBinarySearch(array: Expression, value: Expression)
+ extends BinaryExpression
+ with ImplicitCastInputTypes
+ with NullIntolerant
+ with RuntimeReplaceable
+ with QueryErrorsBase {
+
+ override def left: Expression = array
+ override def right: Expression = value
+ override def dataType: DataType = IntegerType
+
+ override def inputTypes: Seq[AbstractDataType] = {
+ (left.dataType, right.dataType) match {
+ case (_, NullType) => Seq.empty
+ case (ArrayType(e1, hasNull), e2) =>
+ TypeCoercion.findTightestCommonType(e1, e2) match {
+ case Some(dt) => Seq(ArrayType(dt, hasNull), dt)
+ case _ => Seq.empty
+ }
+ case _ => Seq.empty
+ }
+ }
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ (left.dataType, right.dataType) match {
+ case (NullType, _) | (_, NullType) =>
+ DataTypeMismatch(
+ errorSubClass = "NULL_TYPE",
+ Map("functionName" -> toSQLId(prettyName)))
+ case (t, _) if !ArrayType.acceptsType(t) =>
+ DataTypeMismatch(
+ errorSubClass = "UNEXPECTED_INPUT_TYPE",
+ messageParameters = Map(
+ "paramIndex" -> ordinalNumber(0),
+ "requiredType" -> toSQLType(ArrayType),
+ "inputSql" -> toSQLExpr(left),
+ "inputType" -> toSQLType(left.dataType))
+ )
+ case (ArrayType(e1, _), e2) if DataTypeUtils.sameType(e1, e2) =>
+ TypeUtils.checkForOrderingExpr(e2, prettyName)
+ case _ =>
+ DataTypeMismatch(
+ errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES",
+ messageParameters = Map(
+ "functionName" -> toSQLId(prettyName),
+ "dataType" -> toSQLType(ArrayType),
+ "leftType" -> toSQLType(left.dataType),
+ "rightType" -> toSQLType(right.dataType)
+ )
+ )
+ }
+ }
+
+ @transient private lazy val elementType: DataType =
+ array.dataType.asInstanceOf[ArrayType].elementType
+ @transient private lazy val resultArrayElementNullable: Boolean =
+ array.dataType.asInstanceOf[ArrayType].containsNull
+
+ @transient private lazy val isPrimitiveType: Boolean =
CodeGenerator.isPrimitiveType(elementType)
+ @transient private lazy val canPerformFastBinarySearch: Boolean =
isPrimitiveType &&
+ elementType != BooleanType && !resultArrayElementNullable
+
+ override def replacement: Expression =
+ if (canPerformFastBinarySearch) {
+ StaticInvoke(
+ classOf[ArrayExpressionUtils],
+ IntegerType,
+ "binarySearch",
+ Seq(array, value),
+ inputTypes)
+ } else if (isPrimitiveType) {
+ StaticInvoke(
+ classOf[ArrayExpressionUtils],
+ IntegerType,
+ "binarySearchNullSafe",
Review Comment:
hmm... I made an attempt:
1. Replace `public static int binarySearchNullSafe(ArrayData data, Boolean
value)` with
```java
public static int binarySearch(ArrayData data, Boolean value) {
return Arrays.binarySearch(data.toJavaBooleanArray(), value);
}
```
2. All other `binarySearchNullSafe` methods in `ArrayExpressionUtils` have
been removed.
3. This line is also changed to call `binarySearch`.
4. run `build/sbt "catalyst/testOnly
org.apache.spark.sql.catalyst.expressions.CollectionExpressionsSuite"`
then
```
[info] CollectionExpressionsSuite:
[info] - Array and Map Size - legacy (734 milliseconds)
[info] - Array and Map Size (108 milliseconds)
[info] - Unsupported data type for size() (145 milliseconds)
[info] - MapKeys/MapValues (64 milliseconds)
[info] - MapContainsKey (2 milliseconds)
[info] - ArrayContains (36 milliseconds)
[info] - ArrayBinarySearch (249 milliseconds)
[info] - MapEntries (93 milliseconds)
[info] - Map Concat (204 milliseconds)
[info] - MapFromEntries (67 milliseconds)
[info] - Sort Map (98 milliseconds)
[info] - Sort Array (140 milliseconds)
[info] - Array contains (83 milliseconds)
[info] - ArraysOverlap (89 milliseconds)
[info] - Slice (107 milliseconds)
[info] - ArrayJoin (53 milliseconds)
[info] - ArraysZip (756 milliseconds)
[info] - Array Min (36 milliseconds)
[info] - Array max (22 milliseconds)
[info] - Sequence of numbers (526 milliseconds)
[info] - Sequence of timestamps (173 milliseconds)
[info] - Sequence on DST boundaries (28 milliseconds)
[info] - Sequence of dates (96 milliseconds)
[info] - SPARK-37544: Time zone should not affect date sequence with month
interval (136 milliseconds)
[info] - SPARK-35088: Accept ANSI intervals by the Sequence expression (209
milliseconds)
[info] - SPARK-36090: Support TimestampNTZType in expression Sequence (111
milliseconds)
[info] - Sequence with default step (64 milliseconds)
[info] - Reverse (48 milliseconds)
[info] - Array Position (47 milliseconds)
[info] - elementAt (146 milliseconds)
[info] - correctly handles ElementAt nullability for arrays (2 milliseconds)
[info] - Concat (125 milliseconds)
[info] - Flatten (71 milliseconds)
[info] - ArrayRepeat (55 milliseconds)
[info] - Array remove (87 milliseconds)
[info] - Array Distinct (78 milliseconds)
[info] - Array Union (110 milliseconds)
[info] - Shuffle (73 milliseconds)
[info] - Array Except (144 milliseconds)
[info] - Array Except - null handling (33 milliseconds)
[info] - Array Insert (164 milliseconds)
[info] - Array Intersect (143 milliseconds)
[info] - Array Intersect - null handling (32 milliseconds)
[info] - SPARK-31980: Start and end equal in month range (14 milliseconds)
[info] - SPARK-36639: Start and end equal in month range with a negative
step (6 milliseconds)
[info] - SPARK-33386: element_at ArrayIndexOutOfBoundsException (17
milliseconds)
[info] - SPARK-40066: element_at returns null on invalid map value access (8
milliseconds)
[info] - SPARK-36702: ArrayUnion should handle duplicated Double.NaN and
Float.Nan (26 milliseconds)
[info] - SPARK-36753: ArrayExcept should handle duplicated Double.NaN and
Float.Nan (18 milliseconds)
[info] - SPARK-36754: ArrayIntersect should handle duplicated Double.NaN and
Float.Nan (13 milliseconds)
[info] - SPARK-36741: ArrayDistinct should handle duplicated Double.NaN and
Float.Nan (12 milliseconds)
[info] - SPARK-36755: ArraysOverlap hould handle duplicated Double.NaN and
Float.Nan (11 milliseconds)
[info] - SPARK-36740: ArrayMin/ArrayMax/SortArray should handle NaN greater
than non-NaN value (25 milliseconds)
[info] - SPARK-39184: Avoid ArrayIndexOutOfBoundsException when crossing DST
boundary (38 milliseconds)
[info] - SPARK-42401: Array insert of null value (explicit) (12 milliseconds)
[info] - SPARK-42401: Array insert of null value (implicit) (6 milliseconds)
[info] Run completed in 6 seconds, 830 milliseconds.
[info] Total number of tests run: 56
[info] Suites: completed 1, aborted 0
[info] Tests: succeeded 56, failed 0, canceled 0, ignored 0, pending 0
[info] All tests passed.
```
So is it due to insufficient test case coverage or were those
`binarySearchNullSafe` methods originally unnecessary to define?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]