Kimahriman commented on code in PR #34558:
URL: https://github.com/apache/spark/pull/34558#discussion_r3297942369
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala:
##########
@@ -886,6 +1179,114 @@ case class ArrayAggregate(
}
}
+ protected def nullSafeCodeGen(
+ ctx: CodegenContext,
+ ev: ExprCode,
+ f: String => String): ExprCode = {
+ val argumentGen = argument.genCode(ctx)
+ val resultCode = f(argumentGen.value)
+
+ if (nullable) {
+ val nullSafeEval = ctx.nullSafeExec(argument.nullable,
argumentGen.isNull)(resultCode)
+ ev.copy(code = code"""
+ |${argumentGen.code}
+ |boolean ${ev.isNull} = ${argumentGen.isNull};
+ |${CodeGenerator.javaType(dataType)} ${ev.value} =
${CodeGenerator.defaultValue(dataType)};
+ |$nullSafeEval
+ """)
+ } else {
+ ev.copy(code = code"""
+ |${argumentGen.code}
+ |${CodeGenerator.javaType(dataType)} ${ev.value} =
${CodeGenerator.defaultValue(dataType)};
+ |$resultCode
+ """, isNull = FalseLiteral)
+ }
+ }
+
+ protected def assignVar(
+ varCode: ExprCode,
+ atomicVar: String,
+ value: String,
+ isNull: String,
+ nullable: Boolean): String = {
+ val atomicAssign = assignAtomic(atomicVar, value, isNull, nullable)
+ if (nullable) {
+ s"""
+ ${varCode.value} = $value;
+ ${varCode.isNull} = $isNull;
+ $atomicAssign
+ """
+ } else {
+ s"""
+ ${varCode.value} = $value;
+ $atomicAssign
+ """
+ }
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ ctx.withLambdaVars(Seq(elementVar, accForMergeVar, accForFinishVar),
varCodes => {
+ val Seq(elementCode, accForMergeCode, accForFinishCode) = varCodes
+
+ nullSafeCodeGen(ctx, ev, arg => {
+ val numElements = ctx.freshName("numElements")
+ val i = ctx.freshName("i")
+
+ val zeroCode = zero.genCode(ctx)
+ val mergeCode = merge.genCode(ctx)
+ val finishCode = finish.genCode(ctx)
+
+ val elementAssignment = assignArrayElement(ctx, arg, elementCode,
elementVar, i)
+ val mergeAtomic = ctx.addReferenceObj(accForMergeVar.name,
+ accForMergeVar.value)
+ val finishAtomic = ctx.addReferenceObj(accForFinishVar.name,
+ accForFinishVar.value)
+
+ val mergeJavaType = CodeGenerator.javaType(accForMergeVar.dataType)
+ val finishJavaType = CodeGenerator.javaType(accForFinishVar.dataType)
+
+ // Some expressions return internal buffers that we have to copy
+ val mergeCopy = if (CodeGenerator.isPrimitiveType(merge.dataType)) {
+ s"${mergeCode.value}"
+ } else {
+ s"($mergeJavaType)InternalRow.copyValue(${mergeCode.value})"
+ }
+
+ val nullCheck = if (nullable) {
+ s"${ev.isNull} = ${finishCode.isNull};"
+ } else {
+ ""
+ }
+
+ val initialAssignment = assignVar(accForMergeCode, mergeAtomic,
zeroCode.value,
+ zeroCode.isNull, zero.nullable)
+
+ val mergeAssignment = assignVar(accForMergeCode, mergeAtomic,
mergeCopy,
+ mergeCode.isNull, merge.nullable)
Review Comment:
Fixed in `261147f456d`.
`initialAssignment` and `mergeAssignment` now both use
`accForMergeVar.nullable`, so the generated accumulator lambda variable null
flag is updated on every assignment. That clears the stale-null state both
after assigning a non-null `zero` and after a non-null merge result.
I added generated-code regressions for both paths from the comment:
- stale null state within a single aggregate loop
- stale null state carried across rows in the same generated partition
I verified both tests fail before the fix and pass after it.
Verified with:
```bash
build/sbt 'sql/testOnly org.apache.spark.sql.DataFrameFunctionsSuite -- -z
"aggregate function - generated code clears accumulator null state"'
```
Result: passed (`2` tests, `0` failures).
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]