LuciferYang commented on code in PR #47741:
URL: https://github.com/apache/spark/pull/47741#discussion_r1719322782
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala:
##########
@@ -1518,6 +1519,137 @@ case class ArrayContains(left: Expression, right:
Expression)
copy(left = newLeft, right = newRight)
}
+/**
+ * Searches the specified array for the specified object using the binary
search algorithm.
+ * This expression is dedicated only for PySpark and Spark-ML.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(array, value) - Return index (0-based) of the search value,
" +
+ "if it is contained in the array; otherwise, (-<insertion point> - 1).",
+ examples = """
+ Examples:
+ > SELECT _FUNC_(array(1, 2, 3), 2);
+ 1
+ > SELECT _FUNC_(array(null, 1, 2, 3), 2);
+ 2
+ > SELECT _FUNC_(array(1.0F, 2.0F, 3.0F), 1.1F);
+ -2
+ """,
+ group = "array_funcs",
+ since = "4.0.0")
+case class ArrayBinarySearch(array: Expression, value: Expression)
+ extends BinaryExpression
+ with ImplicitCastInputTypes
+ with NullIntolerant
+ with RuntimeReplaceable
+ with QueryErrorsBase {
+
+ override def left: Expression = array
+ override def right: Expression = value
+ override def dataType: DataType = IntegerType
+
+ override def inputTypes: Seq[AbstractDataType] = {
+ (left.dataType, right.dataType) match {
+ case (_, NullType) => Seq.empty
+ case (ArrayType(e1, hasNull), e2) =>
+ TypeCoercion.findTightestCommonType(e1, e2) match {
+ case Some(dt) => Seq(ArrayType(dt, hasNull), dt)
+ case _ => Seq.empty
+ }
+ case _ => Seq.empty
+ }
+ }
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ (left.dataType, right.dataType) match {
+ case (NullType, _) | (_, NullType) =>
+ DataTypeMismatch(
+ errorSubClass = "NULL_TYPE",
+ Map("functionName" -> toSQLId(prettyName)))
+ case (t, _) if !ArrayType.acceptsType(t) =>
+ DataTypeMismatch(
+ errorSubClass = "UNEXPECTED_INPUT_TYPE",
+ messageParameters = Map(
+ "paramIndex" -> ordinalNumber(0),
+ "requiredType" -> toSQLType(ArrayType),
+ "inputSql" -> toSQLExpr(left),
+ "inputType" -> toSQLType(left.dataType))
+ )
+ case (ArrayType(e1, _), e2) if DataTypeUtils.sameType(e1, e2) =>
+ TypeUtils.checkForOrderingExpr(e2, prettyName)
+ case _ =>
+ DataTypeMismatch(
+ errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES",
+ messageParameters = Map(
+ "functionName" -> toSQLId(prettyName),
+ "dataType" -> toSQLType(ArrayType),
+ "leftType" -> toSQLType(left.dataType),
+ "rightType" -> toSQLType(right.dataType)
+ )
+ )
+ }
+ }
+
+ @transient private lazy val elementType: DataType =
+ array.dataType.asInstanceOf[ArrayType].elementType
+ @transient private lazy val resultArrayElementNullable: Boolean =
+ array.dataType.asInstanceOf[ArrayType].containsNull
+
+ @transient private lazy val isPrimitiveType: Boolean =
CodeGenerator.isPrimitiveType(elementType)
+ @transient private lazy val canPerformFastBinarySearch: Boolean =
isPrimitiveType &&
+ elementType != BooleanType && !resultArrayElementNullable
+
+ @transient private lazy val comp: SerializableComparator[Any] = {
Review Comment:
If it is necessary for `comp` to be `Serializable`, I think the following
definition method can be adopted to avoid defining additional interface:
```scala
@transient private lazy val comp: Comparator[Any] = new Comparator[Any] with
Serializable {
private val ordering: Ordering[Any] = array.dataType match {
case _ @ ArrayType(n, _) =>
PhysicalDataType.ordering(n)
}
override def compare(o1: Any, o2: Any): Int = {
(o1, o2) match {
case (null, null) => 0
case (null, _) => 1
case (_, null) => -1
case _ => ordering.compare(o1, o2)
}
}
}
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]