Github user cloud-fan commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19901#discussion_r155168991
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
 ---
    @@ -212,59 +213,66 @@ case class CaseWhen(
           val res = elseExpr.genCode(ctx)
           s"""
              |${res.code}
    -         |${ev.isNull} = ${res.isNull};
    -         |${ev.value} = ${res.value};
    +         |$resultIsNull = (byte)(${res.isNull} ? 1 : 0);
    +         |$tmpResult = ${res.value};
            """.stripMargin
         }
     
         val allConditions = cases ++ elseCode
     
         // This generates code like:
    -    //   conditionMet = caseWhen_1(i);
    -    //   if(conditionMet) {
    +    //   caseWhenResultIsNull = caseWhen_1(i);
    +    //   if(caseWhenResultIsNull != -1) {
         //     continue;
         //   }
    -    //   conditionMet = caseWhen_2(i);
    -    //   if(conditionMet) {
    +    //   caseWhenResultIsNull = caseWhen_2(i);
    +    //   if(caseWhenResultIsNull != -1) {
         //     continue;
         //   }
         //   ...
         // and the declared methods are:
    -    //   private boolean caseWhen_1234() {
    -    //     boolean conditionMet = false;
    +    //   private byte caseWhen_1234() {
    +    //     byte caseWhenResultIsNull = -1;
         //     do {
         //       // here the evaluation of the conditions
         //     } while (false);
    -    //     return conditionMet;
    +    //     return caseWhenResultIsNull;
         //   }
         val codes = ctx.splitExpressionsWithCurrentInputs(
           expressions = allConditions,
           funcName = "caseWhen",
    -      returnType = ctx.JAVA_BOOLEAN,
    -      makeSplitFunction = func =>
    -        s"""
    -           |${ctx.JAVA_BOOLEAN} $conditionMet = false;
    -           |do {
    -           |  $func
    -           |} while (false);
    -           |return $conditionMet;
    -         """.stripMargin,
    -      foldFunctions = _.map { funcCall =>
    -        s"""
    -           |$conditionMet = $funcCall;
    -           |if ($conditionMet) {
    -           |  continue;
    -           |}
    -         """.stripMargin
    -      }.mkString)
    +      returnType = ctx.JAVA_BYTE,
    +      makeSplitFunction = {
    +        func =>
    +          s"""
    +             |${ctx.JAVA_BYTE} $resultIsNull = -1;
    +             |do {
    +             |  $func
    +             |} while (false);
    +             |return $resultIsNull;
    +           """.stripMargin
    +      },
    +      foldFunctions = { funcCalls =>
    +        funcCalls.map { funcCall =>
    +          s"""
    +             |$resultIsNull = $funcCall;
    +             |if ($resultIsNull != -1) {
    +             |  continue;
    +             |}
    +           """.stripMargin
    +        }.mkString
    +      })
     
    -    ev.copy(code = s"""
    -      ${ev.isNull} = true;
    -      ${ev.value} = ${ctx.defaultValue(dataType)};
    -      ${ctx.JAVA_BOOLEAN} $conditionMet = false;
    -      do {
    -        $codes
    -      } while (false);""")
    +    ev.copy(code =
    +      s"""
    +         |${ctx.JAVA_BYTE} $resultIsNull = -1;
    +         |$tmpResult = ${ctx.defaultValue(dataType)};
    +         |do {
    +         |  $codes
    +         |} while (false);
    +         |boolean ${ev.isNull} = ($resultIsNull != 0); // TRUE if -1 or 1
    +         |${ctx.javaType(dataType)} ${ev.value} = $tmpResult;
    --- End diff --
    
    nit: `final`


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to