juliuszsompolski commented on code in PR #52399:
URL: https://github.com/apache/spark/pull/52399#discussion_r2753899887


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala:
##########
@@ -92,6 +95,277 @@ case class MergeRowsExec(
     child.execute().mapPartitions(processPartition)
   }
 
+  override def inputRDDs(): Seq[RDD[InternalRow]] = {
+    child.asInstanceOf[CodegenSupport].inputRDDs()
+  }
+
+  protected override def doProduce(ctx: CodegenContext): String = {
+    child.asInstanceOf[CodegenSupport].produce(ctx, this)
+  }
+
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: 
ExprCode): String = {
+    // Save the input variables that were passed to doConsume
+    val inputCurrentVars = input
+
+    // code for instruction execution code
+    generateInstructionExecutionCode(ctx, inputCurrentVars)
+  }
+
+
+  /**
+   * code for cardinality validation
+   */
+  private def generateCardinalityValidationCode(ctx: CodegenContext, 
rowIdOrdinal: Int,
+                                                input: Seq[ExprCode]): 
ExprCode = {
+    val bitmapClass = classOf[Roaring64Bitmap]
+    val rowIdBitmap = ctx.addMutableState(bitmapClass.getName, "matchedRowIds",
+      v => s"$v = new ${bitmapClass.getName}();")
+
+    val currentRowId = input(rowIdOrdinal)
+    val queryExecutionErrorsClass = QueryExecutionErrors.getClass.getName + 
".MODULE$"
+    val code =
+      code"""
+            |${currentRowId.code}
+            |if ($rowIdBitmap.contains(${currentRowId.value})) {
+            |  throw 
$queryExecutionErrorsClass.mergeCardinalityViolationError();
+            |}
+            |$rowIdBitmap.add(${currentRowId.value});
+     """.stripMargin
+    ExprCode(code, FalseLiteral, JavaCode.variable(rowIdBitmap, bitmapClass))
+  }
+
+  /**
+   * Generate code for instruction execution based on row presence conditions
+   */
+  private def generateInstructionExecutionCode(ctx: CodegenContext,
+                                               inputExprs: Seq[ExprCode]): 
String = {
+
+    // code for evaluating src/tgt presence conditions
+    val sourcePresentExpr = generatePredicateCode(ctx, isSourceRowPresent, 
child.output, inputExprs)
+    val targetPresentExpr = generatePredicateCode(ctx, isTargetRowPresent, 
child.output, inputExprs)
+
+    // code for each instruction type
+    val matchedInstructionsCode = generateInstructionsCode(ctx, 
matchedInstructions,
+      "matched", inputExprs, sourcePresent = true)
+    val notMatchedInstructionsCode = generateInstructionsCode(ctx, 
notMatchedInstructions,
+      "notMatched", inputExprs, sourcePresent = true)
+    val notMatchedBySourceInstructionsCode = generateInstructionsCode(ctx,
+      notMatchedBySourceInstructions, "notMatchedBySource", inputExprs, 
sourcePresent = false)
+
+    val cardinalityValidationCode = if (checkCardinality) {
+      val rowIdOrdinal = child.output.indexWhere(attr => 
conf.resolver(attr.name, ROW_ID))
+      assert(rowIdOrdinal != -1, "Cannot find row ID attr")
+      generateCardinalityValidationCode(ctx, rowIdOrdinal, inputExprs).code
+    } else {
+      ""
+    }
+
+    s"""
+       |${sourcePresentExpr.code}
+       |${targetPresentExpr.code}
+       |
+       |if (${targetPresentExpr.value} && ${sourcePresentExpr.value}) {
+       |  $cardinalityValidationCode
+       |  $matchedInstructionsCode
+       |} else if (${sourcePresentExpr.value}) {
+       |  $notMatchedInstructionsCode
+       |} else if (${targetPresentExpr.value}) {
+       |  $notMatchedBySourceInstructionsCode
+       |}
+     """.stripMargin
+  }
+
+  /**
+   * Generate code for executing a sequence of instructions
+   */
+  private def generateInstructionsCode(ctx: CodegenContext, instructions: 
Seq[Instruction],
+                                       instructionType: String,
+                                       inputExprs: Seq[ExprCode],
+                                       sourcePresent: Boolean): String = {
+    if (instructions.isEmpty) {
+      ""
+    } else {
+      val instructionCodes = instructions.map(instruction =>
+        generateSingleInstructionCode(ctx, instruction, inputExprs, 
sourcePresent))
+
+      s"""
+         |${instructionCodes.mkString("\n")}
+         |return;
+       """.stripMargin
+    }
+  }
+
+  private def generateSingleInstructionCode(ctx: CodegenContext,
+                                            instruction: Instruction,
+                                            inputExprs: Seq[ExprCode],
+                                            sourcePresent: Boolean): String = {
+    instruction match {
+      case Keep(context, condition, outputExprs) =>
+        val projectionExpr = generateProjectionCode(ctx, outputExprs, 
inputExprs)
+        val code = generatePredicateCode(ctx, condition, child.output, 
inputExprs)
+
+        // Generate metric updates based on context
+        val metricUpdateCode = generateMetricUpdateCode(ctx, context, 
sourcePresent)
+
+        s"""
+           |${code.code}
+           |if (${code.value}) {
+           |  $metricUpdateCode
+           |  ${consume(ctx, projectionExpr)}
+           |  return;
+           |}
+       """.stripMargin
+
+      case Discard(condition) =>
+        val code = generatePredicateCode(ctx, condition, child.output, 
inputExprs)
+        val metricUpdateCode = generateDeleteMetricUpdateCode(ctx, 
sourcePresent)
+
+        s"""
+           |${code.code}
+           |if (${code.value}) {
+           |  $metricUpdateCode
+           |  return; // Discar row
+           |}
+       """.stripMargin
+
+      case Split(condition, outputExprs, otherOutputExprs) =>
+        val projectionExpr = generateProjectionCode(ctx, outputExprs, 
inputExprs)
+        val otherProjectionExpr = generateProjectionCode(ctx, 
otherOutputExprs, inputExprs)
+        val code = generatePredicateCode(ctx, condition, child.output, 
inputExprs)
+        val metricUpdateCode = generateUpdateMetricUpdateCode(ctx, 
sourcePresent)
+
+        s"""
+           |${code.code}
+           |if (${code.value}) {
+           |  $metricUpdateCode
+           |  ${consume(ctx, projectionExpr)}
+           |  ${consume(ctx, otherProjectionExpr)}
+           |  return;
+           |}
+       """.stripMargin
+      case _ =>
+        // Codegen not implemented
+        throw new SparkUnsupportedOperationException(
+          errorClass = "_LEGACY_ERROR_TEMP_3073",
+          messageParameters = Map("instruction" -> instruction.toString))
+    }
+  }
+
+  /**
+   * metric update code based on Keep's context
+   */
+  private def generateMetricUpdateCode(ctx: CodegenContext, context: Context,
+                                       sourcePresent: Boolean): String = {
+    context match {
+      case Copy =>
+        val copyMetric = metricTerm(ctx, "numTargetRowsCopied")
+        s"$copyMetric.add(1);"
+
+      case Insert =>
+        val insertMetric = metricTerm(ctx, "numTargetRowsInserted")
+        s"$insertMetric.add(1);"
+
+      case Update =>
+        generateUpdateMetricUpdateCode(ctx, sourcePresent)
+
+      case Delete =>
+        generateDeleteMetricUpdateCode(ctx, sourcePresent)
+
+      case _ =>
+        throw new IllegalArgumentException(s"Unexpected context for KeepExec: 
$context")
+    }
+  }
+
+  private def generateUpdateMetricUpdateCode(ctx: CodegenContext,
+                                             sourcePresent: Boolean): String = 
{
+    val updateMetric = metricTerm(ctx, "numTargetRowsUpdated")
+    if (sourcePresent) {
+      val matchedUpdateMetric = metricTerm(ctx, "numTargetRowsMatchedUpdated")
+
+      s"""
+         |$updateMetric.add(1);
+         |$matchedUpdateMetric.add(1);
+       """.stripMargin
+    } else {
+      val notMatchedBySourceUpdateMetric = metricTerm(ctx, 
"numTargetRowsNotMatchedBySourceUpdated")
+
+      s"""
+         |$updateMetric.add(1);
+         |$notMatchedBySourceUpdateMetric.add(1);
+       """.stripMargin
+    }
+  }
+
+  private def generateDeleteMetricUpdateCode(ctx: CodegenContext,
+                                             sourcePresent: Boolean): String = 
{
+    val deleteMetric = metricTerm(ctx, "numTargetRowsDeleted")
+    if (sourcePresent) {
+      val matchedDeleteMetric = metricTerm(ctx, "numTargetRowsMatchedDeleted")
+
+      s"""
+         |$deleteMetric.add(1);
+         |$matchedDeleteMetric.add(1);
+       """.stripMargin
+    } else {
+      val notMatchedBySourceDeleteMetric = metricTerm(ctx, 
"numTargetRowsNotMatchedBySourceDeleted")
+
+      s"""
+         |$deleteMetric.add(1);
+         |$notMatchedBySourceDeleteMetric.add(1);
+       """.stripMargin
+    }
+  }
+
+  /**
+   * Helper method to save and restore CodegenContext state for code 
generation.
+   *
+   * This is needed because when generating code for expressions, the 
CodegenContext
+   * state (currentVars and INPUT_ROW) gets modified during expression 
evaluation.
+   * This method temporarily sets the context to the input variables from 
doConsume
+   * and restores the original state after the block completes.
+   */
+  private def withCodegenContext[T](
+      ctx: CodegenContext,
+      inputCurrentVars: Seq[ExprCode])(block: => T): T = {
+    val originalCurrentVars = ctx.currentVars
+    val originalInputRow = ctx.INPUT_ROW
+    try {
+      // Set to the input variables saved in doConsume
+      ctx.currentVars = inputCurrentVars
+      block
+    } finally {
+      // Restore original context
+      ctx.currentVars = originalCurrentVars
+      ctx.INPUT_ROW = originalInputRow
+    }
+  }
+
+  private def generatePredicateCode(ctx: CodegenContext,
+                                    predicate: Expression,
+                                    inputAttrs: Seq[Attribute],
+                                    inputCurrentVars: Seq[ExprCode]): ExprCode 
= {
+    withCodegenContext(ctx, inputCurrentVars) {
+      val boundPredicate = BindReferences.bindReference(predicate, inputAttrs)
+      val ev = boundPredicate.genCode(ctx)
+      val predicateVar = ctx.freshName("predicateResult")
+      val code = code"""
+                       |${ev.code}
+                       |boolean $predicateVar = !${ev.isNull} && ${ev.value};
+                        """.stripMargin
+      ExprCode(code, FalseLiteral,
+        JavaCode.variable(predicateVar, BooleanType))
+    }
+  }
+
+  private def generateProjectionCode(ctx: CodegenContext,
+                                     outputExprs: Seq[Expression],
+                                     inputCurrentVars: Seq[ExprCode]): 
Seq[ExprCode] = {
+    withCodegenContext(ctx, inputCurrentVars) {
+      val boundExprs = outputExprs.map(BindReferences.bindReference(_, 
child.output))
+      boundExprs.map(_.genCode(ctx))
+    }
+  }

Review Comment:
   this code is quite generic. Aren't there existing helpers for this? If not, 
maybe move it to some generic helper (it's stateless, so companion `object 
CodegenSupport`)?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala:
##########
@@ -92,6 +95,277 @@ case class MergeRowsExec(
     child.execute().mapPartitions(processPartition)
   }
 
+  override def inputRDDs(): Seq[RDD[InternalRow]] = {
+    child.asInstanceOf[CodegenSupport].inputRDDs()
+  }
+
+  protected override def doProduce(ctx: CodegenContext): String = {
+    child.asInstanceOf[CodegenSupport].produce(ctx, this)
+  }
+
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: 
ExprCode): String = {
+    // Save the input variables that were passed to doConsume
+    val inputCurrentVars = input
+
+    // code for instruction execution code
+    generateInstructionExecutionCode(ctx, inputCurrentVars)
+  }
+
+
+  /**
+   * code for cardinality validation
+   */
+  private def generateCardinalityValidationCode(ctx: CodegenContext, 
rowIdOrdinal: Int,
+                                                input: Seq[ExprCode]): 
ExprCode = {
+    val bitmapClass = classOf[Roaring64Bitmap]
+    val rowIdBitmap = ctx.addMutableState(bitmapClass.getName, "matchedRowIds",
+      v => s"$v = new ${bitmapClass.getName}();")
+
+    val currentRowId = input(rowIdOrdinal)
+    val queryExecutionErrorsClass = QueryExecutionErrors.getClass.getName + 
".MODULE$"
+    val code =
+      code"""
+            |${currentRowId.code}
+            |if ($rowIdBitmap.contains(${currentRowId.value})) {
+            |  throw 
$queryExecutionErrorsClass.mergeCardinalityViolationError();
+            |}
+            |$rowIdBitmap.add(${currentRowId.value});
+     """.stripMargin
+    ExprCode(code, FalseLiteral, JavaCode.variable(rowIdBitmap, bitmapClass))
+  }
+
+  /**
+   * Generate code for instruction execution based on row presence conditions
+   */
+  private def generateInstructionExecutionCode(ctx: CodegenContext,
+                                               inputExprs: Seq[ExprCode]): 
String = {
+
+    // code for evaluating src/tgt presence conditions
+    val sourcePresentExpr = generatePredicateCode(ctx, isSourceRowPresent, 
child.output, inputExprs)
+    val targetPresentExpr = generatePredicateCode(ctx, isTargetRowPresent, 
child.output, inputExprs)
+
+    // code for each instruction type
+    val matchedInstructionsCode = generateInstructionsCode(ctx, 
matchedInstructions,
+      "matched", inputExprs, sourcePresent = true)
+    val notMatchedInstructionsCode = generateInstructionsCode(ctx, 
notMatchedInstructions,
+      "notMatched", inputExprs, sourcePresent = true)
+    val notMatchedBySourceInstructionsCode = generateInstructionsCode(ctx,
+      notMatchedBySourceInstructions, "notMatchedBySource", inputExprs, 
sourcePresent = false)
+
+    val cardinalityValidationCode = if (checkCardinality) {
+      val rowIdOrdinal = child.output.indexWhere(attr => 
conf.resolver(attr.name, ROW_ID))
+      assert(rowIdOrdinal != -1, "Cannot find row ID attr")
+      generateCardinalityValidationCode(ctx, rowIdOrdinal, inputExprs).code
+    } else {
+      ""
+    }
+
+    s"""
+       |${sourcePresentExpr.code}
+       |${targetPresentExpr.code}
+       |
+       |if (${targetPresentExpr.value} && ${sourcePresentExpr.value}) {
+       |  $cardinalityValidationCode
+       |  $matchedInstructionsCode
+       |} else if (${sourcePresentExpr.value}) {
+       |  $notMatchedInstructionsCode
+       |} else if (${targetPresentExpr.value}) {
+       |  $notMatchedBySourceInstructionsCode
+       |}
+     """.stripMargin
+  }
+
+  /**
+   * Generate code for executing a sequence of instructions
+   */
+  private def generateInstructionsCode(ctx: CodegenContext, instructions: 
Seq[Instruction],
+                                       instructionType: String,

Review Comment:
   parameter instructionType is not used.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala:
##########
@@ -92,6 +95,277 @@ case class MergeRowsExec(
     child.execute().mapPartitions(processPartition)
   }
 
+  override def inputRDDs(): Seq[RDD[InternalRow]] = {
+    child.asInstanceOf[CodegenSupport].inputRDDs()
+  }
+
+  protected override def doProduce(ctx: CodegenContext): String = {
+    child.asInstanceOf[CodegenSupport].produce(ctx, this)
+  }
+
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: 
ExprCode): String = {
+    // Save the input variables that were passed to doConsume
+    val inputCurrentVars = input
+
+    // code for instruction execution code
+    generateInstructionExecutionCode(ctx, inputCurrentVars)
+  }
+
+
+  /**
+   * code for cardinality validation
+   */
+  private def generateCardinalityValidationCode(ctx: CodegenContext, 
rowIdOrdinal: Int,
+                                                input: Seq[ExprCode]): 
ExprCode = {
+    val bitmapClass = classOf[Roaring64Bitmap]
+    val rowIdBitmap = ctx.addMutableState(bitmapClass.getName, "matchedRowIds",
+      v => s"$v = new ${bitmapClass.getName}();")
+
+    val currentRowId = input(rowIdOrdinal)
+    val queryExecutionErrorsClass = QueryExecutionErrors.getClass.getName + 
".MODULE$"
+    val code =
+      code"""
+            |${currentRowId.code}
+            |if ($rowIdBitmap.contains(${currentRowId.value})) {
+            |  throw 
$queryExecutionErrorsClass.mergeCardinalityViolationError();
+            |}
+            |$rowIdBitmap.add(${currentRowId.value});
+     """.stripMargin
+    ExprCode(code, FalseLiteral, JavaCode.variable(rowIdBitmap, bitmapClass))
+  }
+
+  /**
+   * Generate code for instruction execution based on row presence conditions
+   */
+  private def generateInstructionExecutionCode(ctx: CodegenContext,
+                                               inputExprs: Seq[ExprCode]): 
String = {
+
+    // code for evaluating src/tgt presence conditions
+    val sourcePresentExpr = generatePredicateCode(ctx, isSourceRowPresent, 
child.output, inputExprs)
+    val targetPresentExpr = generatePredicateCode(ctx, isTargetRowPresent, 
child.output, inputExprs)
+
+    // code for each instruction type
+    val matchedInstructionsCode = generateInstructionsCode(ctx, 
matchedInstructions,
+      "matched", inputExprs, sourcePresent = true)
+    val notMatchedInstructionsCode = generateInstructionsCode(ctx, 
notMatchedInstructions,
+      "notMatched", inputExprs, sourcePresent = true)
+    val notMatchedBySourceInstructionsCode = generateInstructionsCode(ctx,
+      notMatchedBySourceInstructions, "notMatchedBySource", inputExprs, 
sourcePresent = false)
+
+    val cardinalityValidationCode = if (checkCardinality) {
+      val rowIdOrdinal = child.output.indexWhere(attr => 
conf.resolver(attr.name, ROW_ID))
+      assert(rowIdOrdinal != -1, "Cannot find row ID attr")
+      generateCardinalityValidationCode(ctx, rowIdOrdinal, inputExprs).code
+    } else {
+      ""
+    }
+
+    s"""
+       |${sourcePresentExpr.code}
+       |${targetPresentExpr.code}
+       |
+       |if (${targetPresentExpr.value} && ${sourcePresentExpr.value}) {
+       |  $cardinalityValidationCode
+       |  $matchedInstructionsCode
+       |} else if (${sourcePresentExpr.value}) {
+       |  $notMatchedInstructionsCode
+       |} else if (${targetPresentExpr.value}) {
+       |  $notMatchedBySourceInstructionsCode
+       |}
+     """.stripMargin
+  }
+
+  /**
+   * Generate code for executing a sequence of instructions
+   */
+  private def generateInstructionsCode(ctx: CodegenContext, instructions: 
Seq[Instruction],
+                                       instructionType: String,
+                                       inputExprs: Seq[ExprCode],
+                                       sourcePresent: Boolean): String = {
+    if (instructions.isEmpty) {
+      ""
+    } else {
+      val instructionCodes = instructions.map(instruction =>
+        generateSingleInstructionCode(ctx, instruction, inputExprs, 
sourcePresent))
+
+      s"""
+         |${instructionCodes.mkString("\n")}
+         |return;
+       """.stripMargin
+    }
+  }
+
+  private def generateSingleInstructionCode(ctx: CodegenContext,
+                                            instruction: Instruction,
+                                            inputExprs: Seq[ExprCode],
+                                            sourcePresent: Boolean): String = {
+    instruction match {
+      case Keep(context, condition, outputExprs) =>
+        val projectionExpr = generateProjectionCode(ctx, outputExprs, 
inputExprs)
+        val code = generatePredicateCode(ctx, condition, child.output, 
inputExprs)
+
+        // Generate metric updates based on context
+        val metricUpdateCode = generateMetricUpdateCode(ctx, context, 
sourcePresent)
+
+        s"""
+           |${code.code}
+           |if (${code.value}) {
+           |  $metricUpdateCode
+           |  ${consume(ctx, projectionExpr)}
+           |  return;
+           |}
+       """.stripMargin
+
+      case Discard(condition) =>
+        val code = generatePredicateCode(ctx, condition, child.output, 
inputExprs)
+        val metricUpdateCode = generateDeleteMetricUpdateCode(ctx, 
sourcePresent)
+
+        s"""
+           |${code.code}
+           |if (${code.value}) {
+           |  $metricUpdateCode
+           |  return; // Discar row
+           |}
+       """.stripMargin
+
+      case Split(condition, outputExprs, otherOutputExprs) =>
+        val projectionExpr = generateProjectionCode(ctx, outputExprs, 
inputExprs)
+        val otherProjectionExpr = generateProjectionCode(ctx, 
otherOutputExprs, inputExprs)
+        val code = generatePredicateCode(ctx, condition, child.output, 
inputExprs)
+        val metricUpdateCode = generateUpdateMetricUpdateCode(ctx, 
sourcePresent)
+
+        s"""
+           |${code.code}
+           |if (${code.value}) {
+           |  $metricUpdateCode
+           |  ${consume(ctx, projectionExpr)}
+           |  ${consume(ctx, otherProjectionExpr)}
+           |  return;
+           |}
+       """.stripMargin
+      case _ =>
+        // Codegen not implemented
+        throw new SparkUnsupportedOperationException(
+          errorClass = "_LEGACY_ERROR_TEMP_3073",
+          messageParameters = Map("instruction" -> instruction.toString))
+    }
+  }
+
+  /**
+   * metric update code based on Keep's context
+   */
+  private def generateMetricUpdateCode(ctx: CodegenContext, context: Context,
+                                       sourcePresent: Boolean): String = {
+    context match {
+      case Copy =>
+        val copyMetric = metricTerm(ctx, "numTargetRowsCopied")
+        s"$copyMetric.add(1);"
+
+      case Insert =>
+        val insertMetric = metricTerm(ctx, "numTargetRowsInserted")
+        s"$insertMetric.add(1);"
+
+      case Update =>
+        generateUpdateMetricUpdateCode(ctx, sourcePresent)
+
+      case Delete =>
+        generateDeleteMetricUpdateCode(ctx, sourcePresent)
+
+      case _ =>
+        throw new IllegalArgumentException(s"Unexpected context for KeepExec: 
$context")
+    }
+  }

Review Comment:
   move it to Instruction class.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala:
##########
@@ -92,6 +95,277 @@ case class MergeRowsExec(
     child.execute().mapPartitions(processPartition)
   }
 
+  override def inputRDDs(): Seq[RDD[InternalRow]] = {
+    child.asInstanceOf[CodegenSupport].inputRDDs()
+  }
+
+  protected override def doProduce(ctx: CodegenContext): String = {
+    child.asInstanceOf[CodegenSupport].produce(ctx, this)
+  }
+
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: 
ExprCode): String = {
+    // Save the input variables that were passed to doConsume
+    val inputCurrentVars = input
+
+    // code for instruction execution code
+    generateInstructionExecutionCode(ctx, inputCurrentVars)
+  }
+
+
+  /**
+   * code for cardinality validation
+   */
+  private def generateCardinalityValidationCode(ctx: CodegenContext, 
rowIdOrdinal: Int,
+                                                input: Seq[ExprCode]): 
ExprCode = {
+    val bitmapClass = classOf[Roaring64Bitmap]
+    val rowIdBitmap = ctx.addMutableState(bitmapClass.getName, "matchedRowIds",
+      v => s"$v = new ${bitmapClass.getName}();")
+
+    val currentRowId = input(rowIdOrdinal)
+    val queryExecutionErrorsClass = QueryExecutionErrors.getClass.getName + 
".MODULE$"
+    val code =
+      code"""
+            |${currentRowId.code}
+            |if ($rowIdBitmap.contains(${currentRowId.value})) {
+            |  throw 
$queryExecutionErrorsClass.mergeCardinalityViolationError();
+            |}
+            |$rowIdBitmap.add(${currentRowId.value});
+     """.stripMargin
+    ExprCode(code, FalseLiteral, JavaCode.variable(rowIdBitmap, bitmapClass))
+  }
+
+  /**
+   * Generate code for instruction execution based on row presence conditions
+   */
+  private def generateInstructionExecutionCode(ctx: CodegenContext,
+                                               inputExprs: Seq[ExprCode]): 
String = {
+
+    // code for evaluating src/tgt presence conditions
+    val sourcePresentExpr = generatePredicateCode(ctx, isSourceRowPresent, 
child.output, inputExprs)
+    val targetPresentExpr = generatePredicateCode(ctx, isTargetRowPresent, 
child.output, inputExprs)
+
+    // code for each instruction type
+    val matchedInstructionsCode = generateInstructionsCode(ctx, 
matchedInstructions,
+      "matched", inputExprs, sourcePresent = true)
+    val notMatchedInstructionsCode = generateInstructionsCode(ctx, 
notMatchedInstructions,
+      "notMatched", inputExprs, sourcePresent = true)
+    val notMatchedBySourceInstructionsCode = generateInstructionsCode(ctx,
+      notMatchedBySourceInstructions, "notMatchedBySource", inputExprs, 
sourcePresent = false)
+
+    val cardinalityValidationCode = if (checkCardinality) {
+      val rowIdOrdinal = child.output.indexWhere(attr => 
conf.resolver(attr.name, ROW_ID))
+      assert(rowIdOrdinal != -1, "Cannot find row ID attr")
+      generateCardinalityValidationCode(ctx, rowIdOrdinal, inputExprs).code
+    } else {
+      ""
+    }
+
+    s"""
+       |${sourcePresentExpr.code}
+       |${targetPresentExpr.code}
+       |
+       |if (${targetPresentExpr.value} && ${sourcePresentExpr.value}) {
+       |  $cardinalityValidationCode
+       |  $matchedInstructionsCode
+       |} else if (${sourcePresentExpr.value}) {
+       |  $notMatchedInstructionsCode
+       |} else if (${targetPresentExpr.value}) {
+       |  $notMatchedBySourceInstructionsCode
+       |}
+     """.stripMargin
+  }
+
+  /**
+   * Generate code for executing a sequence of instructions
+   */
+  private def generateInstructionsCode(ctx: CodegenContext, instructions: 
Seq[Instruction],
+                                       instructionType: String,
+                                       inputExprs: Seq[ExprCode],
+                                       sourcePresent: Boolean): String = {

Review Comment:
   scalastyle: break one parameter per line, 4 space indent for parameters 
(here and in other places)
   ```suggestion
     private def generateInstructionsCode(
         ctx: CodegenContext,
         instructions: Seq[Instruction],
         instructionType: String,
         inputExprs: Seq[ExprCode],
         sourcePresent: Boolean): String = {
   ```



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala:
##########
@@ -92,6 +95,277 @@ case class MergeRowsExec(
     child.execute().mapPartitions(processPartition)
   }
 
+  override def inputRDDs(): Seq[RDD[InternalRow]] = {
+    child.asInstanceOf[CodegenSupport].inputRDDs()
+  }
+
+  protected override def doProduce(ctx: CodegenContext): String = {
+    child.asInstanceOf[CodegenSupport].produce(ctx, this)
+  }
+
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: 
ExprCode): String = {
+    // Save the input variables that were passed to doConsume
+    val inputCurrentVars = input
+
+    // code for instruction execution code
+    generateInstructionExecutionCode(ctx, inputCurrentVars)
+  }
+
+
+  /**
+   * code for cardinality validation
+   */
+  private def generateCardinalityValidationCode(ctx: CodegenContext, 
rowIdOrdinal: Int,
+                                                input: Seq[ExprCode]): 
ExprCode = {
+    val bitmapClass = classOf[Roaring64Bitmap]
+    val rowIdBitmap = ctx.addMutableState(bitmapClass.getName, "matchedRowIds",
+      v => s"$v = new ${bitmapClass.getName}();")
+
+    val currentRowId = input(rowIdOrdinal)
+    val queryExecutionErrorsClass = QueryExecutionErrors.getClass.getName + 
".MODULE$"
+    val code =
+      code"""
+            |${currentRowId.code}
+            |if ($rowIdBitmap.contains(${currentRowId.value})) {
+            |  throw 
$queryExecutionErrorsClass.mergeCardinalityViolationError();
+            |}
+            |$rowIdBitmap.add(${currentRowId.value});
+     """.stripMargin
+    ExprCode(code, FalseLiteral, JavaCode.variable(rowIdBitmap, bitmapClass))
+  }
+
+  /**
+   * Generate code for instruction execution based on row presence conditions
+   */
+  private def generateInstructionExecutionCode(ctx: CodegenContext,
+                                               inputExprs: Seq[ExprCode]): 
String = {
+
+    // code for evaluating src/tgt presence conditions
+    val sourcePresentExpr = generatePredicateCode(ctx, isSourceRowPresent, 
child.output, inputExprs)
+    val targetPresentExpr = generatePredicateCode(ctx, isTargetRowPresent, 
child.output, inputExprs)
+
+    // code for each instruction type
+    val matchedInstructionsCode = generateInstructionsCode(ctx, 
matchedInstructions,
+      "matched", inputExprs, sourcePresent = true)
+    val notMatchedInstructionsCode = generateInstructionsCode(ctx, 
notMatchedInstructions,
+      "notMatched", inputExprs, sourcePresent = true)
+    val notMatchedBySourceInstructionsCode = generateInstructionsCode(ctx,
+      notMatchedBySourceInstructions, "notMatchedBySource", inputExprs, 
sourcePresent = false)
+
+    val cardinalityValidationCode = if (checkCardinality) {
+      val rowIdOrdinal = child.output.indexWhere(attr => 
conf.resolver(attr.name, ROW_ID))
+      assert(rowIdOrdinal != -1, "Cannot find row ID attr")
+      generateCardinalityValidationCode(ctx, rowIdOrdinal, inputExprs).code
+    } else {
+      ""
+    }
+
+    s"""
+       |${sourcePresentExpr.code}
+       |${targetPresentExpr.code}
+       |
+       |if (${targetPresentExpr.value} && ${sourcePresentExpr.value}) {
+       |  $cardinalityValidationCode
+       |  $matchedInstructionsCode
+       |} else if (${sourcePresentExpr.value}) {
+       |  $notMatchedInstructionsCode
+       |} else if (${targetPresentExpr.value}) {
+       |  $notMatchedBySourceInstructionsCode
+       |}
+     """.stripMargin
+  }
+
+  /**
+   * Generate code for executing a sequence of instructions
+   */
+  private def generateInstructionsCode(ctx: CodegenContext, instructions: 
Seq[Instruction],
+                                       instructionType: String,
+                                       inputExprs: Seq[ExprCode],
+                                       sourcePresent: Boolean): String = {
+    if (instructions.isEmpty) {
+      ""
+    } else {
+      val instructionCodes = instructions.map(instruction =>
+        generateSingleInstructionCode(ctx, instruction, inputExprs, 
sourcePresent))
+
+      s"""
+         |${instructionCodes.mkString("\n")}
+         |return;
+       """.stripMargin
+    }
+  }
+
+  private def generateSingleInstructionCode(ctx: CodegenContext,
+                                            instruction: Instruction,
+                                            inputExprs: Seq[ExprCode],
+                                            sourcePresent: Boolean): String = {
+    instruction match {

Review Comment:
   for better code locality and readability, move this method to the 
Instruction class, with each subclass overriding it



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to