sarutak commented on code in PR #55912:
URL: https://github.com/apache/spark/pull/55912#discussion_r3265720895


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeAsOfJoinExec.scala:
##########
@@ -0,0 +1,426 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.TaskContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReference
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.catalyst.util.TypeUtils
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+
+/**
+ * Performs an AS-OF join using sort-merge. Both sides are co-partitioned
+ * by the equi-join keys and sorted by (equi-join keys, as-of key).
+ * For each left row, we scan the right side to find the nearest match
+ * satisfying the as-of condition.
+ *
+ * Note: When there are no equi-keys, both sides are collected into a
+ * single partition (AllTuples). The right side is fully buffered in
+ * memory, so this operator is not suitable for large right-side tables
+ * without equi-keys.
+ */
+case class SortMergeAsOfJoinExec(
+    leftKeys: Seq[Expression],
+    rightKeys: Seq[Expression],
+    leftAsOfExpr: Expression,
+    rightAsOfExpr: Expression,
+    asOfCondition: Expression,
+    orderExpression: Expression,
+    joinType: JoinType,
+    condition: Option[Expression],
+    left: SparkPlan,
+    right: SparkPlan) extends BinaryExecNode {
+
+  override lazy val metrics: Map[String, SQLMetric] = Map(
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext,
+      "number of output rows"))
+
+  override def output: Seq[Attribute] = joinType match {
+    case LeftOuter =>
+      left.output ++ right.output.map(_.withNullability(true))
+    case _ =>
+      left.output ++ right.output
+  }
+
+  override def outputOrdering: Seq[SortOrder] = {
+    // Output preserves left-side ordering (equi-keys + as-of key)
+    left.outputOrdering
+  }
+
+  override def requiredChildDistribution: Seq[Distribution] = {
+    if (leftKeys.isEmpty) {
+      AllTuples :: AllTuples :: Nil
+    } else {
+      ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: 
Nil
+    }
+  }
+
+  override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
+    val leftOrdering = leftKeys.map(SortOrder(_, Ascending)) :+
+      SortOrder(leftAsOfExpr, Ascending)
+    val rightOrdering = rightKeys.map(SortOrder(_, Ascending)) :+
+      SortOrder(rightAsOfExpr, Ascending)
+    leftOrdering :: rightOrdering :: Nil
+  }
+
+  override def outputPartitioning: Partitioning = left.outputPartitioning
+
+  protected override def doExecute(): RDD[InternalRow] = {
+    val numOutputRows = longMetric("numOutputRows")
+
+    left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
+      val scanner = new SortMergeAsOfJoinScanner(
+        leftIter,
+        rightIter,
+        left.output,
+        right.output,
+        leftKeys,
+        rightKeys,
+        asOfCondition,
+        orderExpression,
+        joinType,
+        condition,
+        numOutputRows
+      )
+      // Register cleanup to release the right-side buffer on task completion
+      TaskContext.get().addTaskCompletionListener[Unit](_ => scanner.close())
+      scanner.iterator
+    }
+  }
+
+  override protected def withNewChildrenInternal(
+      newLeft: SparkPlan,
+      newRight: SparkPlan): SortMergeAsOfJoinExec = {
+    copy(left = newLeft, right = newRight)
+  }
+}
+
+/**
+ * Performs the sort-merge AS-OF join scan.
+ *
+ * Both inputs are sorted by (equi-keys, as-of key) ascending. For each
+ * left row within an equi-key group, we find the right row that satisfies
+ * the as-of condition and minimizes the order expression (distance).
+ *
+ * Since the right side is sorted by as-of key within each group, for
+ * backward joins we scan right-to-left and stop at the first match
+ * (exploiting sort order for early termination).
+ */
+private[joins] class SortMergeAsOfJoinScanner(
+    leftIter: Iterator[InternalRow],
+    rightIter: Iterator[InternalRow],
+    leftOutput: Seq[Attribute],
+    rightOutput: Seq[Attribute],
+    leftKeys: Seq[Expression],
+    rightKeys: Seq[Expression],
+    asOfCondition: Expression,
+    orderExpression: Expression,
+    joinType: JoinType,
+    residualCondition: Option[Expression],
+    numOutputRows: SQLMetric) {
+
+  private val joinedOutput = leftOutput ++ rightOutput
+  private val joinedRow = new JoinedRow()
+  private val resultProjection =
+    UnsafeProjection.create(joinedOutput, joinedOutput)
+
+  // Bound expressions for evaluating conditions on joined rows
+  private val boundAsOfCond = bindReference(asOfCondition, joinedOutput)
+  private val boundOrderExpr = bindReference(orderExpression, joinedOutput)
+  private val boundResidualCond =
+    residualCondition.map(bindReference(_, joinedOutput))
+
+  // Key ordering for equi-join keys
+  private val equiKeyOrdering: Option[BaseOrdering] =
+    if (leftKeys.nonEmpty) {
+      val keyAttributes = leftKeys.zipWithIndex.map { case (key, i) =>
+        AttributeReference(s"key_$i", key.dataType, key.nullable)()
+      }
+      Some(GenerateOrdering.generate(
+        keyAttributes.map(SortOrder(_, Ascending)), keyAttributes))
+    } else {
+      None
+    }
+
+  // Projections to extract equi-keys for comparison
+  private val leftKeyProj = UnsafeProjection.create(leftKeys, leftOutput)
+  private val rightKeyProj = UnsafeProjection.create(rightKeys, rightOutput)
+
+  // Ordering for the distance metric
+  private val distanceOrdering =
+    TypeUtils.getInterpretedOrdering(orderExpression.dataType)
+
+  // Determine scan direction based on the as-of condition.
+  // Backward (left >= right): best match is at end of sorted buffer -> 
right-to-left
+  // Forward (left <= right): best match is at start -> left-to-right
+  // Nearest / unknown: left-to-right (works correctly, just no early 
termination
+  // guarantee for the "as-of not satisfied" shortcut)
+  private val scanRightToLeft: Boolean = {
+    def isBackward(expr: Expression): Boolean = expr match {
+      case GreaterThanOrEqual(_, _) => true
+      case GreaterThan(_, _) => true
+      case And(l, _) => isBackward(l)
+      case _ => false
+    }
+    isBackward(asOfCondition)
+  }
+
+  // Null row for LeftOuter when no match is found
+  private val nullRightRow = new GenericInternalRow(rightOutput.length)
+
+  // Right-side buffer: holds right rows for the current equi-key group.
+  // Rows are sorted by as-of key ascending (guaranteed by 
requiredChildOrdering).
+  private val rightGroupBuffer = new ArrayBuffer[InternalRow]()

Review Comment:
   Thank you for the suggestion. I'll extended the class-level note to document 
the skewed equi-key group risk as well. Replacing ArrayBuffer with 
ExternalAppendOnlyUnsafeRowArray for spill support is a good idea. The main 
challenge is that the current right-to-left scan (findBestRightToLeft) relies 
on indexed access from the end of the buffer, and 
`ExternalAppendOnlyUnsafeRowArray` only supports forward iteration. This would 
require either switching to a "last match wins" forward scan (losing early 
termination for backward joins) or a hybrid approach (in-memory reverse scan 
below a threshold, forward scan with spill above it). I'd like to tackle this 
as a follow-up.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to