Kimahriman commented on code in PR #34558:
URL: https://github.com/apache/spark/pull/34558#discussion_r3290456907
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala:
##########
@@ -886,6 +1170,114 @@ case class ArrayAggregate(
}
}
+ protected def nullSafeCodeGen(
+ ctx: CodegenContext,
+ ev: ExprCode,
+ f: String => String): ExprCode = {
+ val argumentGen = argument.genCode(ctx)
+ val resultCode = f(argumentGen.value)
+
+ if (nullable) {
+ val nullSafeEval = ctx.nullSafeExec(argument.nullable,
argumentGen.isNull)(resultCode)
+ ev.copy(code = code"""
+ |${argumentGen.code}
+ |boolean ${ev.isNull} = ${argumentGen.isNull};
+ |${CodeGenerator.javaType(dataType)} ${ev.value} =
${CodeGenerator.defaultValue(dataType)};
+ |$nullSafeEval
+ """)
+ } else {
+ ev.copy(code = code"""
+ |${argumentGen.code}
+ |${CodeGenerator.javaType(dataType)} ${ev.value} =
${CodeGenerator.defaultValue(dataType)};
+ |$resultCode
+ """, isNull = FalseLiteral)
+ }
+ }
+
+ protected def assignVar(
+ varCode: ExprCode,
+ atomicVar: String,
+ value: String,
+ isNull: String,
+ nullable: Boolean): String = {
+ val atomicAssign = assignAtomic(atomicVar, value, isNull, nullable)
+ if (nullable) {
+ s"""
+ ${varCode.value} = $value;
+ ${varCode.isNull} = $isNull;
+ $atomicAssign
+ """
+ } else {
+ s"""
+ ${varCode.value} = $value;
+ $atomicAssign
+ """
+ }
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ ctx.withLambdaVars(Seq(elementVar, accForMergeVar, accForFinishVar),
varCodes => {
+ val Seq(elementCode, accForMergeCode, accForFinishCode) = varCodes
+
+ nullSafeCodeGen(ctx, ev, arg => {
+ val numElements = ctx.freshName("numElements")
+ val i = ctx.freshName("i")
+
+ val zeroCode = zero.genCode(ctx)
+ val mergeCode = merge.genCode(ctx)
+ val finishCode = finish.genCode(ctx)
+
+ val elementAssignment = assignArrayElement(ctx, arg, elementCode,
elementVar, i)
+ val mergeAtomic = ctx.addReferenceObj(accForMergeVar.name,
+ accForMergeVar.value)
+ val finishAtomic = ctx.addReferenceObj(accForFinishVar.name,
+ accForFinishVar.value)
+
+ val mergeJavaType = CodeGenerator.javaType(accForMergeVar.dataType)
+ val finishJavaType = CodeGenerator.javaType(accForFinishVar.dataType)
+
+ // Some expressions return internal buffers that we have to copy
+ val mergeCopy = if (CodeGenerator.isPrimitiveType(merge.dataType)) {
+ s"${mergeCode.value}"
+ } else {
+ s"($mergeJavaType)InternalRow.copyValue(${mergeCode.value})"
+ }
+
+ val nullCheck = if (nullable) {
+ s"${ev.isNull} = ${finishCode.isNull};"
+ } else {
+ ""
+ }
+
+ val initialAssignment = assignVar(accForMergeCode, mergeAtomic,
zeroCode.value,
+ zeroCode.isNull, zero.nullable)
+
+ val mergeAssignment = assignVar(accForMergeCode, mergeAtomic,
mergeCopy,
+ mergeCode.isNull, merge.nullable)
+
+ val finishAssignment = assignVar(accForFinishCode, finishAtomic,
accForMergeCode.value,
+ accForMergeCode.isNull, merge.nullable)
Review Comment:
Thanks, good catch. Fixed in `9f330488952` by making `finishAssignment`
follow the accumulator lambda variable nullability instead of `merge.nullable`,
so a nullable `zero` is still propagated into `finish` when the array is empty.
I also added the regression case from the comment and verified it with:
```bash
build/sbt 'catalyst/testOnly
org.apache.spark.sql.catalyst.expressions.HigherOrderFunctionsSuite -- -z
ArrayAggregate'
```
Result: passed (`2` tests, `0` failures).
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]