Github user mgaido91 commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19752#discussion_r152537080
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
 ---
    @@ -211,111 +231,61 @@ abstract class CaseWhenBase(
         val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("")
         "CASE" + cases + elseCase + " END"
       }
    -}
    -
    -
    -/**
    - * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE 
e] END".
    - * When a = true, returns b; when c = true, returns d; else returns e.
    - *
    - * @param branches seq of (branch condition, branch value)
    - * @param elseValue optional value for the else branch
    - */
    -// scalastyle:off line.size.limit
    -@ExpressionDescription(
    -  usage = "CASE WHEN expr1 THEN expr2 [WHEN expr3 THEN expr4]* [ELSE 
expr5] END - When `expr1` = true, returns `expr2`; else when `expr3` = true, 
returns `expr4`; else returns `expr5`.",
    -  arguments = """
    -    Arguments:
    -      * expr1, expr3 - the branch condition expressions should all be 
boolean type.
    -      * expr2, expr4, expr5 - the branch value expressions and else value 
expression should all be
    -          same type or coercible to a common type.
    -  """,
    -  examples = """
    -    Examples:
    -      > SELECT CASE WHEN 1 > 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END;
    -       1
    -      > SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END;
    -       2
    -      > SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 < 0 THEN 2.0 ELSE null END;
    -       NULL
    -  """)
    -// scalastyle:on line.size.limit
    -case class CaseWhen(
    -    val branches: Seq[(Expression, Expression)],
    -    val elseValue: Option[Expression] = None)
    -  extends CaseWhenBase(branches, elseValue) with CodegenFallback with 
Serializable {
    -
    -  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    -    super[CodegenFallback].doGenCode(ctx, ev)
    -  }
    -
    -  def toCodegen(): CaseWhenCodegen = {
    -    CaseWhenCodegen(branches, elseValue)
    -  }
    -}
    -
    -/**
    - * CaseWhen expression used when code generation condition is satisfied.
    - * OptimizeCodegen optimizer replaces CaseWhen into CaseWhenCodegen.
    - *
    - * @param branches seq of (branch condition, branch value)
    - * @param elseValue optional value for the else branch
    - */
    -case class CaseWhenCodegen(
    -    val branches: Seq[(Expression, Expression)],
    -    val elseValue: Option[Expression] = None)
    -  extends CaseWhenBase(branches, elseValue) with Serializable {
     
       override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    -    // Generate code that looks like:
    -    //
    -    // condA = ...
    -    // if (condA) {
    -    //   valueA
    -    // } else {
    -    //   condB = ...
    -    //   if (condB) {
    -    //     valueB
    -    //   } else {
    -    //     condC = ...
    -    //     if (condC) {
    -    //       valueC
    -    //     } else {
    -    //       elseValue
    -    //     }
    -    //   }
    -    // }
    +    val conditionMet = ctx.freshName("caseWhenConditionMet")
    +    ctx.addMutableState("boolean", ev.isNull, "")
    +    ctx.addMutableState(ctx.javaType(dataType), ev.value, "")
         val cases = branches.map { case (condExpr, valueExpr) =>
           val cond = condExpr.genCode(ctx)
           val res = valueExpr.genCode(ctx)
           s"""
    -        ${cond.code}
    -        if (!${cond.isNull} && ${cond.value}) {
    -          ${res.code}
    -          ${ev.isNull} = ${res.isNull};
    -          ${ev.value} = ${res.value};
    +        if(!$conditionMet) {
    +          ${cond.code}
    +          if (!${cond.isNull} && ${cond.value}) {
    +            ${res.code}
    +            ${ev.isNull} = ${res.isNull};
    +            ${ev.value} = ${res.value};
    +            $conditionMet = true;
    +          }
             }
           """
         }
     
    -    var generatedCode = cases.mkString("", "\nelse {\n", "\nelse {\n")
    -
    -    elseValue.foreach { elseExpr =>
    +    val elseCode = elseValue.map { elseExpr =>
           val res = elseExpr.genCode(ctx)
    -      generatedCode +=
    -        s"""
    +      s"""
    +        if(!$conditionMet) {
               ${res.code}
               ${ev.isNull} = ${res.isNull};
               ${ev.value} = ${res.value};
    -        """
    -    }
    +        }
    +      """
    +    }.getOrElse("")
     
    -    generatedCode += "}\n" * cases.size
    +    val casesCode = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
    +      cases.mkString("\n")
    +    } else {
    +      ctx.splitExpressions(cases, "caseWhen",
    --- End diff --
    
    I think that we need to call it, indeed, as explained in this comment: 
https://github.com/apache/spark/pull/19767#issuecomment-345176286


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to