Github user mgaido91 commented on a diff in the pull request:
https://github.com/apache/spark/pull/19752#discussion_r153118387
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
---
@@ -211,111 +231,73 @@ abstract class CaseWhenBase(
val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("")
"CASE" + cases + elseCase + " END"
}
-}
-
-
-/**
- * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE
e] END".
- * When a = true, returns b; when c = true, returns d; else returns e.
- *
- * @param branches seq of (branch condition, branch value)
- * @param elseValue optional value for the else branch
- */
-// scalastyle:off line.size.limit
-@ExpressionDescription(
- usage = "CASE WHEN expr1 THEN expr2 [WHEN expr3 THEN expr4]* [ELSE
expr5] END - When `expr1` = true, returns `expr2`; else when `expr3` = true,
returns `expr4`; else returns `expr5`.",
- arguments = """
- Arguments:
- * expr1, expr3 - the branch condition expressions should all be
boolean type.
- * expr2, expr4, expr5 - the branch value expressions and else value
expression should all be
- same type or coercible to a common type.
- """,
- examples = """
- Examples:
- > SELECT CASE WHEN 1 > 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END;
- 1
- > SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END;
- 2
- > SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 < 0 THEN 2.0 ELSE null END;
- NULL
- """)
-// scalastyle:on line.size.limit
-case class CaseWhen(
- val branches: Seq[(Expression, Expression)],
- val elseValue: Option[Expression] = None)
- extends CaseWhenBase(branches, elseValue) with CodegenFallback with
Serializable {
-
- override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- super[CodegenFallback].doGenCode(ctx, ev)
- }
-
- def toCodegen(): CaseWhenCodegen = {
- CaseWhenCodegen(branches, elseValue)
- }
-}
-
-/**
- * CaseWhen expression used when code generation condition is satisfied.
- * OptimizeCodegen optimizer replaces CaseWhen into CaseWhenCodegen.
- *
- * @param branches seq of (branch condition, branch value)
- * @param elseValue optional value for the else branch
- */
-case class CaseWhenCodegen(
- val branches: Seq[(Expression, Expression)],
- val elseValue: Option[Expression] = None)
- extends CaseWhenBase(branches, elseValue) with Serializable {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- // Generate code that looks like:
- //
- // condA = ...
- // if (condA) {
- // valueA
- // } else {
- // condB = ...
- // if (condB) {
- // valueB
- // } else {
- // condC = ...
- // if (condC) {
- // valueC
- // } else {
- // elseValue
- // }
- // }
- // }
+ // This variable represents whether the first successful condition is
met or not.
+ // It is initialized to `false` and it is set to `true` when the first
condition which
+ // evaluates to `true` is met and therefore is not needed to go on
anymore on the computation
+ // of the following conditions.
+ val conditionMet = ctx.freshName("caseWhenConditionMet")
+ ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull, "")
+ ctx.addMutableState(ctx.javaType(dataType), ev.value, "")
val cases = branches.map { case (condExpr, valueExpr) =>
val cond = condExpr.genCode(ctx)
val res = valueExpr.genCode(ctx)
s"""
- ${cond.code}
- if (!${cond.isNull} && ${cond.value}) {
- ${res.code}
- ${ev.isNull} = ${res.isNull};
- ${ev.value} = ${res.value};
+ if(!$conditionMet) {
+ ${cond.code}
+ if (!${cond.isNull} && ${cond.value}) {
+ ${res.code}
+ ${ev.isNull} = ${res.isNull};
+ ${ev.value} = ${res.value};
+ $conditionMet = true;
+ }
}
"""
}
- var generatedCode = cases.mkString("", "\nelse {\n", "\nelse {\n")
-
- elseValue.foreach { elseExpr =>
+ val elseCode = elseValue.map { elseExpr =>
val res = elseExpr.genCode(ctx)
- generatedCode +=
- s"""
+ s"""
+ if(!$conditionMet) {
${res.code}
${ev.isNull} = ${res.isNull};
${ev.value} = ${res.value};
- """
+ }
+ """
}
- generatedCode += "}\n" * cases.size
+ val allConditions = cases ++ elseCode
+
+ val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
+ allConditions.mkString("\n")
+ } else {
+ ctx.splitExpressions(allConditions, "caseWhen",
+ ("InternalRow", ctx.INPUT_ROW) :: (ctx.JAVA_BOOLEAN,
conditionMet) :: Nil,
+ returnType = ctx.JAVA_BOOLEAN,
+ makeSplitFunction = {
+ func =>
+ s"""
+ $func
+ return $conditionMet;
+ """
+ },
+ foldFunctions = { funcCalls =>
+ funcCalls.map { funcCall =>
+ s"""
+ $conditionMet = $funcCall;
+ if ($conditionMet) {
+ continue;
+ }"""
+ }.mkString("do {", "", "\n} while (false);")
--- End diff --
no, since there is a newline at the beginning of each expression.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]