nvander1 commented on a change in pull request #24761: [SPARK-27905] [SQL] Add
higher order function 'forall'
URL: https://github.com/apache/spark/pull/24761#discussion_r291686074
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
##########
@@ -393,35 +418,38 @@ case class ArrayFilter(
case class ArrayExists(
argument: Expression,
function: Expression)
- extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {
-
- override def dataType: DataType = BooleanType
-
- override def functionType: AbstractDataType = BooleanType
-
+ extends ArrayExistsForAllBase {
+ override def prettyName: String = "exists"
+ override def check(cond: Boolean): Boolean = cond
override def bind(f: (Expression, Seq[(DataType, Boolean)]) =>
LambdaFunction): ArrayExists = {
val ArrayType(elementType, containsNull) = argument.dataType
copy(function = f(function, (elementType, containsNull) :: Nil))
}
+}
- @transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable),
_) = function
-
- override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
- val arr = argumentValue.asInstanceOf[ArrayData]
- val f = functionForEval
- var exists = false
- var i = 0
- while (i < arr.numElements && !exists) {
- elementVar.value.set(arr.get(i, elementVar.dataType))
- if (f.eval(inputRow).asInstanceOf[Boolean]) {
- exists = true
- }
- i += 1
- }
- exists
+/**
+ * Tests whether a predicate holds for all elements in the array.
+ */
+@ExpressionDescription(usage =
+ "_FUNC_(expr, pred) - Tests whether a predicate holds for all elements in
the array.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_(array(1, 3, 7), x -> x % 2 == 1);
+ true
+ > SELECT _FUNC_(array(1, 3, 6), x -> x % 2 == 1);
+ false
+ """,
+ since = "3.0.0")
+case class ArrayForAll(
+ argument: Expression,
+ function: Expression)
+ extends ArrayExistsForAllBase {
+ override def prettyName: String = "forall"
+ override def check(cond: Boolean): Boolean = !cond
+ override def bind(f: (Expression, Seq[(DataType, Boolean)]) =>
LambdaFunction): ArrayForAll = {
Review comment:
I tried to factor out the bind definition as well here, but there seems to
be a known issue on trying to rely on the generated copy method of a case
class in its parent trait: https://www.scala-lang.org/old/node/6369
Aside: this is also the same bind method for ArrayFilter
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]