yadavay-amzn commented on code in PR #56101:
URL: https://github.com/apache/spark/pull/56101#discussion_r3472001103


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala:
##########
@@ -2352,6 +2352,16 @@ object SQLConf {
     .booleanConf
     .createWithDefault(true)
 
+  val NEAREST_BY_BROADCAST_ENABLED =
+    buildConf("spark.sql.join.nearestBy.broadcast.enabled")
+      .internal()
+      .doc("When true, NearestByJoin uses a streaming heap operator instead of 
the " +
+        "cross-product + aggregate rewrite.")
+      .version("5.0.0")

Review Comment:
   Done - set to 4.3.0.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNearestByJoinExec.scala:
##########
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import java.util.{Comparator, PriorityQueue => JPriorityQueue}
+
+import org.apache.spark.SparkException
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.{InnerLike, JoinType, LeftOuter, 
NearestByDirection, NearestByDistance}
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.catalyst.util.TypeUtils
+import org.apache.spark.sql.execution.{ExplainUtils, SparkPlan}
+import org.apache.spark.sql.execution.metric.SQLMetrics
+
+/**
+ * Heap entry storing an index into the broadcast array alongside its ranking 
value.
+ * Using a case class with primitive `Int` field avoids boxing that `(Int, 
Any)` tuples incur.
+ */
+private[joins] case class HeapEntry(index: Int, rankingValue: Any)
+
+/**
+ * Physical operator for NearestByJoin that avoids materializing the full 
cross product.
+ * For each left row, iterates all broadcast right rows maintaining a bounded 
priority
+ * queue of size k, then emits the top-k matches directly.
+ *
+ * The right side is fully broadcast to all partitions. This operator only 
fires when
+ * the right side fits within [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]]. For 
right tables
+ * exceeding this threshold, the existing cross-product + aggregate rewrite is 
used as
+ * fallback. Tie-breaking among equal ranking values is non-deterministic 
(matches the
+ * existing rewrite behavior).
+ */
+case class BroadcastNearestByJoinExec(
+    left: SparkPlan,
+    right: SparkPlan,
+    joinType: JoinType,
+    numResults: Int,
+    rankingExpression: Expression,
+    direction: NearestByDirection) extends BaseJoinExec {
+
+  override def condition: Option[Expression] = None
+  override def leftKeys: Seq[Expression] = Seq.empty
+  override def rightKeys: Seq[Expression] = Seq.empty
+
+  override def simpleStringWithNodeId(): String = {
+    val opId = ExplainUtils.getOpId(this)
+    s"$nodeName $joinType k=$numResults $direction ($opId)".trim
+  }
+
+  override def output: Seq[Attribute] = joinType match {
+    case _: InnerLike | LeftOuter =>
+      left.output.map(_.withNullability(true)) ++ 
right.output.map(_.withNullability(true))
+    case other =>
+      throw SparkException.internalError(
+        s"$nodeName does not support join type: $other")
+  }
+
+  override lazy val metrics = Map(
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output 
rows"),
+    "streamedRows" -> SQLMetrics.createMetric(sparkContext, "number of left 
rows processed"))
+
+  override def requiredChildDistribution: Seq[Distribution] =
+    UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: 
Nil
+
+  override def outputPartitioning: Partitioning = left.outputPartitioning
+
+  override def outputOrdering: Seq[SortOrder] = Nil
+
+  protected override def doExecute(): RDD[InternalRow] = {
+    val broadcastedRight = right.executeBroadcast[Array[InternalRow]]()
+    val numOutput = longMetric("numOutputRows")
+    val streamedRowsMetric = longMetric("streamedRows")
+    val localJoinType = joinType
+    val k = numResults
+    val isDistance = direction == NearestByDistance
+    val leftOutput = left.output
+    val rightOutput = right.output
+    val rankExpr = rankingExpression
+    val allOutput = output
+    val ordering = TypeUtils.getInterpretedOrdering(rankExpr.dataType)
+
+    left.execute().mapPartitionsInternal { leftIter =>
+      val rightRows = broadcastedRight.value
+      if (rightRows.isEmpty && localJoinType != LeftOuter) {
+        Iterator.empty
+      } else {
+        val joinedRow = new JoinedRow
+        val rankingProj = UnsafeProjection.create(
+          Seq(rankExpr), leftOutput ++ rightOutput)
+        val resultProj = UnsafeProjection.create(allOutput, allOutput)
+
+        // Hoist heap outside flatMap to reduce GC pressure
+        val heap = if (isDistance) {
+          new JPriorityQueue[HeapEntry](k + 1,
+            new Comparator[HeapEntry] {
+              override def compare(a: HeapEntry, b: HeapEntry): Int =
+                ordering.compare(b.rankingValue, a.rankingValue)
+            })
+        } else {
+          new JPriorityQueue[HeapEntry](k + 1,
+            new Comparator[HeapEntry] {
+              override def compare(a: HeapEntry, b: HeapEntry): Int =
+                ordering.compare(a.rankingValue, b.rankingValue)
+            })
+        }
+
+        leftIter.flatMap { leftRow =>
+          streamedRowsMetric += 1
+          heap.clear()
+
+          var i = 0
+          while (i < rightRows.length) {
+            val rightRow = rightRows(i)
+            joinedRow(leftRow, rightRow)
+            val rankingRow = rankingProj(joinedRow).copy()

Review Comment:
   Done. The copy now happens only for variable-length ranking types: gated on 
`!UnsafeRow.isFixedLength(rankExpr.dataType)` (precomputed once), so 
fixed-width types (numeric, Date/Timestamp, small Decimal) extract an immutable 
value with no per-pair copy, while variable-length types 
(String/Binary/struct/array/map, and Decimal with precision > 18) copy the 
single-column row to detach it from the reused buffer. The String and 
Decimal(19) retain tests cover the variable-length path.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to