viirya commented on a change in pull request #32457: URL: https://github.com/apache/spark/pull/32457#discussion_r629613414
########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala ########## @@ -152,38 +152,66 @@ case class ExpandExec( // This column is the same across all output rows. Just generate code for it here. BindReferences.bindReference(firstExpr, attributeSeq).genCode(ctx) } else { - val isNull = ctx.freshName("isNull") - val value = ctx.freshName("value") - val code = code""" - |boolean $isNull = true; - |${CodeGenerator.javaType(firstExpr.dataType)} $value = - | ${CodeGenerator.defaultValue(firstExpr.dataType)}; - """.stripMargin + val isNull = ctx.addMutableState( + CodeGenerator.JAVA_BOOLEAN, + "resultIsNull", + v => s"$v = true;") + val value = ctx.addMutableState( + CodeGenerator.javaType(firstExpr.dataType), + "resultValue", + v => s"$v = ${CodeGenerator.defaultValue(firstExpr.dataType)};") + ExprCode( - code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, firstExpr.dataType)) } } // Part 2: switch/case statements val cases = projections.zipWithIndex.map { case (exprs, row) => - var updateCode = "" - for (col <- exprs.indices) { + val (exprCodesWithIndices, inputVarSets) = exprs.indices.flatMap { col => if (!sameOutput(col)) { - val ev = BindReferences.bindReference(exprs(col), attributeSeq).genCode(ctx) - updateCode += - s""" - |${ev.code} - |${outputColumns(col).isNull} = ${ev.isNull}; - |${outputColumns(col).value} = ${ev.value}; - """.stripMargin + val boundExpr = BindReferences.bindReference(exprs(col), attributeSeq) + val exprCode = boundExpr.genCode(ctx) + val inputVars = CodeGenerator.getLocalInputVariableValues(ctx, boundExpr)._1 + Some(((col, exprCode), inputVars)) + } else { + None + } + }.unzip + + val updateCode = exprCodesWithIndices.map { case (col, ev) => + s""" + |${ev.code} + |${outputColumns(col).isNull} = ${ev.isNull}; + |${outputColumns(col).value} = ${ev.value}; + """.stripMargin + } + + val splitThreshold = SQLConf.get.methodSplitThreshold + val inputVars = inputVarSets.foldLeft(Set.empty[VariableValue])(_ ++ _) + val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVars.toSeq) + val maybeSplitUpdateCode = if (CodeGenerator.isValidParamLength(paramLength) && + exprCodesWithIndices.map(_._2.code.length).sum > splitThreshold) { Review comment: Here we only check if the code under current switch case is under threshold. Seems to me, we need to accumulate non-split code too. If accumulated code is over the threshold, we need to split it out on later switch cases. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org