jaceklaskowski commented on code in PR #34558:
URL: https://github.com/apache/spark/pull/34558#discussion_r1238227638
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala:
##########
@@ -172,6 +172,39 @@ class CodegenContext extends Logging {
*/
var currentVars: Seq[ExprCode] = null
+ /**
+ * Holding a map of current lambda variables.
+ */
+ var currentLambdaVars: mutable.Map[Long, ExprCode] = mutable.HashMap.empty
+
+ def withLambdaVars(namedLambdas: Seq[NamedLambdaVariable],
Review Comment:
Can you put `namedLambdas...` on a separate line?
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala:
##########
@@ -172,6 +172,39 @@ class CodegenContext extends Logging {
*/
var currentVars: Seq[ExprCode] = null
+ /**
+ * Holding a map of current lambda variables.
+ */
+ var currentLambdaVars: mutable.Map[Long, ExprCode] = mutable.HashMap.empty
+
+ def withLambdaVars(namedLambdas: Seq[NamedLambdaVariable],
+ f: Seq[ExprCode] => ExprCode): ExprCode = {
+ val lambdaVars = namedLambdas.map { namedLambda =>
Review Comment:
nit: Replace `namedLambda` to `lambda`?
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala:
##########
@@ -280,6 +335,29 @@ trait SimpleHigherOrderFunction extends
HigherOrderFunction with BinaryLike[Expr
}
}
+ protected def nullSafeCodeGen(
+ ctx: CodegenContext,
+ ev: ExprCode,
+ f: String => String): ExprCode = {
Review Comment:
I'd be very happy if you use this parameter formatting style in the other
places in your PR. Makes reading so much easier, esp. with functions with 5+
params. Can you make such change? 🙏
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala:
##########
@@ -781,6 +1007,49 @@ case class ArrayForAll(
}
}
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode):
ExprCode = {
+ ctx.withLambdaVars(Seq(elementVar), { case Seq(elementCode) =>
+ nullSafeCodeGen(ctx, ev, arg => {
+ val numElements = ctx.freshName("numElements")
+ val forall = ctx.freshName("forall")
+ val foundNull = ctx.freshName("foundNull")
+ val i = ctx.freshName("i")
+
+ val functionCode = function.genCode(ctx)
+ val elementAssignment = assignArrayElement(ctx, arg, elementCode,
elementVar, i)
+
+ val nullCheck = if (nullable) {
+ s"""
+ if ($forall && $foundNull) {
+ ${ev.isNull} = true;
+ }
+ """
+ } else {
+ ""
+ }
+
+ s"""
+ |final int $numElements = ${arg}.numElements();
+ |boolean $forall = true;
+ |boolean $foundNull = false;
+ |int $i = 0;
+ |while ($i < $numElements && $forall) {
+ | $elementAssignment
+ | ${functionCode.code}
+ | if (${functionCode.isNull}) {
+ | $foundNull = true;
+ | } else if (!${functionCode.value}) {
+ | $forall = false;
+ | }
+ | $i++;
+ |}
+ |$nullCheck
+ |${ev.value} = $forall;
+ """.stripMargin
+ })
+ })
+ }
+
Review Comment:
Looks like a copy and paste of `exists`, doesn't it? Can we have a parent
class for some sharing? Unless I'm mistaken, the generated code block is the
exact copy except `$forall` (vs `$exists`).
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala:
##########
@@ -172,6 +172,39 @@ class CodegenContext extends Logging {
*/
var currentVars: Seq[ExprCode] = null
+ /**
+ * Holding a map of current lambda variables.
+ */
+ var currentLambdaVars: mutable.Map[Long, ExprCode] = mutable.HashMap.empty
+
+ def withLambdaVars(namedLambdas: Seq[NamedLambdaVariable],
+ f: Seq[ExprCode] => ExprCode): ExprCode = {
+ val lambdaVars = namedLambdas.map { namedLambda =>
+ val id = namedLambda.exprId.id
+ if (currentLambdaVars.get(id).nonEmpty) {
+ throw QueryExecutionErrors.lambdaVariableAlreadyDefinedError(id)
+ }
+ val isNull = if (namedLambda.nullable) {
+ JavaCode.isNullGlobal(addMutableState(JAVA_BOOLEAN, "lambdaIsNull"))
+ } else {
+ FalseLiteral
+ }
+ val value = addMutableState(javaType(namedLambda.dataType),
"lambdaValue")
+ val lambdaVar = ExprCode(isNull, JavaCode.global(value,
namedLambda.dataType))
+ currentLambdaVars.put(id, lambdaVar)
+ lambdaVar
+ }
+
+ val result = f(lambdaVars)
+ namedLambdas.map(_.exprId.id).foreach(currentLambdaVars.remove)
+ result
+ }
+
+ def getLambdaVar(id: Long): ExprCode = {
+ currentLambdaVars.getOrElse(id,
Review Comment:
Can you move `id` to its own new line? 🙏
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]