cloud-fan commented on code in PR #56101:
URL: https://github.com/apache/spark/pull/56101#discussion_r3462905349


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNearestByJoinExec.scala:
##########
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import java.util.{Comparator, PriorityQueue => JPriorityQueue}
+
+import org.apache.spark.SparkException
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.{InnerLike, JoinType, LeftOuter, 
NearestByDirection, NearestByDistance}
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.catalyst.util.TypeUtils
+import org.apache.spark.sql.execution.{ExplainUtils, SparkPlan}
+import org.apache.spark.sql.execution.metric.SQLMetrics
+
+/**
+ * Heap entry storing an index into the broadcast array alongside its ranking 
value.
+ * Using a case class with primitive `Int` field avoids boxing that `(Int, 
Any)` tuples incur.
+ */
+private[joins] case class HeapEntry(index: Int, rankingValue: Any)
+
+/**
+ * Physical operator for NearestByJoin that avoids materializing the full 
cross product.
+ * For each left row, iterates all broadcast right rows maintaining a bounded 
priority
+ * queue of size k, then emits the top-k matches directly.
+ *
+ * The right side is fully broadcast to all partitions. This operator only 
fires when
+ * the right side fits within [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]]. For 
right tables
+ * exceeding this threshold, the existing cross-product + aggregate rewrite is 
used as
+ * fallback. Tie-breaking among equal ranking values is non-deterministic 
(matches the
+ * existing rewrite behavior).
+ */
+case class BroadcastNearestByJoinExec(
+    left: SparkPlan,
+    right: SparkPlan,
+    joinType: JoinType,
+    numResults: Int,
+    rankingExpression: Expression,
+    direction: NearestByDirection) extends BaseJoinExec {
+
+  override def condition: Option[Expression] = None
+  override def leftKeys: Seq[Expression] = Seq.empty
+  override def rightKeys: Seq[Expression] = Seq.empty
+
+  override def simpleStringWithNodeId(): String = {
+    val opId = ExplainUtils.getOpId(this)
+    s"$nodeName $joinType k=$numResults $direction ($opId)".trim
+  }
+
+  override def output: Seq[Attribute] = joinType match {
+    case _: InnerLike | LeftOuter =>
+      left.output.map(_.withNullability(true)) ++ 
right.output.map(_.withNullability(true))
+    case other =>
+      throw SparkException.internalError(
+        s"$nodeName does not support join type: $other")
+  }
+
+  override lazy val metrics = Map(
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output 
rows"),
+    "streamedRows" -> SQLMetrics.createMetric(sparkContext, "number of left 
rows processed"))
+
+  override def requiredChildDistribution: Seq[Distribution] =
+    UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: 
Nil
+
+  override def outputPartitioning: Partitioning = left.outputPartitioning
+
+  override def outputOrdering: Seq[SortOrder] = Nil
+
+  protected override def doExecute(): RDD[InternalRow] = {
+    val broadcastedRight = right.executeBroadcast[Array[InternalRow]]()
+    val numOutput = longMetric("numOutputRows")
+    val streamedRowsMetric = longMetric("streamedRows")
+    val localJoinType = joinType
+    val k = numResults
+    val isDistance = direction == NearestByDistance
+    val leftOutput = left.output
+    val rightOutput = right.output
+    val rankExpr = rankingExpression
+    val allOutput = output
+    val ordering = TypeUtils.getInterpretedOrdering(rankExpr.dataType)
+
+    left.execute().mapPartitionsInternal { leftIter =>
+      val rightRows = broadcastedRight.value
+      if (rightRows.isEmpty && localJoinType != LeftOuter) {
+        Iterator.empty
+      } else {
+        val joinedRow = new JoinedRow
+        val rankingProj = UnsafeProjection.create(
+          Seq(rankExpr), leftOutput ++ rightOutput)
+        val resultProj = UnsafeProjection.create(allOutput, allOutput)
+
+        // Hoist heap outside flatMap to reduce GC pressure
+        val heap = if (isDistance) {
+          new JPriorityQueue[HeapEntry](k + 1,
+            new Comparator[HeapEntry] {
+              override def compare(a: HeapEntry, b: HeapEntry): Int =
+                ordering.compare(b.rankingValue, a.rankingValue)
+            })
+        } else {
+          new JPriorityQueue[HeapEntry](k + 1,
+            new Comparator[HeapEntry] {
+              override def compare(a: HeapEntry, b: HeapEntry): Int =
+                ordering.compare(a.rankingValue, b.rankingValue)
+            })
+        }
+
+        leftIter.flatMap { leftRow =>
+          streamedRowsMetric += 1
+          heap.clear()
+
+          var i = 0
+          while (i < rightRows.length) {
+            val rightRow = rightRows(i)
+            joinedRow(leftRow, rightRow)
+            val rankingRow = rankingProj(joinedRow).copy()

Review Comment:
   `rankingProj(joinedRow).copy()` copies the full ranking `UnsafeRow` for 
every (left, right) pair — N×M times — before the null check and even for rows 
the heap immediately evicts. The `.copy()` is necessary for variable-length 
ranking types (String/Binary/struct, exercised by the StringType test), but for 
the common primitive case (Double/Int/Long distance) it allocates a `byte[]` 
per pair, which works against the GC-pressure win this operator exists to 
deliver.
   
   Consider extracting the value first and copying only when 
`rankExpr.dataType` is variable-length (or cloning just the retained value 
rather than the whole row). Non-blocking.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala:
##########
@@ -2352,6 +2352,16 @@ object SQLConf {
     .booleanConf
     .createWithDefault(true)
 
+  val NEAREST_BY_BROADCAST_ENABLED =
+    buildConf("spark.sql.join.nearestBy.broadcast.enabled")
+      .internal()
+      .doc("When true, NearestByJoin uses a streaming heap operator instead of 
the " +
+        "cross-product + aggregate rewrite.")
+      .version("5.0.0")

Review Comment:
   This config gates an additive, opt-in optimization, and `NearestByJoin` 
already exists on `branch-4.x` (4.3.0-SNAPSHOT), so the feature is 
backportable. By Spark convention the `.version(...)` should be the next *open* 
feature release (4.3.0), not master's SNAPSHOT (5.0.0).
   ```suggestion
         .version("4.3.0")
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to