peter-toth commented on code in PR #55912: URL: https://github.com/apache/spark/pull/55912#discussion_r3265639062
########## sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeAsOfJoinExec.scala: ########## @@ -0,0 +1,426 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.TaskContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReference +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} + +/** + * Performs an AS-OF join using sort-merge. Both sides are co-partitioned + * by the equi-join keys and sorted by (equi-join keys, as-of key). + * For each left row, we scan the right side to find the nearest match + * satisfying the as-of condition. + * + * Note: When there are no equi-keys, both sides are collected into a + * single partition (AllTuples). The right side is fully buffered in + * memory, so this operator is not suitable for large right-side tables + * without equi-keys. + */ +case class SortMergeAsOfJoinExec( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + leftAsOfExpr: Expression, + rightAsOfExpr: Expression, + asOfCondition: Expression, + orderExpression: Expression, + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryExecNode { + + override lazy val metrics: Map[String, SQLMetric] = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, + "number of output rows")) + + override def output: Seq[Attribute] = joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case _ => Review Comment: The `case _ =>` catch-all silently produces an Inner-style schema for any join type other than `LeftOuter`. Today the only other `joinType` that reaches this operator is `Inner` (validated at DataFrame `joinAsOf` and Spark Connect entry points), so this works. But if a future change ever wires `RightOuter` / `FullOuter` / `LeftSemi` etc. through here — including someone instantiating the case class directly in a test — the schema would be silently wrong (no nullability injection on the preserved side, leaf-attribute mismatch with the rest of the pipeline). Pattern after similar operators that fail loud: ```scala override def output: Seq[Attribute] = joinType match { case LeftOuter => left.output ++ right.output.map(_.withNullability(true)) case _: InnerLike => left.output ++ right.output case other => throw SparkException.internalError( s"$nodeName does not support join type: $other") } ``` A `require(joinType == Inner || joinType == LeftOuter, ...)` in the case class body would be even more defensive — it'd surface the violation at construction time rather than waiting for analysis to call `.output`. ########## sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeAsOfJoinExec.scala: ########## @@ -0,0 +1,426 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.TaskContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReference +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} + +/** + * Performs an AS-OF join using sort-merge. Both sides are co-partitioned + * by the equi-join keys and sorted by (equi-join keys, as-of key). + * For each left row, we scan the right side to find the nearest match + * satisfying the as-of condition. + * + * Note: When there are no equi-keys, both sides are collected into a + * single partition (AllTuples). The right side is fully buffered in + * memory, so this operator is not suitable for large right-side tables + * without equi-keys. + */ +case class SortMergeAsOfJoinExec( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + leftAsOfExpr: Expression, + rightAsOfExpr: Expression, + asOfCondition: Expression, + orderExpression: Expression, + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryExecNode { + + override lazy val metrics: Map[String, SQLMetric] = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, + "number of output rows")) + + override def output: Seq[Attribute] = joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case _ => + left.output ++ right.output + } + + override def outputOrdering: Seq[SortOrder] = { + // Output preserves left-side ordering (equi-keys + as-of key) + left.outputOrdering + } + + override def requiredChildDistribution: Seq[Distribution] = { + if (leftKeys.isEmpty) { + AllTuples :: AllTuples :: Nil + } else { + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + val leftOrdering = leftKeys.map(SortOrder(_, Ascending)) :+ + SortOrder(leftAsOfExpr, Ascending) + val rightOrdering = rightKeys.map(SortOrder(_, Ascending)) :+ + SortOrder(rightAsOfExpr, Ascending) + leftOrdering :: rightOrdering :: Nil + } + + override def outputPartitioning: Partitioning = left.outputPartitioning + + protected override def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + val scanner = new SortMergeAsOfJoinScanner( + leftIter, + rightIter, + left.output, + right.output, + leftKeys, + rightKeys, + asOfCondition, + orderExpression, + joinType, + condition, + numOutputRows + ) + // Register cleanup to release the right-side buffer on task completion + TaskContext.get().addTaskCompletionListener[Unit](_ => scanner.close()) + scanner.iterator + } + } + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, + newRight: SparkPlan): SortMergeAsOfJoinExec = { + copy(left = newLeft, right = newRight) + } +} + +/** + * Performs the sort-merge AS-OF join scan. + * + * Both inputs are sorted by (equi-keys, as-of key) ascending. For each + * left row within an equi-key group, we find the right row that satisfies + * the as-of condition and minimizes the order expression (distance). + * + * Since the right side is sorted by as-of key within each group, for + * backward joins we scan right-to-left and stop at the first match + * (exploiting sort order for early termination). + */ +private[joins] class SortMergeAsOfJoinScanner( + leftIter: Iterator[InternalRow], + rightIter: Iterator[InternalRow], + leftOutput: Seq[Attribute], + rightOutput: Seq[Attribute], + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + asOfCondition: Expression, + orderExpression: Expression, + joinType: JoinType, + residualCondition: Option[Expression], + numOutputRows: SQLMetric) { + + private val joinedOutput = leftOutput ++ rightOutput + private val joinedRow = new JoinedRow() + private val resultProjection = + UnsafeProjection.create(joinedOutput, joinedOutput) + + // Bound expressions for evaluating conditions on joined rows + private val boundAsOfCond = bindReference(asOfCondition, joinedOutput) + private val boundOrderExpr = bindReference(orderExpression, joinedOutput) + private val boundResidualCond = + residualCondition.map(bindReference(_, joinedOutput)) + + // Key ordering for equi-join keys + private val equiKeyOrdering: Option[BaseOrdering] = + if (leftKeys.nonEmpty) { + val keyAttributes = leftKeys.zipWithIndex.map { case (key, i) => + AttributeReference(s"key_$i", key.dataType, key.nullable)() + } + Some(GenerateOrdering.generate( + keyAttributes.map(SortOrder(_, Ascending)), keyAttributes)) + } else { + None + } + + // Projections to extract equi-keys for comparison + private val leftKeyProj = UnsafeProjection.create(leftKeys, leftOutput) + private val rightKeyProj = UnsafeProjection.create(rightKeys, rightOutput) + + // Ordering for the distance metric + private val distanceOrdering = + TypeUtils.getInterpretedOrdering(orderExpression.dataType) + + // Determine scan direction based on the as-of condition. + // Backward (left >= right): best match is at end of sorted buffer -> right-to-left + // Forward (left <= right): best match is at start -> left-to-right + // Nearest / unknown: left-to-right (works correctly, just no early termination + // guarantee for the "as-of not satisfied" shortcut) + private val scanRightToLeft: Boolean = { + def isBackward(expr: Expression): Boolean = expr match { + case GreaterThanOrEqual(_, _) => true + case GreaterThan(_, _) => true + case And(l, _) => isBackward(l) + case _ => false + } + isBackward(asOfCondition) + } + + // Null row for LeftOuter when no match is found + private val nullRightRow = new GenericInternalRow(rightOutput.length) + + // Right-side buffer: holds right rows for the current equi-key group. + // Rows are sorted by as-of key ascending (guaranteed by requiredChildOrdering). + private val rightGroupBuffer = new ArrayBuffer[InternalRow]() Review Comment: The class-level comment (lines 41-43) warns that the no-equi-key case "is not suitable for large right-side tables" because `bufferAllRight` materializes the whole right partition. The same OOM risk exists in the equi-key case for any single equi-key group with a large right-side cardinality (`bufferRightGroup` accumulates all right rows sharing one key into the same `ArrayBuffer`) — and that case isn't called out, so a reader would assume equi-key joins are memory-safe. Two suggestions, in increasing order of effort: 1. Extend the class-level note: "For each equi-key group, all right rows with that key are buffered. Skewed equi-key groups can OOM." 2. Replace the `ArrayBuffer` with `ExternalAppendOnlyUnsafeRowArray` (used by `SortMergeJoinExec` for the same buffer-the-current-group pattern) so the buffer spills to disk past `spark.sql.windowExec.buffer.in.memory.threshold` etc. The right-to-left scan in `findBestRightToLeft` would need to be expressed in terms of the array's iterator rather than indexed access — non-trivial but tractable. Either is fine for this PR; without one of them, the operator carries an undocumented OOM mode that the benchmark numbers don't surface. ########## sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeAsOfJoinExec.scala: ########## @@ -0,0 +1,426 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.TaskContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReference +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} + +/** + * Performs an AS-OF join using sort-merge. Both sides are co-partitioned + * by the equi-join keys and sorted by (equi-join keys, as-of key). + * For each left row, we scan the right side to find the nearest match + * satisfying the as-of condition. + * + * Note: When there are no equi-keys, both sides are collected into a + * single partition (AllTuples). The right side is fully buffered in + * memory, so this operator is not suitable for large right-side tables + * without equi-keys. + */ +case class SortMergeAsOfJoinExec( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + leftAsOfExpr: Expression, + rightAsOfExpr: Expression, + asOfCondition: Expression, + orderExpression: Expression, + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryExecNode { + + override lazy val metrics: Map[String, SQLMetric] = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, + "number of output rows")) + + override def output: Seq[Attribute] = joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case _ => + left.output ++ right.output + } + + override def outputOrdering: Seq[SortOrder] = { + // Output preserves left-side ordering (equi-keys + as-of key) + left.outputOrdering + } + + override def requiredChildDistribution: Seq[Distribution] = { + if (leftKeys.isEmpty) { + AllTuples :: AllTuples :: Nil + } else { + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + val leftOrdering = leftKeys.map(SortOrder(_, Ascending)) :+ + SortOrder(leftAsOfExpr, Ascending) + val rightOrdering = rightKeys.map(SortOrder(_, Ascending)) :+ + SortOrder(rightAsOfExpr, Ascending) + leftOrdering :: rightOrdering :: Nil + } + + override def outputPartitioning: Partitioning = left.outputPartitioning + + protected override def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + val scanner = new SortMergeAsOfJoinScanner( + leftIter, + rightIter, + left.output, + right.output, + leftKeys, + rightKeys, + asOfCondition, + orderExpression, + joinType, + condition, + numOutputRows + ) + // Register cleanup to release the right-side buffer on task completion + TaskContext.get().addTaskCompletionListener[Unit](_ => scanner.close()) + scanner.iterator + } + } + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, + newRight: SparkPlan): SortMergeAsOfJoinExec = { + copy(left = newLeft, right = newRight) + } +} + +/** + * Performs the sort-merge AS-OF join scan. + * + * Both inputs are sorted by (equi-keys, as-of key) ascending. For each + * left row within an equi-key group, we find the right row that satisfies + * the as-of condition and minimizes the order expression (distance). + * + * Since the right side is sorted by as-of key within each group, for + * backward joins we scan right-to-left and stop at the first match + * (exploiting sort order for early termination). + */ +private[joins] class SortMergeAsOfJoinScanner( + leftIter: Iterator[InternalRow], + rightIter: Iterator[InternalRow], + leftOutput: Seq[Attribute], + rightOutput: Seq[Attribute], + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + asOfCondition: Expression, + orderExpression: Expression, + joinType: JoinType, + residualCondition: Option[Expression], + numOutputRows: SQLMetric) { + + private val joinedOutput = leftOutput ++ rightOutput + private val joinedRow = new JoinedRow() + private val resultProjection = + UnsafeProjection.create(joinedOutput, joinedOutput) + + // Bound expressions for evaluating conditions on joined rows + private val boundAsOfCond = bindReference(asOfCondition, joinedOutput) + private val boundOrderExpr = bindReference(orderExpression, joinedOutput) + private val boundResidualCond = + residualCondition.map(bindReference(_, joinedOutput)) + + // Key ordering for equi-join keys + private val equiKeyOrdering: Option[BaseOrdering] = + if (leftKeys.nonEmpty) { + val keyAttributes = leftKeys.zipWithIndex.map { case (key, i) => + AttributeReference(s"key_$i", key.dataType, key.nullable)() + } + Some(GenerateOrdering.generate( + keyAttributes.map(SortOrder(_, Ascending)), keyAttributes)) + } else { + None + } + + // Projections to extract equi-keys for comparison + private val leftKeyProj = UnsafeProjection.create(leftKeys, leftOutput) + private val rightKeyProj = UnsafeProjection.create(rightKeys, rightOutput) + + // Ordering for the distance metric + private val distanceOrdering = + TypeUtils.getInterpretedOrdering(orderExpression.dataType) + + // Determine scan direction based on the as-of condition. + // Backward (left >= right): best match is at end of sorted buffer -> right-to-left + // Forward (left <= right): best match is at start -> left-to-right + // Nearest / unknown: left-to-right (works correctly, just no early termination + // guarantee for the "as-of not satisfied" shortcut) + private val scanRightToLeft: Boolean = { Review Comment: Two issues with this heuristic: 1. **Recurses only on the left child of `And`.** With `tolerance`, the as-of condition is conjunctive (`backward AND right.t >= left.t - tolerance`) and the analyzer/optimizer is free to reorder conjuncts. The current code happens to work for the canonical shape `RewriteAsOfJoin` builds, but if the structure ever flips (e.g. an optimizer rule normalizing `And` operands), `isBackward` would incorrectly return `false` for backward-with-tolerance. Either look at both sides — `case And(l, r) => isBackward(l) || isBackward(r)` — or use a `find`-style search. 2. **The comment doesn't say this is a performance heuristic, not a correctness one.** Both `findBestRightToLeft` and `findBestLeftToRight` produce the correct optimum for any direction — the only difference is whether the early-termination shortcut fires. A reader who sees `isBackward` returning false for what they know is a backward join might assume that's a correctness bug, when it's actually just a missed optimization. Worth a one-line note: "If this misclassifies, the scan still produces the correct result; only the early-termination shortcut is lost." Concrete proposal: ```scala // Best-effort hint -- a wrong answer here only loses early termination, // not correctness. The tree-wide search tolerates conjunct reordering. private val scanRightToLeft: Boolean = asOfCondition.exists { case _: GreaterThanOrEqual | _: GreaterThan => true case _ => false } ``` (Uses `Expression.exists`, which short-circuits on the first match across the whole tree.) -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
