Repository: spark Updated Branches: refs/heads/master ced6ccf0d -> 132a3f470
[SPARK-22500][SQL][FOLLOWUP] cast for struct can split code even with whole stage codegen ## What changes were proposed in this pull request? A followup of https://github.com/apache/spark/pull/19730, we can split the code for casting struct even with whole stage codegen. This PR also has some renaming to make the code easier to read. ## How was this patch tested? existing test Author: Wenchen Fan <wenc...@databricks.com> Closes #19891 from cloud-fan/cast. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/132a3f47 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/132a3f47 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/132a3f47 Branch: refs/heads/master Commit: 132a3f470811bb98f265d0c9ad2c161698e0237b Parents: ced6ccf Author: Wenchen Fan <wenc...@databricks.com> Authored: Tue Dec 5 11:40:13 2017 -0800 Committer: gatorsmile <gatorsm...@gmail.com> Committed: Tue Dec 5 11:40:13 2017 -0800 ---------------------------------------------------------------------- .../spark/sql/catalyst/expressions/Cast.scala | 52 +++++++++----------- 1 file changed, 24 insertions(+), 28 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/132a3f47/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index f4ecbdb..b8d3661 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -548,8 +548,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)) } - // three function arguments are: child.primitive, result.primitive and result.isNull - // it returns the code snippets to be put in null safe evaluation region + // The function arguments are: `input`, `result` and `resultIsNull`. We don't need `inputIsNull` + // in parameter list, because the returned code will be put in null safe evaluation region. private[this] type CastFunction = (String, String, String) => String private[this] def nullSafeCastFunction( @@ -584,15 +584,15 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String throw new SparkException(s"Cannot cast $from to $to.") } - // Since we need to cast child expressions recursively inside ComplexTypes, such as Map's + // Since we need to cast input expressions recursively inside ComplexTypes, such as Map's // Key and Value, Struct's field, we need to name out all the variable names involved in a cast. - private[this] def castCode(ctx: CodegenContext, childPrim: String, childNull: String, - resultPrim: String, resultNull: String, resultType: DataType, cast: CastFunction): String = { + private[this] def castCode(ctx: CodegenContext, input: String, inputIsNull: String, + result: String, resultIsNull: String, resultType: DataType, cast: CastFunction): String = { s""" - boolean $resultNull = $childNull; - ${ctx.javaType(resultType)} $resultPrim = ${ctx.defaultValue(resultType)}; - if (!$childNull) { - ${cast(childPrim, resultPrim, resultNull)} + boolean $resultIsNull = $inputIsNull; + ${ctx.javaType(resultType)} $result = ${ctx.defaultValue(resultType)}; + if (!$inputIsNull) { + ${cast(input, result, resultIsNull)} } """ } @@ -1014,8 +1014,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case (fromField, toField) => nullSafeCastFunction(fromField.dataType, toField.dataType, ctx) } val rowClass = classOf[GenericInternalRow].getName - val result = ctx.freshName("result") - val tmpRow = ctx.freshName("tmpRow") + val tmpResult = ctx.freshName("tmpResult") + val tmpInput = ctx.freshName("tmpInput") val fieldsEvalCode = fieldsCasts.zipWithIndex.map { case (cast, i) => val fromFieldPrim = ctx.freshName("ffp") @@ -1024,37 +1024,33 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val toFieldNull = ctx.freshName("tfn") val fromType = ctx.javaType(from.fields(i).dataType) s""" - boolean $fromFieldNull = $tmpRow.isNullAt($i); + boolean $fromFieldNull = $tmpInput.isNullAt($i); if ($fromFieldNull) { - $result.setNullAt($i); + $tmpResult.setNullAt($i); } else { $fromType $fromFieldPrim = - ${ctx.getValue(tmpRow, from.fields(i).dataType, i.toString)}; + ${ctx.getValue(tmpInput, from.fields(i).dataType, i.toString)}; ${castCode(ctx, fromFieldPrim, fromFieldNull, toFieldPrim, toFieldNull, to.fields(i).dataType, cast)} if ($toFieldNull) { - $result.setNullAt($i); + $tmpResult.setNullAt($i); } else { - ${ctx.setColumn(result, to.fields(i).dataType, i, toFieldPrim)}; + ${ctx.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)}; } } """ } - val fieldsEvalCodes = if (ctx.currentVars == null) { - ctx.splitExpressions( - expressions = fieldsEvalCode, - funcName = "castStruct", - arguments = ("InternalRow", tmpRow) :: (rowClass, result) :: Nil) - } else { - fieldsEvalCode.mkString("\n") - } + val fieldsEvalCodes = ctx.splitExpressions( + expressions = fieldsEvalCode, + funcName = "castStruct", + arguments = ("InternalRow", tmpInput) :: (rowClass, tmpResult) :: Nil) - (c, evPrim, evNull) => + (input, result, resultIsNull) => s""" - final $rowClass $result = new $rowClass(${fieldsCasts.length}); - final InternalRow $tmpRow = $c; + final $rowClass $tmpResult = new $rowClass(${fieldsCasts.length}); + final InternalRow $tmpInput = $input; $fieldsEvalCodes - $evPrim = $result; + $result = $tmpResult; """ } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org