aokolnychyi commented on code in PR #41448:
URL: https://github.com/apache/spark/pull/41448#discussion_r1222497417


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala:
##########
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.roaringbitmap.longlong.Roaring64Bitmap
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.AttributeSet
+import org.apache.spark.sql.catalyst.expressions.BasePredicate
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.Projection
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
+import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Instruction, 
Keep, ROW_ID, Split}
+import org.apache.spark.sql.catalyst.util.truncatedString
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.UnaryExecNode
+
+case class MergeRowsExec(
+    isSourceRowPresent: Expression,
+    isTargetRowPresent: Expression,
+    matchedInstructions: Seq[Instruction],
+    notMatchedInstructions: Seq[Instruction],
+    notMatchedBySourceInstructions: Seq[Instruction],
+    checkCardinality: Boolean,
+    output: Seq[Attribute],
+    child: SparkPlan) extends UnaryExecNode {
+
+  @transient override lazy val producedAttributes: AttributeSet = {
+    AttributeSet(output.filterNot(attr => inputSet.contains(attr)))
+  }
+
+  @transient override lazy val references: AttributeSet = child.outputSet
+
+  override def simpleString(maxFields: Int): String = {
+    s"MergeRowsExec${truncatedString(output, "[", ", ", "]", maxFields)}"
+  }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan 
= {
+    copy(child = newChild)
+  }
+
+  protected override def doExecute(): RDD[InternalRow] = {
+    child.execute().mapPartitions(processPartition)
+  }
+
+  private def processPartition(rowIterator: Iterator[InternalRow]): 
Iterator[InternalRow] = {
+    val isSourceRowPresentPred = createPredicate(isSourceRowPresent)
+    val isTargetRowPresentPred = createPredicate(isTargetRowPresent)
+
+    val matchedInstructionExecs = planInstructions(matchedInstructions)
+    val notMatchedInstructionExecs = planInstructions(notMatchedInstructions)
+    val notMatchedBySourceInstructionExecs = 
planInstructions(notMatchedBySourceInstructions)
+
+    val cardinalityValidator = if (checkCardinality) {
+      val rowIdOrdinal = child.output.indexWhere(attr => 
conf.resolver(attr.name, ROW_ID))
+      assert(rowIdOrdinal != -1, "Cannot find row ID attr")
+      BitmapCardinalityValidator(rowIdOrdinal)
+    } else {
+      NoopCardinalityValidator
+    }
+
+    val mergeIterator = new MergeRowIterator(
+      rowIterator, cardinalityValidator, isTargetRowPresentPred, 
isSourceRowPresentPred,
+      matchedInstructionExecs, notMatchedInstructionExecs, 
notMatchedBySourceInstructionExecs)
+
+    // null indicates a record must be discarded
+    mergeIterator.filter(_ != null)
+  }
+
+  private def createProjection(exprs: Seq[Expression]): UnsafeProjection = {
+    UnsafeProjection.create(exprs, child.output)
+  }
+
+  private def createPredicate(expr: Expression): BasePredicate = {
+    GeneratePredicate.generate(expr, child.output)
+  }
+
+  private def planInstructions(instructions: Seq[Instruction]): 
Seq[InstructionExec] = {
+    instructions.map {
+      case Keep(cond, output) =>
+        KeepExec(createPredicate(cond), createProjection(output))
+
+      case Split(cond, output, otherOutput) =>
+        SplitExec(createPredicate(cond), createProjection(output), 
createProjection(otherOutput))
+
+      case other =>
+        throw new AnalysisException(s"Unexpected instruction: $other")
+    }
+  }
+
+  sealed trait InstructionExec {
+    def condition: BasePredicate
+  }
+
+  case class KeepExec(condition: BasePredicate, projection: Projection) 
extends InstructionExec {
+    def apply(row: InternalRow): InternalRow = projection.apply(row)
+  }
+
+  case class SplitExec(
+      condition: BasePredicate,
+      projection: Projection,
+      otherProjection: Projection) extends InstructionExec {
+    def projectRow(row: InternalRow): InternalRow = projection.apply(row)
+    def projectExtraRow(row: InternalRow): InternalRow = 
otherProjection.apply(row)
+  }
+
+  sealed trait CardinalityValidator {
+    def validate(row: InternalRow): Unit
+  }
+
+  object NoopCardinalityValidator extends CardinalityValidator {
+    def validate(row: InternalRow): Unit = {}
+  }
+
+  /**
+   * A simple cardinality validator that keeps track of seen row IDs in a 
roaring bitmap.
+   * This validator assumes the target table is never broadcasted or 
replicated, which guarantees
+   * matches for one target row are always co-located in the same partition.
+   *
+   * IDs are generated by 
[[org.apache.spark.sql.catalyst.expressions.MonotonicallyIncreasingID]].
+   */
+  case class BitmapCardinalityValidator(rowIdOrdinal: Int) extends 
CardinalityValidator {
+    private val matchedRowIds = new Roaring64Bitmap()

Review Comment:
   Added above.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to