cloud-fan commented on code in PR #55912: URL: https://github.com/apache/spark/pull/55912#discussion_r3333438913
########## sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeAsOfJoinExec.scala: ########## @@ -0,0 +1,446 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReference +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} + +/** + * Performs an AS-OF join using sort-merge. Both sides are co-partitioned + * by the equi-join keys and sorted by (equi-join keys, as-of key). + * For each left row, we scan the right side to find the nearest match + * satisfying the as-of condition. + * + * Note: When there are no equi-keys, both sides are collected into a + * single partition (AllTuples). The right side is fully buffered in + * memory, so this operator is not suitable for large right-side tables + * without equi-keys. For each equi-key group, all right rows with that + * key are also buffered in memory; skewed equi-key groups can OOM. + */ +case class SortMergeAsOfJoinExec( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + leftAsOfExpr: Expression, + rightAsOfExpr: Expression, + asOfCondition: Expression, + orderExpression: Expression, + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BaseJoinExec { Review Comment: **Design: this operator re-implements a less-hardened `SortMergeJoinExec`.** It shares the entire sort-merge shell with `SortMergeJoinExec` — co-partition by keys (`ClusteredDistribution`), co-sort, `zipPartitions`, buffer one side's key-group and scan. The match semantics genuinely differ (SMJ emits *all* equal-key matches; AS-OF emits *one* nearest inequality-satisfying row), so a drop-in reuse of `SortMergeJoinScanner` isn't possible. But the operator omits three things SMJ already solves, all shareable independent of the match loop: 1. **Spill** — the group is buffered in an on-heap `ArrayBuffer` (`:210`) rather than `ExternalAppendOnlyUnsafeRowArray` (→ the OOM risk @peter-toth raised). 2. **Null-key skip** — SMJ skips null join keys (`anyNull`); this operator doesn't (→ the correctness finding below). 3. **Partitioning surface** — `ShuffledJoin` already provides `requiredChildDistribution` / `outputPartitioning` / `output` / `getKeyOrdering`; this hand-writes them on `BaseJoinExec`. Recommended direction: **a shared scanner base + extend `ShuffledJoin`** (factor iterator buffering, null-key skip, advance-to-key-group, and spill-backed group storage into a base; each scanner supplies only the per-group match strategy) — **not** folding AS-OF into `SortMergeJoinExec`, which would thread a second match semantics through SMJ's ~600 lines of whole-stage codegen. This resolves the null-key bug and the OOM risk by construction. Two caveats: (a) a shared *scanner* base gives robustness/maintenance parity but not whole-stage codegen — that lives in the operator's `doProduce` and stays future work; (b) the Backward "reverse scan needs random access" concern (the cited blocker to spill) is avoidable — since the group is sorted ascending and the as-of predicate is monotone, a *forward* scan keeping the last as-of-satisfying row finds the same nearest match, so `ExternalAppendOnlyUnsafeRowArray`'s forward-only `generateIterator(startIndex)` suffices. Non-blocking — an architecture direction; the null-key fix below should land regardless of whether this refactor happens here or in a follow-up. ########## sql/core/src/test/scala/org/apache/spark/sql/SortMergeAsOfJoinSuite.scala: ########## @@ -0,0 +1,563 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.joins.SortMergeAsOfJoinExec +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types._ + +class SortMergeAsOfJoinSuite extends QueryTest + with SharedSparkSession + with AdaptiveSparkPlanHelper { + + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set(SQLConf.SORT_MERGE_AS_OF_JOIN_ENABLED.key, "true") + } + + override def afterAll(): Unit = { + spark.conf.unset(SQLConf.SORT_MERGE_AS_OF_JOIN_ENABLED.key) + super.afterAll() + } + + def prepareForAsOfJoin(): (classic.DataFrame, classic.DataFrame) = { + val schema1 = StructType( + StructField("a", IntegerType, false) :: + StructField("b", StringType, false) :: + StructField("left_val", StringType, false) :: Nil) + val rowSeq1: List[Row] = List( + Row(1, "x", "a"), Row(5, "y", "b"), Row(10, "z", "c")) + val df1 = spark.createDataFrame(rowSeq1.asJava, schema1) + + val schema2 = StructType( + StructField("a", IntegerType) :: + StructField("b", StringType) :: + StructField("right_val", IntegerType) :: Nil) + val rowSeq2: List[Row] = List( + Row(1, "v", 1), Row(2, "w", 2), Row(3, "x", 3), + Row(6, "y", 6), Row(7, "z", 7)) + val df2 = spark.createDataFrame(rowSeq2.asJava, schema2) + + (df1, df2) + } + + test("uses SortMergeAsOfJoinExec physical operator") { + val (df1, df2) = prepareForAsOfJoin() + val result = df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "backward") + val plan = result.queryExecution.executedPlan + assert(collectWithSubqueries(plan) { + case _: SortMergeAsOfJoinExec => true + }.nonEmpty, s"Expected SortMergeAsOfJoinExec in plan:\n$plan") + } + + test("backward join - simple") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "backward"), + Seq( + Row(1, "x", "a", 1, "v", 1), + Row(5, "y", "b", 3, "x", 3), + Row(10, "z", "c", 7, "z", 7) + ) + ) + } + + test("backward join - usingColumns") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq("b"), + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "backward"), + Seq( + Row(10, "z", "c", 7, "z", 7) + ) + ) + } + + test("backward join - left outer") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq("b"), + joinType = "leftouter", tolerance = null, + allowExactMatches = true, direction = "backward"), + Seq( + Row(1, "x", "a", null, null, null), + Row(5, "y", "b", null, null, null), + Row(10, "z", "c", 7, "z", 7) + ) + ) + } + + test("forward join") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "forward"), + Seq( + Row(1, "x", "a", 1, "v", 1), + Row(5, "y", "b", 6, "y", 6), + Row(10, "z", "c", null, null, null) + ).filter(_.get(3) != null) // inner join: no match for 10 + ) + } + + test("nearest join") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "nearest"), + Seq( + Row(1, "x", "a", 1, "v", 1), + Row(5, "y", "b", 6, "y", 6), + Row(10, "z", "c", 7, "z", 7) + ) + ) + } + + test("backward join - tolerance = 1") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", + tolerance = functions.lit(1), + allowExactMatches = true, direction = "backward"), + Seq( + Row(1, "x", "a", 1, "v", 1), + Row(10, "z", "c", null, null, null) + ).filter(_.get(3) != null) + ) + } + + test("backward join - allowExactMatches = false") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = false, direction = "backward"), + Seq( + // left.a=1: no right row with a < 1 → no match Review Comment: Non-ASCII `→` in a comment fails Spark's scalastyle non-ASCII rule; use `->`. ```suggestion // left.a=1: no right row with a < 1 -> no match ``` ########## sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeAsOfJoinExec.scala: ########## @@ -0,0 +1,446 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReference +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} + +/** + * Performs an AS-OF join using sort-merge. Both sides are co-partitioned + * by the equi-join keys and sorted by (equi-join keys, as-of key). + * For each left row, we scan the right side to find the nearest match + * satisfying the as-of condition. + * + * Note: When there are no equi-keys, both sides are collected into a + * single partition (AllTuples). The right side is fully buffered in + * memory, so this operator is not suitable for large right-side tables + * without equi-keys. For each equi-key group, all right rows with that + * key are also buffered in memory; skewed equi-key groups can OOM. + */ +case class SortMergeAsOfJoinExec( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + leftAsOfExpr: Expression, + rightAsOfExpr: Expression, + asOfCondition: Expression, + orderExpression: Expression, + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BaseJoinExec { + + override lazy val metrics: Map[String, SQLMetric] = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, + "number of output rows")) + + override def output: Seq[Attribute] = joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case _: InnerLike => + left.output ++ right.output + case other => + throw SparkException.internalError( + s"$nodeName does not support join type: $other") + } + + override def outputOrdering: Seq[SortOrder] = { + // Output preserves left-side ordering (equi-keys + as-of key) + left.outputOrdering + } + + override def requiredChildDistribution: Seq[Distribution] = { + if (leftKeys.isEmpty) { + AllTuples :: AllTuples :: Nil + } else { + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + val leftOrdering = leftKeys.map(SortOrder(_, Ascending)) :+ + SortOrder(leftAsOfExpr, Ascending) + val rightOrdering = rightKeys.map(SortOrder(_, Ascending)) :+ + SortOrder(rightAsOfExpr, Ascending) + leftOrdering :: rightOrdering :: Nil + } + + override def outputPartitioning: Partitioning = joinType match { + case _: InnerLike => + PartitioningCollection( + Seq(left.outputPartitioning, right.outputPartitioning)) + case LeftOuter => left.outputPartitioning + case other => + throw SparkException.internalError( + s"$nodeName does not support join type: $other") + } + + // Determine scan direction based on the order expression (distance metric). + // This is a performance heuristic only -- if it misclassifies, the scan + // still produces the correct result; only the early-termination shortcut + // is lost. + // + // orderExpression is direction-unique by construction: + // Backward: Subtract(leftAsOf, rightAsOf) -> right-to-left + // Forward: Subtract(rightAsOf, leftAsOf) -> left-to-right + // Nearest: If(...) -> left-to-right + private val scanRightToLeft: Boolean = orderExpression match { + case Subtract(l, _, _) if l.semanticEquals(leftAsOfExpr) => true + case _ => false + } + + protected override def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + val scanFromRight = scanRightToLeft + + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + val scanner = new SortMergeAsOfJoinScanner( + leftIter, + rightIter, + left.output, + right.output, + leftKeys, + rightKeys, + asOfCondition, + orderExpression, + joinType, + condition, + numOutputRows, + scanFromRight + ) + // Register cleanup to release the right-side buffer on task completion + TaskContext.get().addTaskCompletionListener[Unit](_ => scanner.close()) + scanner.iterator + } + } + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, + newRight: SparkPlan): SortMergeAsOfJoinExec = { + copy(left = newLeft, right = newRight) + } +} + +/** + * Performs the sort-merge AS-OF join scan. + * + * Both inputs are sorted by (equi-keys, as-of key) ascending. For each + * left row within an equi-key group, we find the right row that satisfies + * the as-of condition and minimizes the order expression (distance). + * + * Since the right side is sorted by as-of key within each group, for + * backward joins we scan right-to-left and stop at the first match + * (exploiting sort order for early termination). + */ +private[joins] class SortMergeAsOfJoinScanner( + leftIter: Iterator[InternalRow], + rightIter: Iterator[InternalRow], + leftOutput: Seq[Attribute], + rightOutput: Seq[Attribute], + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + asOfCondition: Expression, + orderExpression: Expression, + joinType: JoinType, + residualCondition: Option[Expression], + numOutputRows: SQLMetric, + scanRightToLeft: Boolean) { + + private val joinedOutput = leftOutput ++ rightOutput + private val joinedRow = new JoinedRow() + private val resultProjection = + UnsafeProjection.create(joinedOutput, joinedOutput) + + // Bound expressions for evaluating conditions on joined rows + private val boundAsOfCond = bindReference(asOfCondition, joinedOutput) + private val boundOrderExpr = bindReference(orderExpression, joinedOutput) + private val boundResidualCond = + residualCondition.map(bindReference(_, joinedOutput)) Review Comment: The residual-condition path is exercised by no test. The `joinAsOf(joinExprs: Column, ...)` overload can place non-equi pair-correlated predicates into `condition`; the strategy routes them here as `residualCondition`, and the scanner interleaves `boundResidualCond` evaluation with distance-based early termination (it skips residual-failing rows without early-stopping). That interaction is subtle — all suite cases use `usingColumns` (pure `EqualTo`) or no condition. Suggest adding a residual test (e.g. an equi-key join plus a `left.val > right.val` residual) for both Backward and Nearest. ########## sql/core/src/test/scala/org/apache/spark/sql/SortMergeAsOfJoinSuite.scala: ########## @@ -0,0 +1,563 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.joins.SortMergeAsOfJoinExec +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types._ + +class SortMergeAsOfJoinSuite extends QueryTest + with SharedSparkSession + with AdaptiveSparkPlanHelper { + + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set(SQLConf.SORT_MERGE_AS_OF_JOIN_ENABLED.key, "true") + } + + override def afterAll(): Unit = { + spark.conf.unset(SQLConf.SORT_MERGE_AS_OF_JOIN_ENABLED.key) + super.afterAll() + } + + def prepareForAsOfJoin(): (classic.DataFrame, classic.DataFrame) = { + val schema1 = StructType( + StructField("a", IntegerType, false) :: + StructField("b", StringType, false) :: + StructField("left_val", StringType, false) :: Nil) + val rowSeq1: List[Row] = List( + Row(1, "x", "a"), Row(5, "y", "b"), Row(10, "z", "c")) + val df1 = spark.createDataFrame(rowSeq1.asJava, schema1) + + val schema2 = StructType( + StructField("a", IntegerType) :: + StructField("b", StringType) :: + StructField("right_val", IntegerType) :: Nil) + val rowSeq2: List[Row] = List( + Row(1, "v", 1), Row(2, "w", 2), Row(3, "x", 3), + Row(6, "y", 6), Row(7, "z", 7)) + val df2 = spark.createDataFrame(rowSeq2.asJava, schema2) + + (df1, df2) + } + + test("uses SortMergeAsOfJoinExec physical operator") { + val (df1, df2) = prepareForAsOfJoin() + val result = df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "backward") + val plan = result.queryExecution.executedPlan + assert(collectWithSubqueries(plan) { + case _: SortMergeAsOfJoinExec => true + }.nonEmpty, s"Expected SortMergeAsOfJoinExec in plan:\n$plan") + } + + test("backward join - simple") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "backward"), + Seq( + Row(1, "x", "a", 1, "v", 1), + Row(5, "y", "b", 3, "x", 3), + Row(10, "z", "c", 7, "z", 7) + ) + ) + } + + test("backward join - usingColumns") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq("b"), + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "backward"), + Seq( + Row(10, "z", "c", 7, "z", 7) + ) + ) + } + + test("backward join - left outer") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq("b"), + joinType = "leftouter", tolerance = null, + allowExactMatches = true, direction = "backward"), + Seq( + Row(1, "x", "a", null, null, null), + Row(5, "y", "b", null, null, null), + Row(10, "z", "c", 7, "z", 7) + ) + ) + } + + test("forward join") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "forward"), + Seq( + Row(1, "x", "a", 1, "v", 1), + Row(5, "y", "b", 6, "y", 6), + Row(10, "z", "c", null, null, null) + ).filter(_.get(3) != null) // inner join: no match for 10 + ) + } + + test("nearest join") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "nearest"), + Seq( + Row(1, "x", "a", 1, "v", 1), + Row(5, "y", "b", 6, "y", 6), + Row(10, "z", "c", 7, "z", 7) + ) + ) + } + + test("backward join - tolerance = 1") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", + tolerance = functions.lit(1), + allowExactMatches = true, direction = "backward"), + Seq( + Row(1, "x", "a", 1, "v", 1), + Row(10, "z", "c", null, null, null) + ).filter(_.get(3) != null) + ) + } + + test("backward join - allowExactMatches = false") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = false, direction = "backward"), + Seq( + // left.a=1: no right row with a < 1 → no match + // left.a=5: right.a=3 (3 < 5) → match Review Comment: Non-ASCII `→`; use `->`. ```suggestion // left.a=5: right.a=3 (3 < 5) -> match ``` ########## sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeAsOfJoinExec.scala: ########## @@ -0,0 +1,446 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReference +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} + +/** + * Performs an AS-OF join using sort-merge. Both sides are co-partitioned + * by the equi-join keys and sorted by (equi-join keys, as-of key). + * For each left row, we scan the right side to find the nearest match + * satisfying the as-of condition. + * + * Note: When there are no equi-keys, both sides are collected into a + * single partition (AllTuples). The right side is fully buffered in + * memory, so this operator is not suitable for large right-side tables + * without equi-keys. For each equi-key group, all right rows with that + * key are also buffered in memory; skewed equi-key groups can OOM. + */ +case class SortMergeAsOfJoinExec( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + leftAsOfExpr: Expression, + rightAsOfExpr: Expression, + asOfCondition: Expression, + orderExpression: Expression, + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) extends BaseJoinExec { + + override lazy val metrics: Map[String, SQLMetric] = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, + "number of output rows")) + + override def output: Seq[Attribute] = joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case _: InnerLike => + left.output ++ right.output + case other => + throw SparkException.internalError( + s"$nodeName does not support join type: $other") + } + + override def outputOrdering: Seq[SortOrder] = { + // Output preserves left-side ordering (equi-keys + as-of key) + left.outputOrdering + } + + override def requiredChildDistribution: Seq[Distribution] = { + if (leftKeys.isEmpty) { + AllTuples :: AllTuples :: Nil + } else { + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + val leftOrdering = leftKeys.map(SortOrder(_, Ascending)) :+ + SortOrder(leftAsOfExpr, Ascending) + val rightOrdering = rightKeys.map(SortOrder(_, Ascending)) :+ + SortOrder(rightAsOfExpr, Ascending) + leftOrdering :: rightOrdering :: Nil + } + + override def outputPartitioning: Partitioning = joinType match { + case _: InnerLike => + PartitioningCollection( + Seq(left.outputPartitioning, right.outputPartitioning)) + case LeftOuter => left.outputPartitioning + case other => + throw SparkException.internalError( + s"$nodeName does not support join type: $other") + } + + // Determine scan direction based on the order expression (distance metric). + // This is a performance heuristic only -- if it misclassifies, the scan + // still produces the correct result; only the early-termination shortcut + // is lost. + // + // orderExpression is direction-unique by construction: + // Backward: Subtract(leftAsOf, rightAsOf) -> right-to-left + // Forward: Subtract(rightAsOf, leftAsOf) -> left-to-right + // Nearest: If(...) -> left-to-right + private val scanRightToLeft: Boolean = orderExpression match { + case Subtract(l, _, _) if l.semanticEquals(leftAsOfExpr) => true + case _ => false + } + + protected override def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + val scanFromRight = scanRightToLeft + + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + val scanner = new SortMergeAsOfJoinScanner( + leftIter, + rightIter, + left.output, + right.output, + leftKeys, + rightKeys, + asOfCondition, + orderExpression, + joinType, + condition, + numOutputRows, + scanFromRight + ) + // Register cleanup to release the right-side buffer on task completion + TaskContext.get().addTaskCompletionListener[Unit](_ => scanner.close()) + scanner.iterator + } + } + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, + newRight: SparkPlan): SortMergeAsOfJoinExec = { + copy(left = newLeft, right = newRight) + } +} + +/** + * Performs the sort-merge AS-OF join scan. + * + * Both inputs are sorted by (equi-keys, as-of key) ascending. For each + * left row within an equi-key group, we find the right row that satisfies + * the as-of condition and minimizes the order expression (distance). + * + * Since the right side is sorted by as-of key within each group, for + * backward joins we scan right-to-left and stop at the first match + * (exploiting sort order for early termination). + */ +private[joins] class SortMergeAsOfJoinScanner( + leftIter: Iterator[InternalRow], + rightIter: Iterator[InternalRow], + leftOutput: Seq[Attribute], + rightOutput: Seq[Attribute], + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + asOfCondition: Expression, + orderExpression: Expression, + joinType: JoinType, + residualCondition: Option[Expression], + numOutputRows: SQLMetric, + scanRightToLeft: Boolean) { + + private val joinedOutput = leftOutput ++ rightOutput + private val joinedRow = new JoinedRow() + private val resultProjection = + UnsafeProjection.create(joinedOutput, joinedOutput) + + // Bound expressions for evaluating conditions on joined rows + private val boundAsOfCond = bindReference(asOfCondition, joinedOutput) + private val boundOrderExpr = bindReference(orderExpression, joinedOutput) + private val boundResidualCond = + residualCondition.map(bindReference(_, joinedOutput)) + + // Key ordering for equi-join keys + private val equiKeyOrdering: Option[BaseOrdering] = + if (leftKeys.nonEmpty) { + val keyAttributes = leftKeys.zipWithIndex.map { case (key, i) => + AttributeReference(s"key_$i", key.dataType, key.nullable)() + } + Some(GenerateOrdering.generate( Review Comment: **Correctness: NULL `EqualTo` equi-keys are incorrectly matched.** (An instance of the design comment on the class declaration above.) `equiKeyOrdering` is a `GenerateOrdering` over `SortOrder(_, Ascending)`, whose comparator treats two NULLs as equal (`compare == 0`). So a left row and a right row that both carry a NULL equi-key land in the same group (`advanceRightTo`/`bufferRightGroup`) and produce a join result. But the equi-key comes from `EqualTo`, whose SQL semantics are `NULL = NULL -> NULL` (no match) — which is exactly what the MIN_BY baseline produces, since it puts `EqualTo(l, r)` inside a `Filter`. The standard `SortMergeJoinScanner` codifies the same contract by skipping null join keys (`streamedRowKey.anyNull` / `advancedBufferedToRowWithNullFreeJoinKey` / `assert(!bufferedRowKey.anyNull)`). This operator has no such guard, so it diverges from both and emits extra rows. Note the `extractEquiJoinKeys` comment at `SparkStrategies.scala:206` is inverted: it says `EqualNullSafe` is excluded to avoid "incorrectly match null-keyed rows," but `EqualNullSafe` is precisely the operator where NULLs *should* match — the null-match hazard lives in the retained `EqualTo` path. Suggested fix: in the scanner, treat any null-valued equi-key as a non-match — skip null-key left rows (Inner: skip; LeftOuter: emit null right) and skip null-key right rows when buffering, mirroring `SortMergeJoinScanner`. And please add a test with NULL `EqualTo` keys (the existing "null as-of keys" test uses no equi-keys, so this path is uncovered). ########## sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AsOfJoinBenchmark.scala: ########## @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.classic +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf + +/** + * Benchmark to measure AS-OF join performance: sort-merge operator vs correlated subquery. + * To run this benchmark: + * {{{ + * 1. build/sbt "sql/Test/runMain <this class>" + * 2. generate result: + * SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/Test/runMain <this class>" + * Results will be written to + * "benchmarks/AsOfJoinBenchmark-results.txt". + * }}} + */ +object AsOfJoinBenchmark extends SqlBasedBenchmark { + + private def doAsOfJoin( + left: classic.DataFrame, + right: classic.DataFrame, + usingColumns: Seq[String]): Unit = { + left.joinAsOf( + right, left.col("ts"), right.col("ts"), + usingColumns = usingColumns, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "backward" + ).noop() + } + + private def asOfJoinBenchmark( + leftRows: Int, + rightRows: Int, + numGroups: Int): Unit = { + val left: classic.DataFrame = spark.range(leftRows).select( + (col("id") % numGroups).as("group_id"), + col("id").as("ts"), + lit("left_val").as("left_val") + ).toDF().asInstanceOf[classic.DataFrame] + val right: classic.DataFrame = spark.range(rightRows).select( + (col("id") % numGroups).as("group_id"), + (col("id") * 3 / 2).as("ts"), + lit("right_val").as("right_val") + ).toDF().asInstanceOf[classic.DataFrame] + + val benchmark = new Benchmark( + s"AS-OF Join (left=$leftRows, right=$rightRows, groups=$numGroups)", + leftRows, + output = output) + + benchmark.addCase("Correlated subquery (baseline)") { _ => + withSQLConf( + SQLConf.SORT_MERGE_AS_OF_JOIN_ENABLED.key -> "false", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + doAsOfJoin(left, right, Seq("group_id")) + } + } + + benchmark.addCase("Sort-merge AS-OF join") { _ => + withSQLConf( + SQLConf.SORT_MERGE_AS_OF_JOIN_ENABLED.key -> "true", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + doAsOfJoin(left, right, Seq("group_id")) + } + } + + benchmark.run() + } + + private def asOfJoinNoEquiKeyBenchmark( + leftRows: Int, rightRows: Int): Unit = { + val left: classic.DataFrame = spark.range(leftRows).select( + col("id").as("ts"), + lit("left_val").as("left_val") + ).toDF().asInstanceOf[classic.DataFrame] + val right: classic.DataFrame = spark.range(rightRows).select( + (col("id") * 3 / 2).as("ts"), + lit("right_val").as("right_val") + ).toDF().asInstanceOf[classic.DataFrame] + + val benchmark = new Benchmark( + s"AS-OF Join no equi-key (left=$leftRows, right=$rightRows)", + leftRows, + output = output) + + benchmark.addCase("Correlated subquery (baseline)") { _ => + withSQLConf( + SQLConf.SORT_MERGE_AS_OF_JOIN_ENABLED.key -> "false", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + doAsOfJoin(left, right, Seq.empty) + } + } + + benchmark.addCase("Sort-merge AS-OF join") { _ => + withSQLConf( + SQLConf.SORT_MERGE_AS_OF_JOIN_ENABLED.key -> "true", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + doAsOfJoin(left, right, Seq.empty) + } + } + + benchmark.run() + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + runBenchmark("AS-OF Join Benchmark") { + // 10K left x 10K right, 100 groups — both paths feasible Review Comment: Non-ASCII em-dash `—` in a comment fails Spark's scalastyle non-ASCII rule; use an ASCII hyphen. ```suggestion // 10K left x 10K right, 100 groups - both paths feasible ``` ########## sql/core/src/test/scala/org/apache/spark/sql/SortMergeAsOfJoinSuite.scala: ########## @@ -0,0 +1,563 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.joins.SortMergeAsOfJoinExec +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types._ + +class SortMergeAsOfJoinSuite extends QueryTest + with SharedSparkSession + with AdaptiveSparkPlanHelper { + + override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set(SQLConf.SORT_MERGE_AS_OF_JOIN_ENABLED.key, "true") + } + + override def afterAll(): Unit = { + spark.conf.unset(SQLConf.SORT_MERGE_AS_OF_JOIN_ENABLED.key) + super.afterAll() + } + + def prepareForAsOfJoin(): (classic.DataFrame, classic.DataFrame) = { + val schema1 = StructType( + StructField("a", IntegerType, false) :: + StructField("b", StringType, false) :: + StructField("left_val", StringType, false) :: Nil) + val rowSeq1: List[Row] = List( + Row(1, "x", "a"), Row(5, "y", "b"), Row(10, "z", "c")) + val df1 = spark.createDataFrame(rowSeq1.asJava, schema1) + + val schema2 = StructType( + StructField("a", IntegerType) :: + StructField("b", StringType) :: + StructField("right_val", IntegerType) :: Nil) + val rowSeq2: List[Row] = List( + Row(1, "v", 1), Row(2, "w", 2), Row(3, "x", 3), + Row(6, "y", 6), Row(7, "z", 7)) + val df2 = spark.createDataFrame(rowSeq2.asJava, schema2) + + (df1, df2) + } + + test("uses SortMergeAsOfJoinExec physical operator") { + val (df1, df2) = prepareForAsOfJoin() + val result = df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "backward") + val plan = result.queryExecution.executedPlan + assert(collectWithSubqueries(plan) { + case _: SortMergeAsOfJoinExec => true + }.nonEmpty, s"Expected SortMergeAsOfJoinExec in plan:\n$plan") + } + + test("backward join - simple") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "backward"), + Seq( + Row(1, "x", "a", 1, "v", 1), + Row(5, "y", "b", 3, "x", 3), + Row(10, "z", "c", 7, "z", 7) + ) + ) + } + + test("backward join - usingColumns") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq("b"), + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "backward"), + Seq( + Row(10, "z", "c", 7, "z", 7) + ) + ) + } + + test("backward join - left outer") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq("b"), + joinType = "leftouter", tolerance = null, + allowExactMatches = true, direction = "backward"), + Seq( + Row(1, "x", "a", null, null, null), + Row(5, "y", "b", null, null, null), + Row(10, "z", "c", 7, "z", 7) + ) + ) + } + + test("forward join") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "forward"), + Seq( + Row(1, "x", "a", 1, "v", 1), + Row(5, "y", "b", 6, "y", 6), + Row(10, "z", "c", null, null, null) + ).filter(_.get(3) != null) // inner join: no match for 10 + ) + } + + test("nearest join") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = true, direction = "nearest"), + Seq( + Row(1, "x", "a", 1, "v", 1), + Row(5, "y", "b", 6, "y", 6), + Row(10, "z", "c", 7, "z", 7) + ) + ) + } + + test("backward join - tolerance = 1") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", + tolerance = functions.lit(1), + allowExactMatches = true, direction = "backward"), + Seq( + Row(1, "x", "a", 1, "v", 1), + Row(10, "z", "c", null, null, null) + ).filter(_.get(3) != null) + ) + } + + test("backward join - allowExactMatches = false") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, + allowExactMatches = false, direction = "backward"), + Seq( + // left.a=1: no right row with a < 1 → no match + // left.a=5: right.a=3 (3 < 5) → match + Row(5, "y", "b", 3, "x", 3), + // left.a=10: right.a=7 (7 < 10) → match Review Comment: Non-ASCII `→`; use `->`. ```suggestion // left.a=10: right.a=7 (7 < 10) -> match ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
