juliuszsompolski commented on code in PR #52399:
URL: https://github.com/apache/spark/pull/52399#discussion_r2758435272
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala:
##########
@@ -92,6 +95,277 @@ case class MergeRowsExec(
child.execute().mapPartitions(processPartition)
}
+ override def inputRDDs(): Seq[RDD[InternalRow]] = {
+ child.asInstanceOf[CodegenSupport].inputRDDs()
+ }
+
+ protected override def doProduce(ctx: CodegenContext): String = {
+ child.asInstanceOf[CodegenSupport].produce(ctx, this)
+ }
+
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row:
ExprCode): String = {
+ // Save the input variables that were passed to doConsume
+ val inputCurrentVars = input
+
+ // code for instruction execution code
+ generateInstructionExecutionCode(ctx, inputCurrentVars)
+ }
+
+
+ /**
+ * code for cardinality validation
+ */
+ private def generateCardinalityValidationCode(ctx: CodegenContext,
rowIdOrdinal: Int,
+ input: Seq[ExprCode]):
ExprCode = {
+ val bitmapClass = classOf[Roaring64Bitmap]
+ val rowIdBitmap = ctx.addMutableState(bitmapClass.getName, "matchedRowIds",
+ v => s"$v = new ${bitmapClass.getName}();")
+
+ val currentRowId = input(rowIdOrdinal)
+ val queryExecutionErrorsClass = QueryExecutionErrors.getClass.getName +
".MODULE$"
+ val code =
+ code"""
+ |${currentRowId.code}
+ |if ($rowIdBitmap.contains(${currentRowId.value})) {
+ | throw
$queryExecutionErrorsClass.mergeCardinalityViolationError();
+ |}
+ |$rowIdBitmap.add(${currentRowId.value});
+ """.stripMargin
+ ExprCode(code, FalseLiteral, JavaCode.variable(rowIdBitmap, bitmapClass))
+ }
+
+ /**
+ * Generate code for instruction execution based on row presence conditions
+ */
+ private def generateInstructionExecutionCode(ctx: CodegenContext,
+ inputExprs: Seq[ExprCode]):
String = {
+
+ // code for evaluating src/tgt presence conditions
+ val sourcePresentExpr = generatePredicateCode(ctx, isSourceRowPresent,
child.output, inputExprs)
+ val targetPresentExpr = generatePredicateCode(ctx, isTargetRowPresent,
child.output, inputExprs)
+
+ // code for each instruction type
+ val matchedInstructionsCode = generateInstructionsCode(ctx,
matchedInstructions,
+ "matched", inputExprs, sourcePresent = true)
+ val notMatchedInstructionsCode = generateInstructionsCode(ctx,
notMatchedInstructions,
+ "notMatched", inputExprs, sourcePresent = true)
+ val notMatchedBySourceInstructionsCode = generateInstructionsCode(ctx,
+ notMatchedBySourceInstructions, "notMatchedBySource", inputExprs,
sourcePresent = false)
+
+ val cardinalityValidationCode = if (checkCardinality) {
+ val rowIdOrdinal = child.output.indexWhere(attr =>
conf.resolver(attr.name, ROW_ID))
+ assert(rowIdOrdinal != -1, "Cannot find row ID attr")
+ generateCardinalityValidationCode(ctx, rowIdOrdinal, inputExprs).code
+ } else {
+ ""
+ }
+
+ s"""
+ |${sourcePresentExpr.code}
+ |${targetPresentExpr.code}
+ |
+ |if (${targetPresentExpr.value} && ${sourcePresentExpr.value}) {
+ | $cardinalityValidationCode
+ | $matchedInstructionsCode
+ |} else if (${sourcePresentExpr.value}) {
+ | $notMatchedInstructionsCode
+ |} else if (${targetPresentExpr.value}) {
+ | $notMatchedBySourceInstructionsCode
+ |}
+ """.stripMargin
+ }
+
+ /**
+ * Generate code for executing a sequence of instructions
+ */
+ private def generateInstructionsCode(ctx: CodegenContext, instructions:
Seq[Instruction],
+ instructionType: String,
+ inputExprs: Seq[ExprCode],
+ sourcePresent: Boolean): String = {
+ if (instructions.isEmpty) {
+ ""
+ } else {
+ val instructionCodes = instructions.map(instruction =>
+ generateSingleInstructionCode(ctx, instruction, inputExprs,
sourcePresent))
+
+ s"""
+ |${instructionCodes.mkString("\n")}
+ |return;
+ """.stripMargin
+ }
+ }
+
+ private def generateSingleInstructionCode(ctx: CodegenContext,
+ instruction: Instruction,
+ inputExprs: Seq[ExprCode],
+ sourcePresent: Boolean): String = {
+ instruction match {
+ case Keep(context, condition, outputExprs) =>
+ val projectionExpr = generateProjectionCode(ctx, outputExprs,
inputExprs)
+ val code = generatePredicateCode(ctx, condition, child.output,
inputExprs)
+
+ // Generate metric updates based on context
+ val metricUpdateCode = generateMetricUpdateCode(ctx, context,
sourcePresent)
+
+ s"""
+ |${code.code}
+ |if (${code.value}) {
+ | $metricUpdateCode
+ | ${consume(ctx, projectionExpr)}
+ | return;
+ |}
+ """.stripMargin
+
+ case Discard(condition) =>
+ val code = generatePredicateCode(ctx, condition, child.output,
inputExprs)
+ val metricUpdateCode = generateDeleteMetricUpdateCode(ctx,
sourcePresent)
+
+ s"""
+ |${code.code}
+ |if (${code.value}) {
+ | $metricUpdateCode
+ | return; // Discar row
+ |}
+ """.stripMargin
+
+ case Split(condition, outputExprs, otherOutputExprs) =>
+ val projectionExpr = generateProjectionCode(ctx, outputExprs,
inputExprs)
+ val otherProjectionExpr = generateProjectionCode(ctx,
otherOutputExprs, inputExprs)
+ val code = generatePredicateCode(ctx, condition, child.output,
inputExprs)
+ val metricUpdateCode = generateUpdateMetricUpdateCode(ctx,
sourcePresent)
+
+ s"""
+ |${code.code}
+ |if (${code.value}) {
+ | $metricUpdateCode
+ | ${consume(ctx, projectionExpr)}
+ | ${consume(ctx, otherProjectionExpr)}
+ | return;
+ |}
+ """.stripMargin
+ case _ =>
+ // Codegen not implemented
+ throw new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_3073",
+ messageParameters = Map("instruction" -> instruction.toString))
+ }
+ }
+
+ /**
+ * metric update code based on Keep's context
+ */
+ private def generateMetricUpdateCode(ctx: CodegenContext, context: Context,
+ sourcePresent: Boolean): String = {
+ context match {
+ case Copy =>
+ val copyMetric = metricTerm(ctx, "numTargetRowsCopied")
+ s"$copyMetric.add(1);"
+
+ case Insert =>
+ val insertMetric = metricTerm(ctx, "numTargetRowsInserted")
+ s"$insertMetric.add(1);"
+
+ case Update =>
+ generateUpdateMetricUpdateCode(ctx, sourcePresent)
+
+ case Delete =>
+ generateDeleteMetricUpdateCode(ctx, sourcePresent)
+
+ case _ =>
+ throw new IllegalArgumentException(s"Unexpected context for KeepExec:
$context")
+ }
+ }
+
+ private def generateUpdateMetricUpdateCode(ctx: CodegenContext,
+ sourcePresent: Boolean): String =
{
+ val updateMetric = metricTerm(ctx, "numTargetRowsUpdated")
+ if (sourcePresent) {
+ val matchedUpdateMetric = metricTerm(ctx, "numTargetRowsMatchedUpdated")
+
+ s"""
+ |$updateMetric.add(1);
+ |$matchedUpdateMetric.add(1);
+ """.stripMargin
+ } else {
+ val notMatchedBySourceUpdateMetric = metricTerm(ctx,
"numTargetRowsNotMatchedBySourceUpdated")
+
+ s"""
+ |$updateMetric.add(1);
+ |$notMatchedBySourceUpdateMetric.add(1);
+ """.stripMargin
+ }
+ }
+
+ private def generateDeleteMetricUpdateCode(ctx: CodegenContext,
+ sourcePresent: Boolean): String =
{
+ val deleteMetric = metricTerm(ctx, "numTargetRowsDeleted")
+ if (sourcePresent) {
+ val matchedDeleteMetric = metricTerm(ctx, "numTargetRowsMatchedDeleted")
+
+ s"""
+ |$deleteMetric.add(1);
+ |$matchedDeleteMetric.add(1);
+ """.stripMargin
+ } else {
+ val notMatchedBySourceDeleteMetric = metricTerm(ctx,
"numTargetRowsNotMatchedBySourceDeleted")
+
+ s"""
+ |$deleteMetric.add(1);
+ |$notMatchedBySourceDeleteMetric.add(1);
+ """.stripMargin
+ }
+ }
+
+ /**
+ * Helper method to save and restore CodegenContext state for code
generation.
+ *
+ * This is needed because when generating code for expressions, the
CodegenContext
+ * state (currentVars and INPUT_ROW) gets modified during expression
evaluation.
+ * This method temporarily sets the context to the input variables from
doConsume
+ * and restores the original state after the block completes.
+ */
+ private def withCodegenContext[T](
+ ctx: CodegenContext,
+ inputCurrentVars: Seq[ExprCode])(block: => T): T = {
+ val originalCurrentVars = ctx.currentVars
+ val originalInputRow = ctx.INPUT_ROW
+ try {
+ // Set to the input variables saved in doConsume
+ ctx.currentVars = inputCurrentVars
+ block
+ } finally {
+ // Restore original context
+ ctx.currentVars = originalCurrentVars
+ ctx.INPUT_ROW = originalInputRow
+ }
+ }
+
+ private def generatePredicateCode(ctx: CodegenContext,
+ predicate: Expression,
+ inputAttrs: Seq[Attribute],
+ inputCurrentVars: Seq[ExprCode]): ExprCode
= {
+ withCodegenContext(ctx, inputCurrentVars) {
+ val boundPredicate = BindReferences.bindReference(predicate, inputAttrs)
+ val ev = boundPredicate.genCode(ctx)
+ val predicateVar = ctx.freshName("predicateResult")
+ val code = code"""
+ |${ev.code}
+ |boolean $predicateVar = !${ev.isNull} && ${ev.value};
+ """.stripMargin
+ ExprCode(code, FalseLiteral,
+ JavaCode.variable(predicateVar, BooleanType))
+ }
+ }
+
+ private def generateProjectionCode(ctx: CodegenContext,
+ outputExprs: Seq[Expression],
+ inputCurrentVars: Seq[ExprCode]):
Seq[ExprCode] = {
+ withCodegenContext(ctx, inputCurrentVars) {
+ val boundExprs = outputExprs.map(BindReferences.bindReference(_,
child.output))
+ boundExprs.map(_.genCode(ctx))
+ }
+ }
Review Comment:
It looks like [object
CodeGenerator](https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala#L1435C1-L1435C21)
would be the best place to add generic helpers.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]