zhidongqu-db commented on code in PR #55629:
URL: https://github.com/apache/spark/pull/55629#discussion_r3174733028
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala:
##########
@@ -657,6 +657,29 @@ trait CheckAnalysis extends LookupCatalog with
QueryErrorsBase with PlanToString
messageParameters = Map.empty)
}
+ // Reject streaming inputs early. The optimizer rewrite introduces
+ // `MonotonicallyIncreasingID()`, which is per-batch only and would
silently produce
+ // incorrect results across micro-batches; failing at analysis time
is clearer than
+ // letting the streaming check fire on an incidental MID node.
+ case j: NearestByJoin if j.isStreaming =>
Review Comment:
On the streaming guard: the current comment frames the issue as "MID is
per-batch only", but MID itself is fine within a batch - the real blocker is
that the rewrite uses a global Aggregate keyed by `__qid`, which Spark turns
into a stateful streaming aggregation. Across micro-batches MID values restart,
so state entries from old batches get merged with new rows for the same __qid,
producing wrong top-K results.
The MID is just an implementation detail (we only need a per-row group key),
so streaming support doesn't have to wait on a streaming-aware MID.
A few directions for the follow-up:
1. Group by struct(left.*) instead of MID. Pure Catalyst change - every
distinct left row is its own group. Need to handle duplicate left rows (carry a
count, expand at the end) and bail out on map-typed left columns. Lowest-risk
path.
2. Dedicated physical operator that does per-row top-K against a
broadcast/streaming right side, no cross-join + aggregate. This is also the
operator the SPIP calls out as future work for performance, so it solves two
problems at once.
3. Batch-scoped aggregate (include batch_id in the key, or a non-incremental
aggregate variant) - doable but tangles us up with streaming state/watermark
semantics, not worth it IMO.
Happy to leave the guard in this PR; just suggesting we update the comment
to reflect the actual reason so future-us isn't misled.
##########
sql/core/src/test/resources/sql-tests/inputs/join-nearest-by.sql:
##########
@@ -0,0 +1,57 @@
+-- Test cases for NEAREST BY top-K ranking join.
+
Review Comment:
the new DeduplicateRelations branch for NearestByJoin (same pattern as
Join/LateralJoin/AsOfJoin) isn't covered by a test. A self-join case like:
```
SELECT * FROM t a JOIN t b APPROX NEAREST 1 BY DISTANCE abs(a.x - b.x)
```
would exercise the new branch and lock in the behavior - without it, a
future refactor of DeduplicateRelations could silently break self-NEAREST-BY
without any test failing. Could we add one to join-nearest-by.sql (or as a
planner test in PlanParserSuite / RewriteNearestByJoinSuite)?
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala:
##########
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+
+/**
+ * Replaces a logical [[NearestByJoin]] operator with a
`Generate(Inline(...))` over an
+ * `Aggregate` that tags each left row with a unique id, cross-joins with the
right side, and
+ * groups by the unique id to compute the top-K matches via `MAX_BY`/`MIN_BY`
(K-overload).
+ *
+ * Input Pseudo-Query:
+ * {{{
+ * SELECT * FROM left [INNER | LEFT OUTER] JOIN right
+ * {APPROX | EXACT} NEAREST k BY {DISTANCE | SIMILARITY} expr
+ * }}}
+ *
+ * Rewritten Plan (SIMILARITY, INNER join type):
+ * {{{
+ * Generate inline(_matches), [N], outer=false, [right.col1, right.col2,
...]
+ * +- Aggregate [__qid],
+ * [first(left.col0) AS left.col0, ..., first(left.colN-1) AS
left.colN-1,
+ * max_by(struct(right.*), expr, k) AS _matches]
+ * +- Join Inner
+ * :- Project [left.*, monotonically_increasing_id() AS __qid]
+ * : +- left
+ * +- right
+ * }}}
+ *
+ * For `DISTANCE`, `MIN_BY` is used instead of `MAX_BY`. For `LEFT OUTER`, the
`Generate` is
+ * constructed with `outer = true` so left rows with no matches (empty/null
`_matches`) are
+ * preserved with `NULL` right-side columns.
+ *
+ * In this initial implementation both `APPROX` and `EXACT` take the same
brute-force rewrite
Review Comment:
this is intentional -- we will follow adding QO rewrite logic specifically
for APPROX.
##########
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala:
##########
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference,
CreateStruct, Inline, Literal, MonotonicallyIncreasingID}
+import org.apache.spark.sql.catalyst.expressions.aggregate.{First, MaxMinByK}
+import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter,
NearestByDistance, NearestBySimilarity, PlanTest}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Generate, Join,
JoinHint, LocalRelation, NearestByJoin, Project}
+
+class RewriteNearestByJoinSuite extends PlanTest {
Review Comment:
Could we expand the suite a bit? Current coverage is 3 cases
(similarity+inner+k=5, distance+inner+k=3, similarity+leftouter+k=1). Some gaps
that would be worth filling in:
- distance + leftouter - the only joinType × direction combo not exercised,
and the LEFT OUTER path is where outer = true flips behavior.
- approx = false (EXACT) -same rewrite shape as APPROX today, but a test
would lock in that fact so a future divergence is intentional.
- k = 1 and k = NearestByJoin.MaxNumResults - boundary values; especially
MaxNumResults given it now lives in a constant.
- Self-join (NearestByJoin(t, t, ...)) - exercises the new
DeduplicateRelations branch added in this PR; without a test, that branch could
regress silently.
- Cross-join disabled - wrap one case in
withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "false") to lock in the
SYNTHETIC_JOIN_TAG behavior. Right now nothing tests that NEAREST BY survives
the default cluster config.
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala:
##########
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+
+/**
+ * Replaces a logical [[NearestByJoin]] operator with a
`Generate(Inline(...))` over an
+ * `Aggregate` that tags each left row with a unique id, cross-joins with the
right side, and
+ * groups by the unique id to compute the top-K matches via `MAX_BY`/`MIN_BY`
(K-overload).
+ *
+ * Input Pseudo-Query:
+ * {{{
+ * SELECT * FROM left [INNER | LEFT OUTER] JOIN right
+ * {APPROX | EXACT} NEAREST k BY {DISTANCE | SIMILARITY} expr
+ * }}}
+ *
+ * Rewritten Plan (SIMILARITY, INNER join type):
+ * {{{
+ * Generate inline(_matches), [N], outer=false, [right.col1, right.col2,
...]
+ * +- Aggregate [__qid],
+ * [first(left.col0) AS left.col0, ..., first(left.colN-1) AS
left.colN-1,
+ * max_by(struct(right.*), expr, k) AS _matches]
+ * +- Join Inner
+ * :- Project [left.*, monotonically_increasing_id() AS __qid]
+ * : +- left
+ * +- right
+ * }}}
+ *
+ * For `DISTANCE`, `MIN_BY` is used instead of `MAX_BY`. For `LEFT OUTER`, the
`Generate` is
+ * constructed with `outer = true` so left rows with no matches (empty/null
`_matches`) are
+ * preserved with `NULL` right-side columns.
+ *
+ * In this initial implementation both `APPROX` and `EXACT` take the same
brute-force rewrite
+ * path. `APPROX` establishes the contract for future indexed-ANN strategies.
+ */
+object RewriteNearestByJoin extends Rule[LogicalPlan] {
Review Comment:
yes - can we add a comment here, something like
```
Unlike RewriteAsOfJoin, which uses a correlated scalar subquery, this rule
materializes the cross product directly. A scalar subquery returns a single
value per left row, so it cannot carry K matches without an array-valued
subquery + Generate(Inline(...)) — which collapses back to the same cross
product after decorrelation. The aggregate-then-inline form makes the intended
shape explicit and avoids round-tripping through subquery decorrelation.
```
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala:
##########
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+
+/**
+ * Replaces a logical [[NearestByJoin]] operator with a
`Generate(Inline(...))` over an
+ * `Aggregate` that tags each left row with a unique id, cross-joins with the
right side, and
+ * groups by the unique id to compute the top-K matches via `MAX_BY`/`MIN_BY`
(K-overload).
+ *
+ * Input Pseudo-Query:
+ * {{{
+ * SELECT * FROM left [INNER | LEFT OUTER] JOIN right
+ * {APPROX | EXACT} NEAREST k BY {DISTANCE | SIMILARITY} expr
+ * }}}
+ *
+ * Rewritten Plan (SIMILARITY, INNER join type):
+ * {{{
+ * Generate inline(_matches), [N], outer=false, [right.col1, right.col2,
...]
+ * +- Aggregate [__qid],
+ * [first(left.col0) AS left.col0, ..., first(left.colN-1) AS
left.colN-1,
+ * max_by(struct(right.*), expr, k) AS _matches]
+ * +- Join Inner
+ * :- Project [left.*, monotonically_increasing_id() AS __qid]
+ * : +- left
+ * +- right
+ * }}}
+ *
+ * For `DISTANCE`, `MIN_BY` is used instead of `MAX_BY`. For `LEFT OUTER`, the
`Generate` is
+ * constructed with `outer = true` so left rows with no matches (empty/null
`_matches`) are
+ * preserved with `NULL` right-side columns.
+ *
+ * In this initial implementation both `APPROX` and `EXACT` take the same
brute-force rewrite
+ * path. `APPROX` establishes the contract for future indexed-ANN strategies.
+ */
+object RewriteNearestByJoin extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithNewOutput {
+ case j @ NearestByJoin(left, right, joinType, _, numResults,
rankingExpression, direction) =>
+ // 1. Tag each left row with a unique id so that rows from the same left
row can later be
+ // grouped together after the cross-join with `right`.
+ val qidAlias = Alias(MonotonicallyIncreasingID(), "__qid")()
+ val taggedLeft = Project(left.output :+ qidAlias, left)
+ val qidAttr = qidAlias.toAttribute
+
+ // 2. LEFT OUTER-join the tagged left with right (no join condition).
LEFT OUTER
+ // (rather than INNER) preserves left rows even when `right` is
empty, so that a
+ // `LEFT OUTER NEAREST BY` query still returns those rows with `NULL`
right-side
+ // columns after the aggregate + inline below. When `right` is
non-empty every left
+ // row already has right-row pairings, so LEFT OUTER and INNER are
equivalent.
+ //
+ // Tag the join so `CheckCartesianProducts` skips it: the rewrite
intentionally
+ // materializes a cross product bounded by the downstream `MaxMinByK`
aggregate, so
+ // `spark.sql.crossJoin.enabled = false` should not reject user
queries written as
+ // `NEAREST BY`.
+ val join = Join(taggedLeft, right, LeftOuter, None, JoinHint.NONE)
+ join.setTagValue(NearestByJoin.SYNTHETIC_JOIN_TAG, ())
Review Comment:
I'd lean against (a) - moving the rule out of FinishAnalysis forfeits the
predicate pushdown / column pruning that runs over the rewritten plan, which
seems like a worse trade. (b) is the right long-term fix: detect the shape in
CheckCartesianProducts (e.g., Aggregate over Join(LeftOuter, condition=None)
with a MaxMinByK aggregate expression) and skip on that, no tag needed. (c) we
should do regardless - at minimum a test for `crossJoinEnabled=false` plus one
with a filter pushed down toward the rewrite.
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala:
##########
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+
+/**
+ * Replaces a logical [[NearestByJoin]] operator with a
`Generate(Inline(...))` over an
+ * `Aggregate` that tags each left row with a unique id, cross-joins with the
right side, and
+ * groups by the unique id to compute the top-K matches via `MAX_BY`/`MIN_BY`
(K-overload).
+ *
+ * Input Pseudo-Query:
+ * {{{
+ * SELECT * FROM left [INNER | LEFT OUTER] JOIN right
+ * {APPROX | EXACT} NEAREST k BY {DISTANCE | SIMILARITY} expr
+ * }}}
+ *
+ * Rewritten Plan (SIMILARITY, INNER join type):
+ * {{{
+ * Generate inline(_matches), [N], outer=false, [right.col1, right.col2,
...]
+ * +- Aggregate [__qid],
+ * [first(left.col0) AS left.col0, ..., first(left.colN-1) AS
left.colN-1,
+ * max_by(struct(right.*), expr, k) AS _matches]
+ * +- Join Inner
+ * :- Project [left.*, monotonically_increasing_id() AS __qid]
+ * : +- left
+ * +- right
+ * }}}
+ *
+ * For `DISTANCE`, `MIN_BY` is used instead of `MAX_BY`. For `LEFT OUTER`, the
`Generate` is
+ * constructed with `outer = true` so left rows with no matches (empty/null
`_matches`) are
+ * preserved with `NULL` right-side columns.
+ *
+ * In this initial implementation both `APPROX` and `EXACT` take the same
brute-force rewrite
+ * path. `APPROX` establishes the contract for future indexed-ANN strategies.
+ */
+object RewriteNearestByJoin extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithNewOutput {
+ case j @ NearestByJoin(left, right, joinType, _, numResults,
rankingExpression, direction) =>
+ // 1. Tag each left row with a unique id so that rows from the same left
row can later be
+ // grouped together after the cross-join with `right`.
+ val qidAlias = Alias(MonotonicallyIncreasingID(), "__qid")()
+ val taggedLeft = Project(left.output :+ qidAlias, left)
+ val qidAttr = qidAlias.toAttribute
+
+ // 2. LEFT OUTER-join the tagged left with right (no join condition).
LEFT OUTER
+ // (rather than INNER) preserves left rows even when `right` is
empty, so that a
+ // `LEFT OUTER NEAREST BY` query still returns those rows with `NULL`
right-side
+ // columns after the aggregate + inline below. When `right` is
non-empty every left
+ // row already has right-row pairings, so LEFT OUTER and INNER are
equivalent.
+ //
+ // Tag the join so `CheckCartesianProducts` skips it: the rewrite
intentionally
+ // materializes a cross product bounded by the downstream `MaxMinByK`
aggregate, so
+ // `spark.sql.crossJoin.enabled = false` should not reject user
queries written as
+ // `NEAREST BY`.
+ val join = Join(taggedLeft, right, LeftOuter, None, JoinHint.NONE)
+ join.setTagValue(NearestByJoin.SYNTHETIC_JOIN_TAG, ())
+
+ // 3. Aggregate grouped by `__qid`:
+ // - first(col) for every left column so it flows to the output.
+ // - max_by/min_by(struct(right.*), ranking, k) as `_matches`.
+ // The ranking expression references left and right columns directly;
no outer
+ // reference is needed because both sides are present in the joined
input.
+ val rightStruct = CreateStruct(right.output)
+ // reverse = true -> MIN_BY (smallest ranking value first, for DISTANCE)
+ // reverse = false -> MAX_BY (largest ranking value first, for
SIMILARITY)
+ val reverse = direction match {
+ case NearestByDistance => true
+ case NearestBySimilarity => false
+ }
+ val topK = MaxMinByK(
+ rightStruct,
+ rankingExpression,
+ Literal(numResults),
+ reverse = reverse).toAggregateExpression()
+ val matchesAlias = Alias(topK, "__nearest_matches__")()
+
+ // Carry left columns through with `First`. Within a `__qid` group every
row has the same
+ // left values (each group corresponds to one left row), so `First` is
effectively a no-op.
+ // We use `First` rather than adding all left columns to the GROUP BY
because grouping by
+ // `__qid` alone keeps the shuffle key small.
+ val firstLeftAggs = left.output.map { attr =>
+ Alias(
+ First(attr, ignoreNulls = false).toAggregateExpression(),
Review Comment:
Optimization idea: for INNER NEAREST BY specifically, we could drop the
per-column First aggregates entirely by packing struct(left.* ++ right.*) into
MaxMinByK directly. inline(_matches) would then produce all output columns in
one shot — no First, no separate left tracking, just one aggregate expression
and a smaller plan.
Doesn't extend to LEFT OUTER (the unmatched-row case needs left.* available
outside the matches array, otherwise Generate(outer = true) emits all NULLs),
so we'd want to branch the rewrite by joinType. The trade-off is K times
redundant left.* in the heap vs. a much simpler INNER plan, which is probably
worth it for the common case.
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala:
##########
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+
+/**
+ * Replaces a logical [[NearestByJoin]] operator with a
`Generate(Inline(...))` over an
+ * `Aggregate` that tags each left row with a unique id, cross-joins with the
right side, and
+ * groups by the unique id to compute the top-K matches via `MAX_BY`/`MIN_BY`
(K-overload).
+ *
+ * Input Pseudo-Query:
+ * {{{
+ * SELECT * FROM left [INNER | LEFT OUTER] JOIN right
+ * {APPROX | EXACT} NEAREST k BY {DISTANCE | SIMILARITY} expr
+ * }}}
+ *
+ * Rewritten Plan (SIMILARITY, INNER join type):
+ * {{{
+ * Generate inline(_matches), [N], outer=false, [right.col1, right.col2,
...]
+ * +- Aggregate [__qid],
+ * [first(left.col0) AS left.col0, ..., first(left.colN-1) AS
left.colN-1,
+ * max_by(struct(right.*), expr, k) AS _matches]
+ * +- Join Inner
+ * :- Project [left.*, monotonically_increasing_id() AS __qid]
+ * : +- left
+ * +- right
+ * }}}
+ *
+ * For `DISTANCE`, `MIN_BY` is used instead of `MAX_BY`. For `LEFT OUTER`, the
`Generate` is
+ * constructed with `outer = true` so left rows with no matches (empty/null
`_matches`) are
+ * preserved with `NULL` right-side columns.
+ *
+ * In this initial implementation both `APPROX` and `EXACT` take the same
brute-force rewrite
+ * path. `APPROX` establishes the contract for future indexed-ANN strategies.
+ */
+object RewriteNearestByJoin extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithNewOutput {
+ case j @ NearestByJoin(left, right, joinType, _, numResults,
rankingExpression, direction) =>
+ // 1. Tag each left row with a unique id so that rows from the same left
row can later be
+ // grouped together after the cross-join with `right`.
+ val qidAlias = Alias(MonotonicallyIncreasingID(), "__qid")()
+ val taggedLeft = Project(left.output :+ qidAlias, left)
+ val qidAttr = qidAlias.toAttribute
+
+ // 2. LEFT OUTER-join the tagged left with right (no join condition).
LEFT OUTER
+ // (rather than INNER) preserves left rows even when `right` is
empty, so that a
+ // `LEFT OUTER NEAREST BY` query still returns those rows with `NULL`
right-side
+ // columns after the aggregate + inline below. When `right` is
non-empty every left
+ // row already has right-row pairings, so LEFT OUTER and INNER are
equivalent.
+ //
+ // Tag the join so `CheckCartesianProducts` skips it: the rewrite
intentionally
+ // materializes a cross product bounded by the downstream `MaxMinByK`
aggregate, so
+ // `spark.sql.crossJoin.enabled = false` should not reject user
queries written as
+ // `NEAREST BY`.
+ val join = Join(taggedLeft, right, LeftOuter, None, JoinHint.NONE)
+ join.setTagValue(NearestByJoin.SYNTHETIC_JOIN_TAG, ())
+
+ // 3. Aggregate grouped by `__qid`:
+ // - first(col) for every left column so it flows to the output.
+ // - max_by/min_by(struct(right.*), ranking, k) as `_matches`.
+ // The ranking expression references left and right columns directly;
no outer
+ // reference is needed because both sides are present in the joined
input.
+ val rightStruct = CreateStruct(right.output)
+ // reverse = true -> MIN_BY (smallest ranking value first, for DISTANCE)
+ // reverse = false -> MAX_BY (largest ranking value first, for
SIMILARITY)
+ val reverse = direction match {
+ case NearestByDistance => true
+ case NearestBySimilarity => false
+ }
+ val topK = MaxMinByK(
+ rightStruct,
+ rankingExpression,
+ Literal(numResults),
+ reverse = reverse).toAggregateExpression()
+ val matchesAlias = Alias(topK, "__nearest_matches__")()
+
+ // Carry left columns through with `First`. Within a `__qid` group every
row has the same
+ // left values (each group corresponds to one left row), so `First` is
effectively a no-op.
+ // We use `First` rather than adding all left columns to the GROUP BY
because grouping by
+ // `__qid` alone keeps the shuffle key small.
+ val firstLeftAggs = left.output.map { attr =>
+ Alias(
+ First(attr, ignoreNulls = false).toAggregateExpression(),
+ attr.name)(exprId = attr.exprId, qualifier = attr.qualifier)
+ }
+ val aggregate = Aggregate(Seq(qidAttr), firstLeftAggs :+ matchesAlias,
join)
+
+ // 4. Generate inline(_matches) expands the K-element array into K rows,
exposing each
+ // struct field as a top-level column. `outer = true` for LEFT OUTER
preserves the
+ // left row with NULL right columns when there are no matches.
+ val generatorOutput = right.output.map { a =>
+ AttributeReference(a.name, a.dataType, nullable = true)(qualifier =
a.qualifier)
+ }
Review Comment:
+1
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]